diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 625a5a30..a62b1e70 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -133,7 +133,7 @@ void RealGarbleWire::input(party_id_t from, char input) protocol.init_mul(party.shared_proc); protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask); protocol.exchange(); - if (party.MC->POpen(protocol.finalize_mul(), *party.P) != 0) + if (party.MC->open(protocol.finalize_mul(), *party.P) != 0) throw runtime_error("input mask not a bit"); } #ifdef DEBUG_MASK @@ -168,7 +168,7 @@ void RealGarbleWire::output() auto& party = RealProgramParty::s(); assert(party.MC != 0); assert(party.P != 0); - auto m = party.MC->POpen(mask, *party.P); + auto m = party.MC->open(mask, *party.P); party.output_masks.push_back(m.get_bit(0)); party.taint(); #ifdef DEBUG_MASK diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 011552d6..ff9aef2c 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -90,7 +90,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : { mac_key.randomize(prng); if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({{{*P, true}}}); + BaseMachine::s().ot_setups.push_back({*P, true}); prep = Preprocessing::get_live_prep(0, usage); } else diff --git a/CHANGELOG.md b/CHANGELOG.md index 17a3f3d0..30f21701 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.3 (Nov 21, 2019) + +- Python 3 +- Semi-honest computation based on semi-homomorphic encryption +- Access to player information in high-level language + ## 0.1.2 (Oct 11, 2019) - Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission diff --git a/Compiler/GC/program.py b/Compiler/GC/program.py index d0953f4d..c77d4281 100644 --- a/Compiler/GC/program.py +++ b/Compiler/GC/program.py @@ -5,6 +5,6 @@ class Program(object): types.program = self instructions.program = self self.curr_tape = None - execfile(progname) + exec(compile(open(progname).read(), progname, 'exec')) def malloc(self, *args): pass diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index a015aa2b..988f0a30 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -5,6 +5,7 @@ from Compiler.exceptions import * from Compiler import util, oram, floatingpoint, library import Compiler.GC.instructions as inst import operator +from functools import reduce class bits(Tape.Register, _structure): n = 40 @@ -82,7 +83,7 @@ class bits(Tape.Register, _structure): cls.load_inst[util.is_constant(address)](res, address) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, (int, long))](self, address) + self.store_inst[isinstance(address, int)](self, address) def __init__(self, value=None, n=None, size=None): if size != 1 and size is not None: raise Exception('invalid size for bit type: %s' % size) @@ -92,11 +93,11 @@ class bits(Tape.Register, _structure): self.load_other(value) def set_length(self, n): if n > self.max_length: - print self.max_length + print(self.max_length) raise Exception('too long: %d' % n) self.n = n def load_other(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): @@ -115,6 +116,7 @@ class bits(Tape.Register, _structure): def __repr__(self): return '%s(%d/%d)' % \ (super(bits, self).__repr__(), self.n, type(self).n) + __str__ = __repr__ class cbits(bits): max_length = 64 @@ -219,13 +221,13 @@ class sbits(bits): @classmethod def load_dynamic_mem(cls, address): res = cls() - if isinstance(address, (long, int)): + if isinstance(address, int): inst.ldmsd(res, address, cls.n) else: inst.ldmsdi(res, address, cls.n) return res def store_in_dynamic_mem(self, address): - if isinstance(address, (long, int)): + if isinstance(address, int): inst.stmsd(self, address) else: inst.stmsdi(self, cbits.conv(address)) @@ -322,7 +324,7 @@ class sbits(bits): mul_bits = [self if b else zero for b in bits] return self.bit_compose(mul_bits) else: - print self.n, other + print(self.n, other) return NotImplemented def __lshift__(self, i): return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i]) @@ -478,7 +480,7 @@ class bitsBlock(oram.Block): self.start_demux = oram.demux_list(self.start_bits) self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \ for i in range(entries_per_block)] - self.mul_entries = map(operator.mul, self.start_demux, self.entries) + self.mul_entries = list(map(operator.mul, self.start_demux, self.entries)) self.bits = sum(self.mul_entries).bit_decompose() self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths)) self.anti_value = self.mul_value + self.value @@ -662,6 +664,12 @@ class sbitfix(_fix): return super(sbitfix, self).__mul__(other) __rxor__ = __xor__ __rmul__ = __mul__ + @staticmethod + def multipliable(other, k, f): + class cls(_fix): + int_type = sbitint.get_type(k) + cls.set_precision(f, k) + return cls._new(cls.int_type(other), k, f) sbitfix.set_precision(20, 41) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 1f52e36f..9a22da46 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -1,15 +1,15 @@ -import compilerLib, program, instructions, types, library, floatingpoint -import GC.types +from . import compilerLib, program, instructions, types, library, floatingpoint +from .GC import types as GC_types import inspect -from config import * -from compilerLib import run +from .config import * +from .compilerLib import run # add all instructions to the program VARS dictionary compilerLib.VARS = {} instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] -for mod in (types, GC.types): +for mod in (types, GC_types): instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ if t[1].__module__ == mod.__name__] diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 2ed6016b..f6d1f0e0 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -10,16 +10,17 @@ import Compiler.program import heapq, itertools import operator import sys +from functools import reduce class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" def __init__(self, n): - self.alloc = {} + self.alloc = dict_by_id() self.usage = Compiler.program.RegType.create_dict(lambda: 0) - self.defined = {} - self.dealloc = set() + self.defined = dict_by_id() + self.dealloc = set_by_id() self.n = n def alloc_reg(self, reg, free): @@ -77,8 +78,8 @@ class StraightlineAllocator: unused_regs.append(j) if unused_regs and len(unused_regs) == len(list(i.get_def())): # only report if all assigned registers are unused - print "Register(s) %s never used, assigned by '%s' in %s" % \ - (unused_regs,i,format_trace(i.caller)) + print("Register(s) %s never used, assigned by '%s' in %s" % \ + (unused_regs,i,format_trace(i.caller))) for j in i.get_used(): self.alloc_reg(j, alloc_pool) @@ -86,7 +87,7 @@ class StraightlineAllocator: self.dealloc_reg(j, i, alloc_pool) if k % 1000000 == 0 and k > 0: - print "Allocated registers for %d instructions at" % k, time.asctime() + print("Allocated registers for %d instructions at" % k, time.asctime()) # print "Successfully allocated registers" # print "modp usage: %d clear, %d secret" % \ @@ -97,8 +98,8 @@ class StraightlineAllocator: def determine_scope(block, options): - last_def = defaultdict(lambda: -1) - used_from_scope = set() + last_def = defaultdict_by_id(lambda: -1) + used_from_scope = set_by_id() def find_in_scope(reg, scope): while True: @@ -114,18 +115,18 @@ def determine_scope(block, options): used_from_scope.add(reg) reg.can_eliminate = False else: - print 'Warning: read before write at register', reg - print '\tline %d: %s' % (n, instr) - print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t') - print '\tregister trace: %s' % format_trace(reg.caller, '\t\t') + print('Warning: read before write at register', reg) + print('\tline %d: %s' % (n, instr)) + print('\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')) + print('\tregister trace: %s' % format_trace(reg.caller, '\t\t')) if options.stop: sys.exit(1) def write(reg, n): if last_def[reg] != -1: - print 'Warning: double write at register', reg - print '\tline %d: %s' % (n, instr) - print '\ttrace: %s' % format_trace(instr.caller, '\t\t') + print('Warning: double write at register', reg) + print('\tline %d: %s' % (n, instr)) + print('\ttrace: %s' % format_trace(instr.caller, '\t\t')) if options.stop: sys.exit(1) last_def[reg] = n @@ -146,7 +147,7 @@ def determine_scope(block, options): write(reg, n) block.used_from_scope = used_from_scope - block.defined_registers = set(last_def.iterkeys()) + block.defined_registers = set_by_id(last_def.keys()) class Merger: def __init__(self, block, options, merge_classes): @@ -178,7 +179,7 @@ class Merger: if inst.is_vec(): for arg in inst.args: arg.create_vector_elements() - res = sum(zip(*inst.args), ()) + res = sum(list(zip(*inst.args)), ()) return list(res) else: return inst.args @@ -241,7 +242,7 @@ class Merger: remaining_input_nodes = [] def do_merge(nodes): if len(nodes) > 1000: - print 'Merging %d inputs...' % len(nodes) + print('Merging %d inputs...' % len(nodes)) self.do_merge(iter(nodes)) for n in self.input_nodes: inst = self.instructions[n] @@ -252,7 +253,7 @@ class Merger: if len(merge) >= self.max_parallel_open: do_merge(merge) merge[:] = [] - for merge in reversed(sorted(merges.itervalues())): + for merge in reversed(sorted(merges.values())): if merge: do_merge(merge) self.input_nodes = remaining_input_nodes @@ -266,7 +267,7 @@ class Merger: instructions = self.instructions flex_nodes = defaultdict(dict) starters = [] - for n in xrange(len(G)): + for n in range(len(G)): if n not in merge_nodes_set and \ depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction): #print n, depth_of[n], rev_depth_of[n] @@ -275,19 +276,19 @@ class Merger: not isinstance(self.instructions[n], RawInputInstruction): starters.append(n) if n % 10000000 == 0 and n > 0: - print "Processed %d nodes at" % n, time.asctime() + print("Processed %d nodes at" % n, time.asctime()) inputs = defaultdict(list) for node in self.input_nodes: player = self.instructions[node].args[0] inputs[player].append(node) - first_inputs = [l[0] for l in inputs.itervalues()] + first_inputs = [l[0] for l in inputs.values()] other_inputs = [] i = 0 while True: i += 1 found = False - for l in inputs.itervalues(): + for l in inputs.values(): if i < len(l): other_inputs.append(l[i]) found = True @@ -299,20 +300,20 @@ class Merger: # magical preorder for topological search max_depth = max(merges) if max_depth > 10000: - print "Computing pre-ordering ..." - for i in xrange(max_depth, 0, -1): + print("Computing pre-ordering ...") + for i in range(max_depth, 0, -1): preorder.append(G.get_attr(merges[i], 'stop')) - for j in flex_nodes[i-1].itervalues(): + for j in flex_nodes[i-1].values(): preorder.extend(j) preorder.extend(flex_nodes[0].get(i, [])) preorder.append(merges[i]) if i % 100000 == 0 and i > 0: - print "Done level %d at" % i, time.asctime() + print("Done level %d at" % i, time.asctime()) preorder.extend(other_inputs) preorder.extend(starters) preorder.extend(first_inputs) if max_depth > 10000: - print "Done at", time.asctime() + print("Done at", time.asctime()) return preorder def longest_paths_merge(self): @@ -343,8 +344,8 @@ class Merger: t = type(self.instructions[merge[0]]) self.counter[t] += len(merge) if len(merge) > 1000: - print 'Merging %d %s in round %d/%d' % \ - (len(merge), t.__name__, i, len(merges)) + print('Merging %d %s in round %d/%d' % \ + (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) self.merge_inputs() @@ -352,11 +353,11 @@ class Merger: preorder = None if len(instructions) > 100000: - print "Topological sort ..." + print("Topological sort ...") order = Compiler.graph.topological_sort(G, preorder) instructions[:] = [instructions[i] for i in order if instructions[i] is not None] if len(instructions) > 100000: - print "Done at", time.asctime() + print("Done at", time.asctime()) return len(merges) @@ -377,7 +378,7 @@ class Merger: self.G = G reg_nodes = {} - last_def = defaultdict(lambda: -1) + last_def = defaultdict_by_id(lambda: -1) last_mem_write = [] last_mem_read = [] warned_about_mem = [] @@ -411,8 +412,8 @@ class Merger: def handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind): - this = last_access_this_kind[addr,reg_type] - other = last_access_other_kind[addr,reg_type] + this = last_access_this_kind[str(addr),reg_type] + other = last_access_other_kind[str(addr),reg_type] if this and other: if this[-1] < other[0]: del this[:] @@ -429,15 +430,15 @@ class Merger: handle_mem_access(addr_i, reg_type, last_access_this_kind, last_access_other_kind) if not warned_about_mem and (instr.get_size() > 100): - print 'WARNING: Order of memory instructions ' \ - 'not preserved due to long vector, errors possible' + print('WARNING: Order of memory instructions ' \ + 'not preserved due to long vector, errors possible') warned_about_mem.append(True) else: handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind) if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction): - print 'WARNING: Order of memory instructions ' \ - 'not preserved, errors possible' + print('WARNING: Order of memory instructions ' \ + 'not preserved, errors possible') # hack warned_about_mem.append(True) @@ -553,11 +554,11 @@ class Merger: self.sources.append(n) if n % 100000 == 0 and n > 0: - print "Processed dependency of %d/%d instructions at" % \ - (n, len(block.instructions)), time.asctime() + print("Processed dependency of %d/%d instructions at" % \ + (n, len(block.instructions)), time.asctime()) if len(open_nodes) > 1000: - print "Program has %d %s instructions" % (len(open_nodes), merge_classes) + print("Program has %d %s instructions" % (len(open_nodes), merge_classes)) def merge_nodes(self, i, j): """ Merge node j into i, removing node j """ @@ -566,8 +567,8 @@ class Merger: G.remove_edge(i, j) if i in G[j]: G.remove_edge(j, i) - G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]])) - G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]])) + G.add_edges_from(list(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))) + G.add_edges_from(list(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))) G.get_attr(i, 'merges').append(j) G.remove_node(j) @@ -578,7 +579,7 @@ class Merger: count = 0 open_count = 0 stats = defaultdict(lambda: 0) - for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)): + for i,inst in zip(range(len(instructions) - 1, -1, -1), reversed(instructions)): # remove if instruction has result that isn't used unused_result = not G.degree(i) and len(list(inst.get_def())) \ and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \ @@ -608,21 +609,21 @@ class Merger: eliminate(i) count += 2 if count > 0: - print 'Eliminated %d dead instructions, among which %d opens: %s' \ - % (count, open_count, dict(stats)) + print('Eliminated %d dead instructions, among which %d opens: %s' \ + % (count, open_count, dict(stats))) def print_graph(self, filename): f = open(filename, 'w') - print >>f, 'digraph G {' + print('digraph G {', file=f) for i in range(self.G.n): for j in self.G[i]: - print >>f, '"%d: %s" -> "%d: %s";' % \ - (i, self.instructions[i], j, self.instructions[j]) - print >>f, '}' + print('"%d: %s" -> "%d: %s";' % \ + (i, self.instructions[i], j, self.instructions[j]), file=f) + print('}', file=f) f.close() def print_depth(self, filename): f = open(filename, 'w') for i in range(self.G.n): - print >>f, '%d: %s' % (self.depths[i], self.instructions[i]) + print('%d: %s' % (self.depths[i], self.instructions[i]), file=f) f.close() diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py index 3f2539c9..30c0d7b8 100644 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -26,7 +26,7 @@ def find_deeper(a, b, path, start, length, compute_level=True): any_empty = OR(a.empty, b.empty) a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)] b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)] - diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]] + diff = [XOR(ab, bb) for ab,bb in list(zip(a_bits, b_bits))[start:length]] diff_preor = type(a.value).PreOR([any_empty] + diff) diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])] winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty) @@ -38,7 +38,7 @@ def find_deeper(a, b, path, start, length, compute_level=True): def find_deepest(paths, search_path, start, length, compute_level=True): if len(paths) == 1: return None, paths[0], 1 - l = len(paths) / 2 + l = len(paths) // 2 _, a, a_index = find_deepest(paths[:l], search_path, start, length, False) _, b, b_index = find_deepest(paths[l:], search_path, start, length, False) level, winner = find_deeper(a, b, search_path, start, length, compute_level) @@ -57,7 +57,7 @@ def greater_unary(a, b): if len(a) == 1: return a[0], b[0] else: - l = len(a) / 2 + l = len(a) // 2 return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l])) def comp_step(high, low): @@ -75,7 +75,7 @@ def comp_binary(a, b): if len(a) == 1: return a[0], b[0] else: - l = len(a) / 2 + l = len(a) // 2 return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l])) def unary_to_binary(l): @@ -89,8 +89,8 @@ class CircuitORAM(PathORAM): self.D = log2(size) self.logD = log2(self.D) self.L = self.D + 1 - print 'create oram of size %d with depth %d and %d buckets' \ - % (size, self.D, self.n_buckets()) + print('create oram of size %d with depth %d and %d buckets' \ + % (size, self.D, self.n_buckets())) self.value_type = value_type self.index_type = value_type.get_type(self.D) if entry_size is not None: @@ -245,7 +245,7 @@ class CircuitORAM(PathORAM): for i,_ in enumerate(self.recursive_evict_rounds()): Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size)) def recursive_evict_rounds(self): - for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()): + for _ in zip(self.evict_rounds(), self.index.l.recursive_evict_rounds()): yield def bucket_indices_on_path_to(self, leaf): # root is at 1, different to PathORAM @@ -272,10 +272,10 @@ threshold = 2**10 def OptimalCircuitORAM(size, value_type, *args, **kwargs): if size <= threshold: - print size, 'below threshold', threshold + print(size, 'below threshold', threshold) return LinearORAM(size, value_type, *args, **kwargs) else: - print size, 'above threshold', threshold + print(size, 'above threshold', threshold) return RecursiveCircuitORAM(size, value_type, *args, **kwargs) class RecursiveCircuitIndexStructure(PackedIndexStructure): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 80d54c6a..e9cb21da 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -28,8 +28,8 @@ use_inv = True # (r[i], r[i]^-1, r[i] * r[i-1]^-1) do_precomp = True -import instructions_base -import util +from . import instructions_base +from . import util def set_variant(options): """ Set flags based on the command-line option provided """ @@ -55,7 +55,7 @@ def ld2i(c, n): """ Load immediate 2^n into clear GF(p) register c """ t1 = program.curr_block.new_reg('c') ldi(t1, 2 ** (n % 30)) - for i in range(n / 30): + for i in range(n // 30): t2 = program.curr_block.new_reg('c') mulci(t2, t1, 2 ** 30) t1 = t2 @@ -75,13 +75,13 @@ def LTZ(s, a, k, kappa): k: bit length of a """ - from types import sint + from .types import sint t = sint() Trunc(t, a, k, k - 1, kappa, True) subsfi(s, t, 0) def LessThanZero(a, k, kappa): - import types + from . import types res = types.sint() LTZ(res, a, k, kappa) return res @@ -124,7 +124,7 @@ def TruncZeroes(a, k, m, signed): if program.options.ring: return TruncLeakyInRing(a, k, m, signed) else: - import types + from . import types tmp = types.cint() inv2m(tmp, m) return a * tmp @@ -136,7 +136,7 @@ def TruncLeakyInRing(a, k, m, signed): """ assert k > m assert int(program.options.ring) >= k - from types import sint, intbitint, cint, cgf2n + from .types import sint, intbitint, cint, cgf2n n_bits = k - m n_shift = int(program.options.ring) - n_bits r_bits = [sint.get_random_bit() for i in range(n_bits)] @@ -165,7 +165,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): # cannot work with bit length k+1 tmp = TruncRing(None, a, k, m - 1, signed) return TruncRing(None, tmp + 1, k - m + 1, 1, signed) - from types import sint + from .types import sint res = sint() Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed) return res @@ -277,7 +277,7 @@ def BitLTC1(u, a, b, kappa): """ k = len(b) p = [program.curr_block.new_reg('s') for i in range(k)] - import floatingpoint + from . import floatingpoint a_bits = floatingpoint.bits(a, k) if instructions_base.get_global_vector_size() == 1: a_ = a_bits @@ -357,12 +357,12 @@ def CarryOutAux(d, a, kappa): if k > 1 and k % 2 == 1: a.append(None) k += 1 - u = [None]*(k/2) + u = [None]*(k//2) a = a[::-1] if k > 1: - for i in range(k/2): - u[i] = carry(a[2*i+1], a[2*i], i != k/2-1) - CarryOutAux(d, u[:k/2][::-1], kappa) + for i in range(k//2): + u[i] = carry(a[2*i+1], a[2*i], i != k//2-1) + CarryOutAux(d, u[:k//2][::-1], kappa) else: movs(d, a[0][1]) @@ -376,7 +376,7 @@ def CarryOut(res, a, b, c=0, kappa=None): c: initial carry-in bit """ k = len(a) - import types + from . import types d = [program.curr_block.new_reg('s') for i in range(k)] t = [[types.sint() for i in range(k)] for i in range(4)] s = [program.curr_block.new_reg('s') for i in range(3)] @@ -394,7 +394,7 @@ def CarryOut(res, a, b, c=0, kappa=None): def CarryOutLE(a, b, c=0): """ Little-endian version """ - import types + from . import types res = types.sint() CarryOut(res, a[::-1], b[::-1], c) return res @@ -407,7 +407,7 @@ def BitLTL(res, a, b, kappa): b: array of secret bits (same length as a) """ k = len(b) - import floatingpoint + from . import floatingpoint a_bits = floatingpoint.bits(a, k) s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)] t = [program.curr_block.new_reg('s') for i in range(1)] @@ -547,7 +547,7 @@ def KMulC(a): """ Return just the product of all items in a """ - from types import sint, cint + from .types import sint, cint p = sint() if use_inv: PreMulC_with_inverses(p, a) @@ -582,7 +582,7 @@ def Mod2(a_0, a, k, kappa, signed): adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) asm_open(c, t[3]) - import floatingpoint + from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) mulm(t[4], r_0, tc) @@ -591,4 +591,4 @@ def Mod2(a_0, a, k, kappa, signed): # hack for circular dependency -from instructions import * +from .instructions import * diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index a2773a64..7a9fcbe9 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,8 +1,8 @@ from Compiler.program import Program from Compiler.config import * from Compiler.exceptions import * -import instructions, instructions_base, types, comparison, library -import GC.types +from . import instructions, instructions_base, types, comparison, library +from .GC import types as GC_types import random import time @@ -25,36 +25,36 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ prog.DEBUG = debug VARS['program'] = prog if options.binary: - VARS['sint'] = GC.types.sbitint.get_type(int(options.binary)) - VARS['sfix'] = GC.types.sbitfix + VARS['sint'] = GC_types.sbitint.get_type(int(options.binary)) + VARS['sfix'] = GC_types.sbitfix comparison.set_variant(options) - print 'Compiling file', prog.infile + print('Compiling file', prog.infile) # no longer needed, but may want to support assembly in future (?) if assemblymode: prog.restart_main_thread() - for i in xrange(INIT_REG_MAX): + for i in range(INIT_REG_MAX): VARS['c%d'%i] = prog.curr_block.new_reg('c') VARS['s%d'%i] = prog.curr_block.new_reg('s') VARS['cg%d'%i] = prog.curr_block.new_reg('cg') VARS['sg%d'%i] = prog.curr_block.new_reg('sg') if i % 10000000 == 0 and i > 0: - print "Initialized %d register variables at" % i, time.asctime() + print("Initialized %d register variables at" % i, time.asctime()) # first pass determines how many assembler registers are used prog.FIRST_PASS = True - execfile(prog.infile, VARS) + exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) if instructions_base.Instruction.count != 0: - print 'instructions count', instructions_base.Instruction.count + print('instructions count', instructions_base.Instruction.count) instructions_base.Instruction.count = 0 prog.FIRST_PASS = False prog.reset_values() # make compiler modules directly accessible sys.path.insert(0, 'Compiler') # create the tapes - execfile(prog.infile, VARS) + exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) # optimize the tapes for tape in prog.tapes: @@ -66,14 +66,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ sharedmem = list(prog.mem_s) prog.emulate() if prog.mem_c != clearmem or prog.mem_s != sharedmem: - print 'Warning: emulated memory values changed after compiler optimization' + print('Warning: emulated memory values changed after compiler optimization') # raise CompilerError('Compiler optimization caused incorrect memory write.') if prog.main_thread_running: prog.update_req(prog.curr_tape) - print 'Program requires:', repr(prog.req_num) - print 'Cost:', 0 if prog.req_num is None else prog.req_num.cost() - print 'Memory size:', dict(prog.allocated_mem) + print('Program requires:', repr(prog.req_num)) + print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) + print('Memory size:', dict(prog.allocated_mem)) # finalize the memory prog.finalize_memory() diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 1011cd58..61a65247 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -86,8 +86,8 @@ class HeapQ(object): self.size = MemValue(int_type(0)) self.int_type = int_type self.basic_type = basic_type - print 'heap: %d levels, depth %d, size %d, index size %d' % \ - (self.levels, self.depth, self.heap.oram.size, self.value_index.size) + print('heap: %d levels, depth %d, size %d, index size %d' % \ + (self.levels, self.depth, self.heap.oram.size, self.value_index.size)) def update(self, value, prio, for_real=True): self._update(self.basic_type.hard_conv(value), \ self.basic_type.hard_conv(prio), \ @@ -217,7 +217,7 @@ class HeapQ(object): def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): basic_type = int_type.basic_type - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ init_rounds=vert_loops, value_type=basic_type) @@ -287,7 +287,7 @@ def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint): cint(i).print_reg('edge') time() edges[i] = edges_list[i] - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size for i in range(vert_loops): cint(i).print_reg('vert') @@ -307,7 +307,7 @@ def test_dijkstra_on_cycle(n, oram_type=ORAM, n_loops=None, int_type=sint): time() neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n edges[i] = (neighbour, 1, i % 2) - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size @for_range(vert_loops) def f(i): @@ -390,14 +390,14 @@ class ExtInt(object): class Vector(object): """ Works like a vector. """ def __add__(self, other): - print 'add', type(self) + print('add', type(self)) res = type(self)(len(self)) @for_range(len(self)) def f(i): res[i] = self[i] + other[i] return res def __sub__(self, other): - print 'sub', type(other) + print('sub', type(other)) res = type(other)(len(self)) @for_range(len(self)) def f(i): @@ -412,7 +412,7 @@ class Vector(object): res[0] += self[i] * other[i] return res[0] else: - print 'mul', type(self) + print('mul', type(self)) res = type(self)(len(self)) @for_range_parallel(1024, len(self)) def f(i): @@ -477,7 +477,7 @@ def binarymin(A): if len(A) == 1: return [1], A[0] else: - half = len(A) / 2 + half = len(A) // 2 A_prime = VectorArray(half) B = IntVectorArray(half) i = IntVectorArray(len(A)) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index e673fe61..2a51ebd4 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -1,9 +1,9 @@ from math import log, floor, ceil from Compiler.instructions import * -import types -import comparison -import program -import util +from . import types +from . import comparison +from . import program +from . import util ## ## Helper functions for floating point arithmetic @@ -16,7 +16,7 @@ def two_power(n): else: max = types.cint(1) << 31 res = 2**(n%31) - for i in range(n / 31): + for i in range(n // 31): res *= max return res @@ -25,7 +25,7 @@ def shift_two(n, pos): return n >> pos else: res = (n >> (pos%63)) - for i in range(pos / 63): + for i in range(pos // 63): res >>= 63 return res @@ -139,7 +139,7 @@ def PreOpL(op, items): kmax = 2**logk output = list(items) for i in range(logk): - for j in range(kmax/(2**(i+1))): + for j in range(kmax//(2**(i+1))): y = two_power(i) + j*two_power(i+1) - 1 for z in range(1, 2**i+1): if y+z < k: @@ -153,7 +153,7 @@ def PreOpL2(op, items): op must be a binary function that outputs a new register """ k = len(items) - half = k / 2 + half = k // 2 output = list(items) if k == 0: return [] @@ -161,7 +161,7 @@ def PreOpL2(op, items): v = PreOpL2(op, u) for i in range(half): output[2 * i + 1] = v[i] - for i in range(1, (k + 1) / 2): + for i in range(1, (k + 1) // 2): output[2 * i] = op(v[i - 1], items[2 * i]) return output @@ -185,8 +185,8 @@ def KOpL(op, a): if k == 1: return a[0] else: - t1 = KOpL(op, a[:k/2]) - t2 = KOpL(op, a[k/2:]) + t1 = KOpL(op, a[:k//2]) + t2 = KOpL(op, a[k//2:]) return op(t1, t2) def KORL(a, kappa): @@ -195,8 +195,8 @@ def KORL(a, kappa): if k == 1: return a[0] else: - t1 = KORL(a[:k/2], kappa) - t2 = KORL(a[k/2:], kappa) + t1 = KORL(a[:k//2], kappa) + t2 = KORL(a[k//2:], kappa) return t1 + t2 - t1*t2 def KORC(a, kappa): @@ -234,7 +234,7 @@ def BitAdd(a, b, bits_to_compute=None): bits s[0], ... , s[k] """ k = len(a) if not bits_to_compute: - bits_to_compute = range(k) + bits_to_compute = list(range(k)) d = [None] * k for i in range(1,k): #assert(a[i].value == 0 or a[i].value == 1) @@ -248,25 +248,25 @@ def BitAdd(a, b, bits_to_compute=None): # (for testing) def print_state(): - print 'a: ', + print('a: ', end=' ') for i in range(k): - print '%d ' % a[i].value, - print '\nb: ', + print('%d ' % a[i].value, end=' ') + print('\nb: ', end=' ') for i in range(k): - print '%d ' % b[i].value, - print '\nd: ', + print('%d ' % b[i].value, end=' ') + print('\nd: ', end=' ') for i in range(k): - print '%d ' % d[i][0].value, - print '\n ', + print('%d ' % d[i][0].value, end=' ') + print('\n ', end=' ') for i in range(k): - print '%d ' % d[i][1].value, - print '\n\npg:', + print('%d ' % d[i][1].value, end=' ') + print('\n\npg:', end=' ') for i in range(k): - print '%d ' % pg[i][0].value, - print '\n ', + print('%d ' % pg[i][0].value, end=' ') + print('\n ', end=' ') for i in range(k): - print '%d ' % pg[i][1].value, - print '' + print('%d ' % pg[i][1].value, end=' ') + print('') for bit in c: pass#assert(bit.value == 0 or bit.value == 1) @@ -281,7 +281,7 @@ def BitAdd(a, b, bits_to_compute=None): try: pass#assert(s[i].value == 0 or s[i].value == 1) except AssertionError: - print '#assertion failed in BitAdd for s[%d]' % i + print('#assertion failed in BitAdd for s[%d]' % i) print_state() s[k] = c[k-1] #print_state() @@ -316,9 +316,9 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): try: pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P) except AssertionError: - print 'BitDec assertion failed' - print 'a =', a.value - print 'a mod 2^%d =' % k, (a.value % 2**k) + print('BitDec assertion failed') + print('a =', a.value) + print('a mod 2^%d =' % k, (a.value % 2**k)) return types.intbitint.bit_adder(list(bits(c,m)), r) @@ -503,7 +503,7 @@ def TruncPrRing(a, k, m, signed=True): a += types.sint.get_random_bit() << i return comparison.TruncLeakyInRing(a, k, m, signed=signed) else: - from types import sint + from .types import sint if signed: a += (1 << (k - 1)) if program.Program.prog.use_trunc_pr: diff --git a/Compiler/graph.py b/Compiler/graph.py index 9ce87be7..43bea6ba 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -19,10 +19,10 @@ class SparseDiGraph(object): if default_attributes is None: default_attributes = { 'merges': None, 'stop': -1, 'start': -1 } self.default_attributes = default_attributes - self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes)))) + self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes)))))) self.n = max_nodes # each node contains list of default attributes, followed by outoing edges - self.nodes = [self.default_attributes.values() for i in range(self.n)] + self.nodes = [list(self.default_attributes.values()) for i in range(self.n)] self.succ = [set() for i in range(self.n)] self.pred = [set() for i in range(self.n)] self.weights = {} @@ -45,7 +45,7 @@ class SparseDiGraph(object): raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n)) node = self.nodes[i] - for a,value in attr.items(): + for a,value in list(attr.items()): if a in self.default_attributes: node[self.attribute_pos[a]] = value else: @@ -72,7 +72,7 @@ class SparseDiGraph(object): #del self.weights[(v,i)] #self.nodes[v].remove(i) self.pred[i] = [] - self.nodes[i] = self.default_attributes.values() + self.nodes[i] = list(self.default_attributes.values()) def add_edge(self, i, j, weight=1): if j not in self[i]: @@ -111,7 +111,7 @@ def topological_sort(G, nbunch=None, pref=None): return G[node] else: def get_children(node): - if pref.has_key(node): + if node in pref: pref_set = set(pref[node]) for i in G[node]: if i not in pref_set: @@ -123,7 +123,7 @@ def topological_sort(G, nbunch=None, pref=None): yield i if nbunch is None: - nbunch = reversed(range(len(G))) + nbunch = reversed(list(range(len(G)))) for v in nbunch: # process all vertices in G if v in explored: continue @@ -170,8 +170,8 @@ def reverse_dag_shortest_paths(G, source): dist[source] = 0 for u in top_order: if u ==68273: - print 'dist[68273]', dist[u] - print 'pred[u]', G.pred[u] + print('dist[68273]', dist[u]) + print('pred[u]', G.pred[u]) if dist[u] is None: continue for v in G.pred[u]: @@ -207,7 +207,7 @@ def longest_paths(G, sources=None): G.weights[edge] = -G.weights[edge] dist = {} for source in sources: - print ('%s, ' % source), + print(('%s, ' % source), end=' ') dist[source] = dag_shortest_paths(G, source) #dist = johnson(G, sources) # reset weights diff --git a/Compiler/gs.py b/Compiler/gs.py index b7ad8ee6..ab3b958f 100644 --- a/Compiler/gs.py +++ b/Compiler/gs.py @@ -5,8 +5,8 @@ from Compiler import types from Compiler.util import * -from oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry -from library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str +from .oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry +from .library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str class OMatrixRow(object): def __init__(self, oram, base, add_type): @@ -27,7 +27,7 @@ class OMatrixRow(object): class OMatrix: def __init__(self, N, M=None, oram_type=OptimalORAM, int_type=types.sint): - print 'matrix', oram_type + print('matrix', oram_type) self.N = N self.M = M or N self.oram = oram_type(N * self.M, entry_size=log2(N), init_rounds=0, \ @@ -73,7 +73,7 @@ class OReverseMatrix(OMatrix): class OStack: def __init__(self, N, oram_type=OptimalORAM, int_type=types.sint): - print 'stack', oram_type + print('stack', oram_type) self.oram = oram_type(N, entry_size=log2(N), init_rounds=0, \ value_type=int_type.basic_type) self.size = types.MemValue(int_type(0)) @@ -247,4 +247,4 @@ class Matchmaker: self.reverse = reverse self.int_type = int_type self.basic_type = int_type.basic_type - print 'match', self.oram_type + print('match', self.oram_type) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 12f4f33e..2209cfac 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -11,7 +11,7 @@ documentation """ import itertools -import tools +from . import tools from random import randint from Compiler.config import * from Compiler.exceptions import * @@ -51,7 +51,7 @@ class ldsi(base.Instruction): @base.vectorize class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMC'] arg_format = ['cw','int'] @@ -62,7 +62,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @base.vectorize class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $s_i$ the value in memory \verb+S[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMS'] arg_format = ['sw','int'] @@ -73,7 +73,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @base.vectorize class stmc(base.DirectMemoryWriteInstruction): r""" Sets \verb+C[n]+ to be the value $c_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMC'] arg_format = ['c','int'] @@ -84,7 +84,7 @@ class stmc(base.DirectMemoryWriteInstruction): @base.vectorize class stms(base.DirectMemoryWriteInstruction): r""" Sets \verb+S[n]+ to be the value $s_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMS'] arg_format = ['s','int'] @@ -94,7 +94,7 @@ class stms(base.DirectMemoryWriteInstruction): @base.vectorize class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMINT'] arg_format = ['ciw','int'] @@ -104,7 +104,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): @base.vectorize class stmint(base.DirectMemoryWriteInstruction): r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMINT'] arg_format = ['ci','int'] @@ -227,7 +227,7 @@ class protectmemint(base.Instruction): @base.vectorize class movc(base.Instruction): r""" Assigns register $c_i$ the value in the register $c_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVC'] arg_format = ['cw','c'] @@ -238,7 +238,7 @@ class movc(base.Instruction): @base.vectorize class movs(base.Instruction): r""" Assigns register $s_i$ the value in the register $s_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVS'] arg_format = ['sw','s'] @@ -248,7 +248,7 @@ class movs(base.Instruction): @base.vectorize class movint(base.Instruction): r""" Assigns register $ci_i$ the value in the register $ci_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVINT'] arg_format = ['ciw','ci'] @@ -347,6 +347,21 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] +class nplayers(base.Instruction): + r""" Number of players """ + code = base.opcodes['NPLAYERS'] + arg_format = ['ciw'] + +class threshold(base.Instruction): + r""" Maximal number of corrupt players """ + code = base.opcodes['THRESHOLD'] + arg_format = ['ciw'] + +class playerid(base.Instruction): + r""" My player number """ + code = base.opcodes['PLAYERID'] + arg_format = ['ciw'] + ### ### Basic arithmetic ### @@ -738,7 +753,7 @@ class shrci(base.ClearShiftInstruction): class triple(base.DataInstruction): r""" Load secret variables $s_i$, $s_j$ and $s_k$ with the next multiplication triple. """ - __slots__ = ['data_type'] + __slots__ = [] code = base.opcodes['TRIPLE'] arg_format = ['sw','sw','sw'] data_type = 'triple' @@ -752,7 +767,7 @@ class triple(base.DataInstruction): class gbittriple(base.DataInstruction): r""" Load secret variables $s_i$, $s_j$ and $s_k$ with the next GF(2) multiplication triple. """ - __slots__ = ['data_type'] + __slots__ = [] code = base.opcodes['GBITTRIPLE'] arg_format = ['sgw','sgw','sgw'] data_type = 'bittriple' @@ -1400,7 +1415,8 @@ class convmodp(base.Instruction): bitlength = program.bit_length if bitlength is None else bitlength if bitlength > 64: raise CompilerError('%d-bit conversion requested ' \ - 'but integer registers only have 64 bits') + 'but integer registers only have 64 bits' % \ + bitlength) super(convmodp_class, self).__init__(*(args + (bitlength,))) @base.vectorize @@ -1433,7 +1449,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction): data_type = 'triple' def get_repeat(self): - return len(self.args) / 3 + return len(self.args) // 3 def merge_id(self): # can merge different sizes @@ -1508,6 +1524,8 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction): for j in range(self.args[i] - 2): yield 's' + field + gf2n_arg_format = arg_format + def bases(self): i = 0 while i < len(self.args): @@ -1515,7 +1533,7 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction): i += self.args[i] def get_repeat(self): - return sum(self.args[i] / 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() def get_def(self): return [self.args[i + 1] for i in self.bases()] @@ -1567,7 +1585,7 @@ class lts(base.CISC): arg_format = ['sw', 's', 's', 'int', 'int'] def expand(self): - from types import sint + from .types import sint a = sint() subs(a, self.args[1], self.args[2]) comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d4baa25d..8683c305 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -56,6 +56,9 @@ opcodes = dict( USE_PREP = 0x1C, STARTGRIND = 0x1D, STOPGRIND = 0x1E, + NPLAYERS = 0xE2, + THRESHOLD = 0xE3, + PLAYERID = 0xE4, # Addition ADDC = 0x20, ADDS = 0x21, @@ -277,7 +280,7 @@ def gf2n(instruction): vectorized GF_2^n instruction if a modp version exists. """ global_dict = inspect.getmodule(instruction).__dict__ - if global_dict.has_key('v' + instruction.__name__): + if 'v' + instruction.__name__ in global_dict: vectorized = True else: vectorized = False @@ -316,7 +319,7 @@ def gf2n(instruction): if 'gf2n_arg_format' in instruction_cls.__dict__: arg_format = instruction_cls.gf2n_arg_format elif isinstance(instruction_cls.arg_format, itertools.repeat): - __f = instruction_cls.arg_format.next() + __f = next(instruction_cls.arg_format) if __f != 'int' and __f != 'p': arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) else: @@ -420,7 +423,7 @@ class ClearIntAF(RegisterArgFormat): class IntArgFormat(ArgFormat): @classmethod def check(cls, arg): - if not isinstance(arg, (int, long)): + if not isinstance(arg, int): raise ArgumentError(arg, 'Expected an integer-valued argument') @classmethod @@ -512,7 +515,7 @@ class Instruction(object): Instruction.count += 1 if Instruction.count % 100000 == 0: - print "Compiled %d lines at" % self.__class__.count, time.asctime() + print("Compiled %d lines at" % self.__class__.count, time.asctime()) def get_code(self): return self.code @@ -535,7 +538,7 @@ class Instruction(object): def check_args(self): """ Check the args match up with that specified in arg_format """ - for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)): + for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)): if arg is None: if not isinstance(self.arg_format, (list, tuple)): break # end of optional arguments @@ -758,8 +761,6 @@ class ClearShiftInstruction(ClearImmediate): else: # assume 64-bit machine bits = 63 - elif program.options.ring: - bits = int(program.options.ring) - 1 if self.args[2] > bits: raise CompilerError('Shifting by more than %d bits ' 'not implemented' % bits) diff --git a/Compiler/library.py b/Compiler/library.py index 25cf861e..5314ee94 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1,4 +1,4 @@ -from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single +from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint from Compiler.instructions import * from Compiler.util import tuplify,untuplify from Compiler import instructions,instructions_base,comparison,program,util @@ -6,6 +6,7 @@ import inspect,math import random import collections import operator +from functools import reduce def get_program(): return instructions.program @@ -101,6 +102,8 @@ def print_ln_if(cond, ss, *args): else: subs = ss.split('%s') assert len(subs) == len(args) + 1 + if isinstance(cond, localint): + cond = cond._v cond = cint.conv(cond) for i, s in enumerate(subs): if i != 0: @@ -274,7 +277,7 @@ class FunctionTapeCall: def join(self): self.thread.join() instructions.program.free(self.base, 'ci') - for reg_type,addr in self.bases.iteritems(): + for reg_type,addr in self.bases.items(): get_program().free(addr, reg_type.reg_type) class Function: @@ -287,7 +290,7 @@ class Function: self.compile_args = compile_args def __call__(self, *args): args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) - get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x) + get_reg_type = lambda x: regint if isinstance(x, int) else type(x) if len(args) not in self.type_args: # first call type_args = collections.defaultdict(list) @@ -296,9 +299,11 @@ class Function: def wrapped_function(*compile_args): base = get_arg() bases = dict((t, regint.load_mem(base + i)) \ - for i,t in enumerate(sorted(type_args))) + for i,t in enumerate(sorted(type_args, + key=lambda x: + x.reg_type))) runtime_args = [None] * len(args) - for t in sorted(type_args): + for t in sorted(type_args, key=lambda x: x.reg_type): for i,i_arg in enumerate(type_args[t]): runtime_args[i_arg] = t.load_mem(bases[t] + i) return self.function(*(list(compile_args) + runtime_args)) @@ -308,7 +313,8 @@ class Function: base = instructions.program.malloc(len(type_args), 'ci') bases = dict((t, get_program().malloc(len(type_args[t]), t)) \ for t in type_args) - for i,reg_type in enumerate(sorted(type_args)): + for i,reg_type in enumerate(sorted(type_args, + key=lambda x: x.reg_type)): store_in_mem(bases[reg_type], base + i) for j,i_arg in enumerate(type_args[reg_type]): if get_reg_type(args[i_arg]) != reg_type: @@ -353,13 +359,13 @@ class FunctionBlock(Function): block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node - print 'Compiling function', self.name + print('Compiling function', self.name) result = wrapped_function(*self.compile_args) if result is not None: self.result = memorize(result) else: self.result = None - print 'Done compiling function', self.name + print('Done compiling function', self.name) p_return_address = get_tape().program.malloc(1, 'ci') get_tape().function_basicblocks[block] = p_return_address return_address = regint.load_mem(p_return_address) @@ -429,7 +435,7 @@ def sort(a): res = a for i in range(len(a)): - for j in reversed(range(i)): + for j in reversed(list(range(i))): res[j], res[j+1] = cond_swap(res[j], res[j+1]) return res @@ -443,7 +449,7 @@ def odd_even_merge(a): odd_even_merge(even) odd_even_merge(odd) a[0] = even[0] - for i in range(1, len(a) / 2): + for i in range(1, len(a) // 2): a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i]) a[-1] = odd[-1] @@ -451,8 +457,8 @@ def odd_even_merge_sort(a): if len(a) == 1: return elif len(a) % 2 == 0: - lower = a[:len(a)/2] - upper = a[len(a)/2:] + lower = a[:len(a)//2] + upper = a[len(a)//2:] odd_even_merge_sort(lower) odd_even_merge_sort(upper) a[:] = lower + upper @@ -472,10 +478,10 @@ def chunky_odd_even_merge_sort(a): def round(): for i in range(len(a)): a[i] = type(a[i]).load_mem(i * a[i].sizeof()) - for i in range(len(a) / l): - for j in range(l / k): + for i in range(len(a) // l): + for j in range(l // k): base = i * l + j - step = l / k + step = l // k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: @@ -514,7 +520,7 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use def run_chunk(size, base): if size not in chunks: def swap_list(list_base): - for i in range(size / 2): + for i in range(size // 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) @@ -526,8 +532,8 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size / n_chunks / 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + lower_size = size // n_chunks // 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size @@ -603,10 +609,10 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use k *= 2 size = 0 instructions.program.curr_tape.merge_opens = False - for i in range(n / l): - for j in range(l / k): + for i in range(n // l): + for j in range(l // k): base = i * l + j - step = l / k + step = l // k size += run_setup(k, a_base + base, step, tmp_base + size) run_threads_in_rounds(pre_threads) run_round(size) @@ -651,7 +657,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= def run_chunk(size, base): if size not in chunks: def swap_list(list_base): - for i in range(size / 2): + for i in range(size // 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) @@ -663,8 +669,8 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size / n_chunks / 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + lower_size = size // n_chunks // 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size @@ -692,7 +698,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= def outer(i): def inner(j): base = j - step = l / k + step = l // k if k == 2: tmp_addr = regint.load_mem(tmp_i) load_and_store(base, tmp_addr) @@ -704,19 +710,19 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= load_and_store(m, tmp_addr) store_in_mem(tmp_addr + 1, tmp_i) range_loop(inner2, base + step, base + (k - 1) * step, step) - range_loop(inner, a_base + i * l, a_base + i * l + l / k) + range_loop(inner, a_base + i * l, a_base + i * l + l // k) instructions.program.curr_tape.merge_opens = False to_tmp = True store_in_mem(tmp_base, tmp_i) - range_loop(outer, n / l) + range_loop(outer, n // l) if k == 2: run_round(n) else: - run_round(n / k * (k - 2)) + run_round(n // k * (k - 2)) instructions.program.curr_tape.merge_opens = False to_tmp = False store_in_mem(tmp_base, tmp_i) - range_loop(outer, n / l) + range_loop(outer, n // l) if isinstance(a, list): instructions.program.restart_main_thread() @@ -734,15 +740,15 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32): k = 1 while k < l: k *= 2 - n_outer = len(a) / l - n_inner = l / k - n_innermost = 1 if k == 2 else k / 2 - 1 - @for_range_parallel(n_parallel / n_innermost / n_inner, n_outer) + n_outer = len(a) // l + n_inner = l // k + n_innermost = 1 if k == 2 else k // 2 - 1 + @for_range_parallel(n_parallel // n_innermost // n_inner, n_outer) def loop(i): - @for_range_parallel(n_parallel / n_innermost, n_inner) + @for_range_parallel(n_parallel // n_innermost, n_inner) def inner(j): base = i*l + j - step = l/k + step = l//k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: @@ -805,7 +811,7 @@ def range_loop(loop_body, start, stop=None, step=None): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) / step) * x[0] + lambda x: ((stop - start) // step) * x[0] def for_range(start, stop=None, step=None): """ Execute loop bodies consecutively """ @@ -840,7 +846,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], my_n_parallel = n_parallel if isinstance(n_parallel, int): if isinstance(n_loops, int): - loop_rounds = n_loops / n_parallel \ + loop_rounds = n_loops // n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel @@ -884,7 +890,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], regint.push(k) return i + k my_n_parallel = n_opt_loops - loop_rounds = n_loops / my_n_parallel + loop_rounds = n_loops // my_n_parallel blocks = get_tape().basicblocks n_to_merge = 5 if loop_rounds == 1 and parent_block is blocks[-n_to_merge]: @@ -966,7 +972,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ indices = [] for n in reversed(split): indices.insert(0, i % n) - i /= n + i //= n return loop_body(*indices) return new_body new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req) @@ -979,7 +985,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ else: return dec def decorator(loop_body): - thread_rounds = n_loops / n_threads + thread_rounds = n_loops // n_threads remainder = n_loops % n_threads for t in thread_mem_req: if t != regint: @@ -1233,10 +1239,25 @@ def stop_timer(timer_id=0): stop(timer_id) get_tape().start_new_basicblock(name='post-stop-timer') +def get_number_of_players(): + res = regint() + nplayers(res) + return res + +def get_threshold(): + res = regint() + threshold(res) + return res + +def get_player_id(): + res = localint() + playerid(res._v) + return res + # Fixed point ops from math import ceil, log -from floatingpoint import PreOR, TruncPr, two_power, shift_two +from .floatingpoint import PreOR, TruncPr, two_power, shift_two def approximate_reciprocal(divisor, k, f, theta): """ @@ -1369,7 +1390,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): # no probabilistic truncation in binary circuits nearest = True res_f = f - f = max((k - nearest) / 2 + 1, f) + f = max((k - nearest) // 2 + 1, f) assert 2 * f > k - nearest theta = int(ceil(log(k/3.5) / log(2))) alpha = b.get_type(2 * k).two_power(2*f) @@ -1387,7 +1408,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): x = x.round(2*k, 2*f, kappa, nearest, signed=True) y = y.extend(2 * k) * (alpha + x).extend(2 * k) - y = y.round(k + 2 * f, 3 * f - res_f, kappa, nearest, signed=True) + y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False): """ diff --git a/Compiler/ml.py b/Compiler/ml.py index 4c8e59b4..9d579055 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -3,6 +3,7 @@ import mpc_math, math from Compiler.types import * from Compiler.types import _unreduced_squant from Compiler.library import * +from functools import reduce def log_e(x): return mpc_math.log_fx(x, math.e) @@ -129,7 +130,7 @@ class DenseBase(Layer): tmp[j][k] = sfix.unreduced_dot_product(a, b) if self.d_in * self.d_out < 100000: - print 'reduce at once' + print('reduce at once') @multithread(self.n_threads, self.d_in * self.d_out) def _(base, size): self.nabla_W.assign_vector( @@ -386,7 +387,7 @@ class QuantConvBase(QuantBase): s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) self.weights.input_from(player, budget=100000) self.bias.input_from(player) - print 'WARNING: assuming that bias quantization parameters are correct' + print('WARNING: assuming that bias quantization parameters are correct') self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params) @@ -404,7 +405,7 @@ class QuantConvBase(QuantBase): start_timer(2) n_outputs = reduce(operator.mul, self.output_shape) if n_outputs % self.n_threads == 0: - n_per_thread = n_outputs / self.n_threads + n_per_thread = n_outputs // self.n_threads @for_range_opt_multithread(self.n_threads, self.n_threads) def _(i): res = _unreduced_squant( @@ -556,7 +557,7 @@ class QuantAveragePool2d(QuantBase): self.filter_size = filter_size def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -567,7 +568,7 @@ class QuantAveragePool2d(QuantBase): _, output_h, output_w, n_channels_out = self.output_shape n = input_h * input_w - print 'divisor: ', n + print('divisor: ', n) assert output_h == output_w == 1 assert n_channels_in == n_channels_out @@ -599,7 +600,7 @@ class QuantAveragePool2d(QuantBase): acc += self.X[0][in_y][in_x][c].v #fc += 1 logn = int(math.log(n, 2)) - acc = (acc + n / 2) + acc = (acc + n // 2) if 2 ** logn == n: acc = acc.round(self.output_squant.params.k + logn, logn, nearest=True) @@ -614,7 +615,7 @@ class QuantReshape(QuantBase): super(QuantReshape, self).__init__(input_shape, output_shape) def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') _ = self.new_squant() for s in self.input_squant, _, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -628,7 +629,7 @@ class QuantReshape(QuantBase): class QuantSoftmax(QuantBase): def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -666,14 +667,14 @@ class Optimizer: N = self.layers[0].N assert self.layers[-1].N == N assert N % 2 == 0 - n = N / 2 + n = N // 2 @for_range(n) def _(i): self.layers[-1].Y[i] = 0 self.layers[-1].Y[i + n] = 1 n_per_epoch = int(math.ceil(1. * max(len(X) for X in self.X_by_label) / n)) - print '%d runs per epoch' % n_per_epoch + print('%d runs per epoch' % n_per_epoch) indices_by_label = [] for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) @@ -794,8 +795,8 @@ class SGD(Optimizer): x = x.reveal() print_ln_if((x > 1000) + (x < -1000), name + ': %s %s %s %s', - *[y.v.reveal() for y in old, red_old, \ - new, diff]) + *[y.v.reveal() for y in (old, red_old, \ + new, diff)]) if self.debug: d = delta_theta.get_vector().reveal() a = cfix.Array(len(d.v)) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 23926f24..b2cc4803 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -474,7 +474,7 @@ def norm_simplified_SQ(b, k): m_odd = m_odd + z[i] # construct w, - k_over_2 = k / 2 + 1 + k_over_2 = k // 2 + 1 w_array = [0] * (k_over_2) w_array[0] = z[0] for i in range(1, k_over_2): @@ -510,7 +510,7 @@ def sqrt_simplified_fx(x): m_odd = (1 - 2 * m_odd) + m_odd w = (w * 2 - w) * (1-m_odd) + w # map number to use sfix format and instantiate the number - w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2)) + w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2)) # obtains correct 2 ** (m/2) w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w # produce x/ 2^(m/2) diff --git a/Compiler/oram.py b/Compiler/oram.py index 12786e4f..93e34c84 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -4,6 +4,7 @@ import collections import itertools import operator import sys +from functools import reduce from Compiler.types import * from Compiler.types import _secret @@ -95,7 +96,7 @@ class gf2nBlock(Block): prod_bits = [start * bit for bit in value_bits] anti_bits = [v - p for v,p in zip(value_bits,prod_bits)] self.lower = sum(bit << i for i,bit in enumerate(prod_bits[:length])) - self.bits = map(operator.add, anti_bits[:length], prod_bits[length:]) + \ + self.bits = list(map(operator.add, anti_bits[:length], prod_bits[length:])) + \ anti_bits[length:] self.adjust = if_else(start, 1 << length, cgf2n(1)) elif entries_per_block < 4: @@ -105,7 +106,7 @@ class gf2nBlock(Block): choice_bits = demux(start_bits) inv_bits = [1 - bit for bit in floatingpoint.PreOR(choice_bits, None)] mask_bits = sum(([x] * length for x in inv_bits), []) - lower_bits = map(operator.mul, value_bits, mask_bits) + lower_bits = list(map(operator.mul, value_bits, mask_bits)) self.lower = sum(bit << i for i,bit in enumerate(lower_bits)) self.bits = [sum(map(operator.mul, choice_bits, value_bits[i::length])) \ for i in range(length)] @@ -124,7 +125,7 @@ class gf2nBlock(Block): pre_bits = floatingpoint.PreOpL(lambda x,y,z=None: x + y, bits) inv_bits = [1 - bit for bit in pre_bits] mask_bits = sum(([x] * length for x in inv_bits), []) - lower_bits = map(operator.mul, value_bits, mask_bits) + lower_bits = list(map(operator.mul, value_bits, mask_bits)) masked = self.value - sum(bit << i for i,bit in enumerate(lower_bits)) self.lower = sum(bit << i for i,bit in enumerate(lower_bits)) self.bits = (masked / adjust).bit_decompose(used_bits) @@ -177,12 +178,12 @@ def demux_list(x): return [1] elif n == 1: return [1 - x[0], x[0]] - a = demux_list(x[:n/2]) - b = demux_list(x[n/2:]) + a = demux_list(x[:n//2]) + b = demux_list(x[n//2:]) n_a = len(a) a *= len(b) b = reduce(operator.add, ([i] * n_a for i in b)) - res = map(operator.mul, a, b) + res = list(map(operator.mul, a, b)) return res def demux_array(x, res=None): @@ -193,12 +194,12 @@ def demux_array(x, res=None): res[0] = 1 - x[0] res[1] = x[0] else: - a = Array(2**(n/2), type(x[0])) - a.assign(demux(x[:n/2])) - b = Array(2**(n-n/2), type(x[0])) - b.assign(demux(x[n/2:])) + a = Array(2**(n//2), type(x[0])) + a.assign(demux(x[:n//2])) + b = Array(2**(n-n//2), type(x[0])) + b.assign(demux(x[n//2:])) @for_range_multithread(get_n_threads(len(res)), \ - max(1, n_parallel / len(b)), len(a)) + max(1, n_parallel // len(b)), len(a)) def f(i): @for_range_parallel(n_parallel, len(b)) def f(j): @@ -234,7 +235,7 @@ class Value(object): return Value(other * self.value, other * self.empty) __rmul__ = __mul__ def equal(self, other, length=None): - if isinstance(other, (int, long)) and isinstance(self.value, (int, long)): + if isinstance(other, int) and isinstance(self.value, int): return (1 - self.empty) * (other == self.value) return (1 - self.empty) * self.value.equal(other, length) def reveal(self): @@ -252,9 +253,9 @@ class Value(object): try: value = self.empty while True: - if value in (1, 1L): + if value == 1: return '<>' - if value in (0, 0L): + if value == 0: return '<%s>' % str(self.value) value = value.value except: @@ -297,8 +298,8 @@ class Entry(object): self.created_non_empty = False if x is None: v = iter(v) - self.is_empty = v.next() - self.v = v.next() + self.is_empty = next(v) + self.v = next(v) self.x = ValueTuple(v) else: if empty is None: @@ -332,7 +333,7 @@ class Entry(object): try: return Entry(i + j for i,j in zip(self, other)) except: - print self, other + print(self, other) raise def __sub__(self, other): return Entry(i - j for i,j in zip(self, other)) @@ -342,7 +343,7 @@ class Entry(object): try: return Entry(other * i for i in self) except: - print self, other + print(self, other) raise __rmul__ = __mul__ def reveal(self): @@ -372,8 +373,8 @@ class RefRAM(object): for t,array in zip(self.entry_type,oram.ram.l)] self.index = index def init_mem(self, empty_entry): - print 'init ram' - for a,value in zip(self.l, empty_entry.defaults.values()): + print('init ram') + for a,value in zip(self.l, list(empty_entry.defaults.values())): # don't use threads if n_threads explicitly set to 1 a.assign_all(value, n_threads != 1, conv=False) def get_empty_bits(self): @@ -392,14 +393,14 @@ class RefRAM(object): return [Value(self.l[2+index][i], self.l[0][i]) for i in range(self.size)] def __getitem__(self, index): if print_access: - print 'get', id(self), index + print('get', id(self), index) return Entry(a[index] for a in self.l) def __setitem__(self, index, value): if print_access: - print 'set', id(self), index + print('set', id(self), index) if not isinstance(value, Entry): raise Exception('entries only please: %s' % str(value)) - for i,(a,v) in enumerate(zip(self.l, value.values())): + for i,(a,v) in enumerate(zip(self.l, list(value.values()))): a[index] = v def __len__(self): return self.size @@ -524,7 +525,7 @@ class RefTrivialORAM(EndRecursiveEviction): self.value_type, self.entry_size = oram.internal_entry_size() self.size = oram.bucket_size def init_mem(self): - print 'init trivial oram' + print('init trivial oram') self.ram.init_mem(self.empty_entry(apply_type=False)) def search(self, read_index): if use_binary_search and self.value_type == sgf2n: @@ -554,7 +555,7 @@ class RefTrivialORAM(EndRecursiveEviction): self.last_index = read_index found, empty = self.search(read_index) entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, found, entries) + prod_entries = list(map(operator.mul, found, entries)) read_value = sum((entry.x.skip(skip) for entry in prod_entries), \ empty * empty_entry.x.skip(skip)) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): @@ -566,7 +567,7 @@ class RefTrivialORAM(EndRecursiveEviction): def read_and_remove_by_public(self, index): empty_entry = self.empty_entry(False) entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, index, entries) + prod_entries = list(map(operator.mul, index, entries)) read_entry = reduce(operator.add, prod_entries) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): self.ram[i] = entry - prod_entry + index[i] * empty_entry @@ -574,7 +575,7 @@ class RefTrivialORAM(EndRecursiveEviction): @method_block def _read(self, index): found, empty = self.search(index) - read_value = sum(map(operator.mul, found, self.ram.get_values()), \ + read_value = sum(list(map(operator.mul, found, self.ram.get_values())), \ empty * self.empty_entry(False).x) return read_value, empty @method_block @@ -583,8 +584,8 @@ class RefTrivialORAM(EndRecursiveEviction): found, not_found = self.search(index) add_here = self.find_first_empty() entries = [entry for entry in self.ram] - prod_values = map(operator.mul, found, \ - (entry.x for entry in entries)) + prod_values = list(map(operator.mul, found, \ + (entry.x for entry in entries))) read_value = sum(prod_values, not_found * empty_entry.x) new_value = ValueTuple(new_value) \ if isinstance(new_value, (tuple, list)) \ @@ -699,15 +700,15 @@ class RefTrivialORAM(EndRecursiveEviction): for k in range(2**(j)): t = k + 2**(j) - 1 if k % 2 == 0: - M += bit_prods[(t-1)/2] * mult_tree[t] + M += bit_prods[(t-1)//2] * mult_tree[t] b = 1 - M.equal(0, 40, expand) for k in range(2**j): t = k + 2**j - 1 if k % 2 == 0: - v = bit_prods[(t-1)/2] * b - bit_prods[t] = bit_prods[(t-1)/2] - v + v = bit_prods[(t-1)//2] * b + bit_prods[t] = bit_prods[(t-1)//2] - v else: bit_prods[t] = v return bit_prods[n-1:n-1+self.size], 1 - bit_prods[0] @@ -734,7 +735,7 @@ class RefTrivialORAM(EndRecursiveEviction): print_ln('Bucket overflow') crash() if debug and not sum(add_here) and not new_entry.empty(): - print self.empty_entry() + print(self.empty_entry()) raise Exception('no space for %s in %s' % (str(new_entry), str(self))) self.check(new_entry=new_entry, op='add') def pop(self): @@ -746,7 +747,7 @@ class RefTrivialORAM(EndRecursiveEviction): pop_here = [prefix_empty[i+1] - prefix_empty[i] \ for i in range(len(self.ram))] entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, pop_here, self.ram) + prod_entries = list(map(operator.mul, pop_here, self.ram)) result = (1 - sum(pop_here)) * empty_entry result = sum(prod_entries, result) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): @@ -980,7 +981,7 @@ class LocalIndexStructure(List): @for_range(init_rounds if init_rounds > 0 else size) def f(i): self.l[0][i] = random_block(entry_size, value_type) - print 'index size:', size + print('index size:', size) def update(self, index, value, evict=None): read_value = self[index] #print 'read', index, read_value @@ -1005,7 +1006,7 @@ class TreeORAM(AbstractORAM): """ Tree ORAM. """ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): - print 'create oram of size', size + print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size delta = 3 @@ -1013,9 +1014,9 @@ class TreeORAM(AbstractORAM): # size + 1 for bucket overflow check self.bucket_size = min(int(math.ceil((1 + delta) * k)), size + 1) self.D = log2(max(size / k, 2)) - print 'bucket size:', self.bucket_size - print 'depth:', self.D - print 'complexity:', self.bucket_size * (self.D + 1) + print('bucket size:', self.bucket_size) + print('depth:', self.D) + print('complexity:', self.bucket_size * (self.D + 1)) self.value_type = value_type if entry_size is not None: self.value_length = len(tuplify(entry_size)) @@ -1279,8 +1280,8 @@ class TreeORAM(AbstractORAM): # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: parity = (empty_positions[i]+1) % 2 - half = (empty_positions[i]+1 - parity) / 2 - half_max = self.bucket_size / 2 + half = (empty_positions[i]+1 - parity) // 2 + half_max = self.bucket_size // 2 bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0] bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0] @@ -1384,11 +1385,11 @@ def get_parallel(index_size, value_type, value_length): value_size = get_value_size(value_type) if value_type == sint: value_size *= 2 - res = max(1, min(50 * 32 / (value_length * value_size), \ - 800 * 32 / (value_length * index_size))) + res = max(1, min(50 * 32 // (value_length * value_size), \ + 800 * 32 // (value_length * index_size))) if comparison.const_rounds: - res = max(1, res / 2) - print 'Reading %d buckets in parallel' % res + res = max(1, res // 2) + print('Reading %d buckets in parallel' % res) return res class PackedIndexStructure(object): @@ -1403,7 +1404,7 @@ class PackedIndexStructure(object): self.value_type = value_type for demux_bits in range(max_demux_bits + 1): self.log_entries_per_element = min(log2(size), \ - int(math.floor(math.log(float(get_value_size(value_type)) / \ + int(math.floor(math.log(float(get_value_size(value_type)) // \ sum(self.entry_size), 2)))) self.log_elements_per_block = \ max(0, min(demux_bits, log2(size) - \ @@ -1423,24 +1424,24 @@ class PackedIndexStructure(object): self.elements_per_entry = len(self.split_sizes) self.log_elements_per_block = log2(self.elements_per_entry) self.log_entries_per_element = -self.log_elements_per_block - print 'split sizes:', self.split_sizes + print('split sizes:', self.split_sizes) self.log_entries_per_block = \ self.log_elements_per_block + self.log_entries_per_element self.elements_per_block = 2**self.log_elements_per_block self.entries_per_element = 2**self.log_entries_per_element self.entries_per_block = 2**self.log_entries_per_block self.used_bits = self.entries_per_element * sum(self.entry_size) - real_size = -(-size / self.entries_per_block) - print 'packed size:', real_size - print 'index size:', size - print 'entry size:', self.entry_size - print 'log(entries per element):', self.log_entries_per_element - print 'entries per element:', self.entries_per_element - print 'log(entries per block):', self.log_entries_per_block - print 'entries per block:', self.entries_per_block - print 'log(elements per block):', self.log_elements_per_block - print 'elements per block:', self.elements_per_block - print 'used bits:', self.used_bits + real_size = -(-size // self.entries_per_block) + print('packed size:', real_size) + print('index size:', size) + print('entry size:', self.entry_size) + print('log(entries per element):', self.log_entries_per_element) + print('entries per element:', self.entries_per_element) + print('log(entries per block):', self.log_entries_per_block) + print('entries per block:', self.entries_per_block) + print('log(elements per block):', self.log_elements_per_block) + print('elements per block:', self.elements_per_block) + print('used bits:', self.used_bits) entry_size = [self.used_bits] * self.elements_per_block if real_size > 1: # no need to init underlying ORAM, will be initialized implicitely @@ -1454,10 +1455,10 @@ class PackedIndexStructure(object): self.index_type = self.l.index_type if init_rounds: if init_rounds > 0: - real_init_rounds = init_rounds * real_size / size + real_init_rounds = init_rounds * real_size // size else: real_init_rounds = real_size - print 'packed init rounds:', real_init_rounds + print('packed init rounds:', real_init_rounds) @for_range(real_init_rounds) def f(i): if random_init: @@ -1467,7 +1468,7 @@ class PackedIndexStructure(object): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) - print 'index initialized, size', size + print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple (storage address, index with storage cell, index within @@ -1501,16 +1502,16 @@ class PackedIndexStructure(object): self.block = block self.index_vector = \ demux(bit_decompose(self.b, self.pack.log_elements_per_block)) - self.vector = map(operator.mul, self.index_vector, block) + self.vector = list(map(operator.mul, self.index_vector, block)) self.element = get_block(sum(self.vector), self.c, \ self.pack.entry_size, \ self.pack.entries_per_element) return tuple(self.element.get_slice()) def write(self, value): self.element.set_slice(value) - anti_vector = map(operator.sub, self.block, self.vector) + anti_vector = list(map(operator.sub, self.block, self.vector)) updated_vector = [self.element.value * i for i in self.index_vector] - updated_block = map(operator.add, anti_vector, updated_vector) + updated_block = list(map(operator.add, anti_vector, updated_vector)) return updated_block class MultiSlicer(object): def __init__(self, pack, index): @@ -1685,7 +1686,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100): value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() - print 'initialized' + print('initialized') print_ln('initialized') stop_timer() # synchronize @@ -1718,7 +1719,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100): def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100): oram = oram_type(N, value_type=value_type, entry_size=32, \ init_rounds=0) - print 'initialized' + print('initialized') print_reg(cint(0), 'init') stop_timer() # synchronize @@ -1731,11 +1732,11 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations= def f(i): oram.access(value_type(i % N), value_type(0), value_type(True)) oram.access(value_type(i % N), value_type(i % N), value_type(True)) - print 'first write' + print('first write') time() x = oram.access(value_type(i % N), value_type(0), value_type(False)) x[0][0].reveal().print_reg('writ') - print 'first read' + print('first read') # @for_range(iterations) # def f(i): # x = oram.access(value_type(i % N), value_type(0), value_type(False), \ @@ -1747,7 +1748,7 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations= def test_batch_init(oram_type, N): value_type = sint oram = oram_type(N, value_type) - print 'initialized' + print('initialized') print_reg(cint(0), 'init') oram.batch_init([value_type(i) for i in range(N)]) print_reg(cint(0), 'done') diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index e1265a0c..fb1601c3 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -1,9 +1,10 @@ if '_Array' not in dir(): - from oram import * - import permutation + from Compiler.oram import * + from Compiler import permutation _Array = Array -import oram +from Compiler import oram +from functools import reduce #import pdb @@ -140,7 +141,7 @@ class PathORAM(TreeORAM): bucket_size=2, init_rounds=-1): #if size <= k: # raise CompilerError('ORAM size too small') - print 'create oram of size', size + print('create oram of size', size) self.bucket_oram = bucket_oram self.bucket_size = bucket_size self.D = log2(size) @@ -240,7 +241,7 @@ class PathORAM(TreeORAM): self.state.write(self.value_type(leaf)) - print 'eviction leaf =', leaf + print('eviction leaf =', leaf) # load the path for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): @@ -325,7 +326,7 @@ class PathORAM(TreeORAM): # at most one 1 in found empty = 1 - sum(found) - prod_entries = map(operator.mul, found, entries) + prod_entries = list(map(operator.mul, found, entries)) read_value = sum((entry.x.skip(skip) for entry in prod_entries), \ empty * empty_entry.x.skip(skip)) for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)): @@ -528,7 +529,7 @@ class PathORAM(TreeORAM): values = (ValueTuple(x) for x in zip(*self.read_value)) not_empty = [1 - x for x in self.read_empty] read_empty = 1 - sum(not_empty) - read_value = sum(map(operator.mul, not_empty, values), \ + read_value = sum(list(map(operator.mul, not_empty, values)), \ ValueTuple(0 for i in range(self.value_length))) self.check(u) Program.prog.curr_tape.\ @@ -545,7 +546,7 @@ class PathORAM(TreeORAM): yield bucket def bucket_indices_on_path_to(self, leaf): leaf = regint(leaf) - yield range(self.bucket_size) + yield list(range(self.bucket_size)) index = 0 for i in range(self.D): index = 2*index + 1 + regint(cint(leaf) & 1) @@ -742,7 +743,7 @@ class PathORAM(TreeORAM): try: self.stash.add(e) except Exception: - print self + print(self) raise if evict: self.evict() diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 28896533..79c32e27 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -69,7 +69,7 @@ def odd_even_merge(a, comp): odd_even_merge(even, comp) odd_even_merge(odd, comp) a[0] = even[0] - for i in range(1, len(a) / 2): + for i in range(1, len(a) // 2): a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp) a[-1] = odd[-1] @@ -77,8 +77,8 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): if len(a) == 1: return elif len(a) % 2 == 0: - lower = a[:len(a)/2] - upper = a[len(a)/2:] + lower = a[:len(a)//2] + upper = a[len(a)//2:] odd_even_merge_sort(lower, comp) odd_even_merge_sort(upper, comp) a[:] = lower + upper @@ -137,7 +137,7 @@ def random_perm(n): if not Program.prog.options.insecure: raise CompilerError('no secure implementation of Waksman permution, ' 'use --insecure to activate') - a = range(n) + a = list(range(n)) for i in range(n-1, 0, -1): j = randint(0, i) t = a[i] @@ -155,10 +155,10 @@ def configure_waksman(perm): n = len(perm) if n == 2: return [(perm[0], perm[0])] - I = [None] * (n/2) - O = [None] * (n/2) - p0 = [None] * (n/2) - p1 = [None] * (n/2) + I = [None] * (n//2) + O = [None] * (n//2) + p0 = [None] * (n//2) + p1 = [None] * (n//2) inv_perm = [0] * n for i, p in enumerate(perm): @@ -170,7 +170,7 @@ def configure_waksman(perm): except ValueError: break #print 'j =', j - O[j/2] = 0 + O[j//2] = 0 via = 0 j0 = j while True: @@ -178,10 +178,10 @@ def configure_waksman(perm): i = inv_perm[j] #print ' p0[%d] = %d' % (inv_perm[j]/2, j/2) - p0[i/2] = j/2 + p0[i//2] = j//2 - I[i/2] = i % 2 - O[j/2] = j % 2 + I[i//2] = i % 2 + O[j//2] = j % 2 #print ' O[%d] = %d' % (j/2, j % 2) if i % 2 == 1: i -= 1 @@ -198,7 +198,7 @@ def configure_waksman(perm): j += 1 #j, via = set_swapper(O, i, via, perm) #print ' p1[%d] = %d' % (i/2, perm[i]/2) - p1[i/2] = perm[i]/2 + p1[i//2] = perm[i]//2 #print ' i = %d, j = %d' %(i,j) if j == j0: @@ -206,8 +206,8 @@ def configure_waksman(perm): if None not in p0 and None not in p1: break - assert sorted(p0) == range(n/2) - assert sorted(p1) == range(n/2) + assert sorted(p0) == list(range(n//2)) + assert sorted(p1) == list(range(n//2)) p0_config = configure_waksman(p0) p1_config = configure_waksman(p1) return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] @@ -219,22 +219,22 @@ def waksman(a, config, depth=0, start=0, reverse=False): a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start]) return - a0 = [0] * (n/2) - a1 = [0] * (n/2) - for i in range(n/2): + a0 = [0] * (n//2) + a1 = [0] * (n//2) + for i in range(n//2): if reverse: - a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n/2 + start]) + a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start]) else: a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start]) waksman(a0, config, depth+1, start, reverse) - waksman(a1, config, depth+1, start + n/2, reverse) + waksman(a1, config, depth+1, start + n//2, reverse) - for i in range(n/2): + for i in range(n//2): if reverse: a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start]) else: - a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n/2 + start]) + a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start]) WAKSMAN_FUNCTIONS = {} @@ -263,11 +263,11 @@ def iter_waksman(a, config, reverse=False): outwards = 1 - inwards sizeval = size - #for k in range(n/2): - @for_range_parallel(200, n/2) + #for k in range(n//2): + @for_range_parallel(200, n//2) def f(k): j = cint(k) % sizeval - i = (cint(k) - j)/sizeval + i = (cint(k) - j)//sizeval base = 2*i*sizeval in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards) @@ -297,7 +297,7 @@ def iter_waksman(a, config, reverse=False): # going into middle of network @for_range(logn) def f(i): - size.write(n/(2*nblocks)) + size.write(n//(2*nblocks)) conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 1) @@ -307,20 +307,20 @@ def iter_waksman(a, config, reverse=False): nblocks.write(nblocks*2) depth.write(depth+1) - nblocks.write(nblocks/4) + nblocks.write(nblocks//4) depth.write(depth-2) # and back out @for_range(logn-1) def f(i): - size.write(n/(2*nblocks)) + size.write(n//(2*nblocks)) conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 0) for i in range(n): a[i] = a2[i] - nblocks.write(nblocks/2) + nblocks.write(nblocks//2) depth.write(depth-1) ## going into middle of network @@ -375,7 +375,7 @@ def config_shuffle(n, value_type): if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) - perm += range(n, m) + perm += list(range(n, m)) config_bits = configure_waksman(perm) # 2-D array config = Array(len(config_bits) * len(perm), value_type.reg_type) diff --git a/Compiler/program.py b/Compiler/program.py index 76c3b00b..eebc411a 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -3,15 +3,17 @@ from Compiler.exceptions import * from Compiler.instructions_base import RegType import Compiler.instructions import Compiler.instructions_base -import compilerLib -import allocator as al +from . import compilerLib +from . import allocator as al +from . import util import random import time import sys, os, errno import inspect -from collections import defaultdict +from collections import defaultdict, deque import itertools import math +from functools import reduce data_types = dict( @@ -50,11 +52,11 @@ class Program(object): self.bit_length = int(options.binary) or int(options.field) if not self.bit_length: self.bit_length = BIT_LENGTHS[param] - print 'Default bit length:', self.bit_length + print('Default bit length:', self.bit_length) self.security = 40 - print 'Default security parameter:', self.security + print('Default security parameter:', self.security) self.galois_length = int(options.galois) - print 'Galois length:', self.galois_length + print('Galois length:', self.galois_length) self.schedule = [('start', [])] self.tape_counter = 0 self.tapes = [] @@ -118,7 +120,7 @@ class Program(object): running[tape] -= 1 else: raise CompilerError('Invalid schedule action') - res = max(res, sum(running.itervalues())) + res = max(res, sum(running.values())) return res def init_names(self, args, assemblymode): @@ -129,7 +131,7 @@ class Program(object): else: # assume source is in main SPDZ directory self.programs_dir = sys.path[0] + '/Programs' - print 'Compiling program in', self.programs_dir + print('Compiling program in', self.programs_dir) # create extra directories if needed for dirname in ['Public-Input', 'Bytecode', 'Schedules']: @@ -225,7 +227,7 @@ class Program(object): def read_memory(self, filename): """ Read the clear and shared memory from a file """ f = open(filename) - n = int(f.next()) + n = int(next(f)) self.mem_c = [0]*n self.mem_s = [0]*n mem = self.mem_c @@ -253,8 +255,8 @@ class Program(object): """ Reset register and memory values. """ for tape in self.tapes: tape.reset_registers() - self.mem_c = range(USER_MEM + TMP_MEM) - self.mem_s = range(USER_MEM + TMP_MEM) + self.mem_c = list(range(USER_MEM + TMP_MEM)) + self.mem_s = list(range(USER_MEM + TMP_MEM)) def write_bytes(self, outfile=None): """ Write all non-empty threads and schedule to files. """ @@ -265,7 +267,7 @@ class Program(object): sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name sch_file = open(sch_filename, 'w') - print 'Writing to', sch_filename + print('Writing to', sch_filename) sch_file.write(str(self.max_par_tapes()) + '\n') sch_file.write(str(len(nonempty_tapes)) + '\n') sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') @@ -276,7 +278,7 @@ class Program(object): for sch in self.schedule: # schedule may still contain empty tapes: ignore these - tapes = filter(lambda x: not x[0].is_empty(), sch[1]) + tapes = [x for x in sch[1] if not x[0].is_empty()] # no empty line if not tapes: continue @@ -358,7 +360,7 @@ class Program(object): def malloc(self, size, mem_type, reg_type=None): """ Allocate memory from the top """ - if not isinstance(size, (int, long)): + if not isinstance(size, int): raise CompilerError('size must be known at compile time') if size == 0: return @@ -374,7 +376,7 @@ class Program(object): addr = self.allocated_mem[mem_type] self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)): - print "Memory of type '%s' now of size %d" % (mem_type, addr + size) + print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) self.allocated_mem_blocks[addr,mem_type] = size return addr @@ -387,11 +389,11 @@ class Program(object): self.free_mem_blocks[size,mem_type].add(addr) def finalize_memory(self): - import library + from . import library self.curr_tape.start_new_basicblock(None, 'memory-usage') # reset register counter to 0 self.curr_tape.init_registers() - for mem_type,size in self.allocated_mem.items(): + for mem_type,size in list(self.allocated_mem.items()): if size: #print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: @@ -404,11 +406,11 @@ class Program(object): def set_bit_length(self, bit_length): self.bit_length = bit_length - print 'Changed bit length for comparisons etc. to', bit_length + print('Changed bit length for comparisons etc. to', bit_length) def set_security(self, security): self.security = security - print 'Changed statistical security for comparison etc. to', security + print('Changed statistical security for comparison etc. to', security) def optimize_for_gc(self): pass @@ -500,9 +502,9 @@ class Tape: #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self): - relevant = lambda inst: inst.add_usage.__func__ is not \ - Compiler.instructions_base.Instruction.add_usage.__func__ - self.usage_instructions = filter(relevant, self.instructions) + relevant = lambda inst: inst.add_usage is not \ + Compiler.instructions_base.Instruction.add_usage + self.usage_instructions = list(filter(relevant, self.instructions)) del self.instructions del self.defined_registers self.purged = True @@ -568,8 +570,8 @@ class Tape: def unpurged(function): def wrapper(self, *args, **kwargs): if self.purged: - print '%s called on purged block %s, ignoring' % \ - (function.__name__, self.name) + print('%s called on purged block %s, ignoring' % \ + (function.__name__, self.name)) return return function(self, *args, **kwargs) return wrapper @@ -577,13 +579,13 @@ class Tape: @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print 'Tape %s is empty' % self.name + print('Tape %s is empty' % self.name) return if self.if_states: raise CompilerError('Unclosed if/else blocks') - print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks) + print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) for block in self.basicblocks: al.determine_scope(block, options) @@ -593,38 +595,38 @@ class Tape: if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: for i,block in enumerate(self.basicblocks): if len(block.instructions) > 0: - print 'Processing basic block %s, %d/%d, %d instructions' % \ + print('Processing basic block %s, %d/%d, %d instructions' % \ (block.name, i, len(self.basicblocks), \ - len(block.instructions)) + len(block.instructions))) # the next call is necessary for allocation later even without merging merger = al.Merger(block, options, \ tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 10000: - print 'Eliminate dead code...' + print('Eliminate dead code...') merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: - block.used_from_scope = set() - block.defined_registers = set() + block.used_from_scope = util.set_by_id() + block.defined_registers = util.set_by_id() continue if len(block.instructions) > 10000: - print 'Merging instructions...' + print('Merging instructions...') numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) if numrounds > 0: - print 'Program requires %d rounds of communication' % numrounds + print('Program requires %d rounds of communication' % numrounds) if merger.counter: - print 'Block requires', \ + print('Block requires', \ ', '.join('%d %s' % (y, x.__name__) \ - for x, y in merger.counter.items()) + for x, y in list(merger.counter.items()))) # free memory merger = None if options.dead_code_elimination: - block.instructions = filter(lambda x: x is not None, block.instructions) + block.instructions = [x for x in block.instructions if x is not None] if not (options.merge_opens and self.merge_opens): - print 'Not merging instructions in tape %s' % self.name + print('Not merging instructions in tape %s' % self.name) # add jumps offset = 0 @@ -640,39 +642,44 @@ class Tape: block.adjust_return() # now remove any empty blocks (must be done after setting jumps) - self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks) + self.basicblocks = [x for x in self.basicblocks if len(x.instructions) != 0] # allocate registers reg_counts = self.count_regs() if not options.noreallocate: if self.program.verbose: - print 'Tape register usage:', dict(reg_counts) - print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) - print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) - print 'Re-allocating...' + print('Tape register usage:', dict(reg_counts)) + print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) + print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) + print('Re-allocating...') allocator = al.StraightlineAllocator(REG_MAX) - def alloc_loop(block): + def alloc(block): for reg in sorted(block.used_from_scope, key=lambda x: (x.reg_type, x.i)): allocator.alloc_reg(reg, block.alloc_pool) - for child in block.children: - if child.instructions: - alloc_loop(child) + def alloc_loop(block): + left = deque([block]) + while left: + block = left.popleft() + alloc(block) + for child in block.children: + if child.instructions: + left.append(child) for i,block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 10000: - print 'Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks)) + print('Allocating %s, %d/%d' % \ + (block.name, i, len(self.basicblocks))) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, (int,long)) and jump < 0 and \ + if isinstance(jump, int) and jump < 0 and \ block.exit_block.scope is not None: alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) # offline data requirements - print 'Compile offline data requirements...' + print('Compile offline data requirements...') self.req_num = self.req_tree.aggregate() - print 'Tape requires', self.req_num + print('Tape requires', self.req_num) for req,num in sorted(self.req_num.items()): if num == float('inf') or num >= 2 ** 32: num = -1 @@ -706,8 +713,8 @@ class Tape: Compiler.instructions.reqbl(bl, add_to_prog=False)) if self.program.verbose: - print 'Tape requires prime bit length', self.req_bit_length['p'] - print 'Tape requires galois bit length', self.req_bit_length['2'] + print('Tape requires prime bit length', self.req_bit_length['p']) + print('Tape requires galois bit length', self.req_bit_length['2']) @unpurged def _get_instructions(self): @@ -722,12 +729,12 @@ class Tape: @unpurged def get_bytes(self): """ Get the byte encoding of the program as an actual string of bytes. """ - return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None) + return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) @unpurged def write_encoding(self, filename): """ Write the readable encoding to a file. """ - print 'Writing to', filename + print('Writing to', filename) f = open(filename, 'w') for line in self.get_encoding(): f.write(str(line) + '\n') @@ -736,7 +743,7 @@ class Tape: @unpurged def write_str(self, filename): """ Write the sequence of instructions to a file. """ - print 'Writing to', filename + print('Writing to', filename) f = open(filename, 'w') n = 0 for block in self.basicblocks: @@ -756,8 +763,8 @@ class Tape: filename += '.bc' if not 'Bytecode' in filename: filename = self.program.programs_dir + '/Bytecode/' + filename - print 'Writing to', filename - f = open(filename, 'w') + print('Writing to', filename) + f = open(filename, 'wb') f.write(self.get_bytes()) f.close() @@ -785,9 +792,9 @@ class Tape: super(Tape.ReqNum, self).__init__(lambda: 0, init) def __add__(self, other): res = Tape.ReqNum() - for i,count in self.items(): + for i,count in list(self.items()): res[i] += count - for i,count in other.items(): + for i,count in list(other.items()): res[i] += count return res def __mul__(self, other): @@ -798,7 +805,7 @@ class Tape: __rmul__ = __mul__ def set_all(self, value): if value == float('inf') and self['all', 'inv'] > 0: - print 'Going to unknown from %s' % self + print('Going to unknown from %s' % self) res = Tape.ReqNum() for i in self: res[i] = value @@ -811,14 +818,14 @@ class Tape: res[i] = max(self[i], other[i]) return res def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in self.items() \ + return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ if req[1] != 'input') def __str__(self): return ", ".join('%s inputs in %s from player %d' \ % (num, req[0], req[2]) \ if req[1] == 'input' \ else '%s %ss in %s' % (num, req[1], req[0]) \ - for req,num in self.items()) + for req,num in list(self.items())) def __repr__(self): return repr(dict(self)) @@ -853,8 +860,8 @@ class Tape: n_rounds = res['all', 'round'] n_invs = res['all', 'inv'] if (n_invs / n_rounds) * 1000 < n_reps: - print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs) + print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ + '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) except: pass return res @@ -892,7 +899,7 @@ class Tape: The 'value' property is for emulation. """ - __slots__ = ["reg_type", "program", "i", "value", "_is_active", \ + __slots__ = ["reg_type", "program", "i", "_is_active", \ "size", "vector", "vectorbase", "caller", \ "can_eliminate"] @@ -925,7 +932,7 @@ class Tape: else: self.caller = None if self.i % 1000000 == 0 and self.i > 0: - print "Initialized %d registers at" % self.i, time.asctime() + print("Initialized %d registers at" % self.i, time.asctime()) def set_size(self, size): if self.size == size: diff --git a/Compiler/types.py b/Compiler/types.py index 3fa5992a..f72690a7 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2,11 +2,12 @@ from Compiler.program import Tape from Compiler.exceptions import * from Compiler.instructions import * from Compiler.instructions_base import * -from floatingpoint import two_power -import comparison, floatingpoint +from .floatingpoint import two_power +from . import comparison, floatingpoint import math -import util +from . import util import operator +from functools import reduce class ClientMessageType: @@ -127,15 +128,15 @@ class _number(object): return self * self def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self else: return self.add(other) def __mul__(self, other): - if other is 0 or other is 0L: + if other is 0: return 0 - elif other is 1 or other is 1L: + elif other is 1: return self else: return self.mul(other) @@ -301,7 +302,7 @@ class _register(Tape.Register, _number, _structure): if isinstance(val, (tuple, list)): size = len(val) super(_register, self).__init__(reg_type, program.curr_tape, size=size) - if isinstance(val, (int, long)): + if isinstance(val, int): self.load_int(val) elif isinstance(val, (tuple, list)): for i, x in enumerate(val): @@ -374,7 +375,7 @@ class _clear(_register): res = self.prep_res(other) if isinstance(other, cls): c_inst(res, self, other) - elif isinstance(other, (int, long)): + elif isinstance(other, int): if self.in_immediate_range(other): ci_inst(res, self, other) else: @@ -392,7 +393,7 @@ class _clear(_register): def coerce_op(self, other, inst, reverse=False): cls = self.__class__ res = cls() - if isinstance(other, (int, long)): + if isinstance(other, int): other = cls(other) elif not isinstance(other, cls): return NotImplemented @@ -414,14 +415,14 @@ class _clear(_register): def __rsub__(self, other): return self.clear_op(other, subc, subcfi, True) - def __div__(self, other): + def __truediv__(self, other): return self.clear_op(other, divc, divci) - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.coerce_op(other, divc, True) def __eq__(self, other): - if isinstance(other, (_clear,int,long)): + if isinstance(other, (_clear,int)): return regint(self) == other else: return NotImplemented @@ -493,12 +494,12 @@ class cint(_clear, _int): ldi(self, val) else: max = 2**31 - 1 - sign = abs(val) / val + sign = abs(val) // val val = abs(val) chunks = [] while val: mod = val % max - val = (val - mod) / max + val = (val - mod) // max chunks.append(mod) sum = cint(sign * chunks.pop()) for i,chunk in enumerate(reversed(chunks)): @@ -520,13 +521,13 @@ class cint(_clear, _int): return self.coerce_op(other, modc, True) def __lt__(self, other): - if isinstance(other, (type(self),int,long)): + if isinstance(other, (type(self),int)): return regint(self) < other else: return NotImplemented def __gt__(self, other): - if isinstance(other, (type(self),int,long)): + if isinstance(other, (type(self),int)): return regint(self) > other else: return NotImplemented @@ -537,6 +538,23 @@ class cint(_clear, _int): def __ge__(self, other): return 1 - (self < other) + @vectorize + def __eq__(self, other): + if not isinstance(other, (_clear, int)): + return NotImplemented + res = 1 + remaining = program.bit_length + while remaining > 0: + if isinstance(other, cint): + o = other.to_regint(min(remaining, 64)) + else: + o = other % 2 ** 64 + res *= (self.to_regint(min(remaining, 64)) == o) + self >>= 64 + other >>= 64 + remaining -= 64 + return res + def __lshift__(self, other): return self.clear_op(other, shlc, shlci) @@ -683,7 +701,7 @@ class cgf2n(_clear, _gf2n): def bit_decompose(self, bit_length=None, step=None): bit_length = bit_length or program.galois_length step = step or 1 - res = [type(self)() for _ in range(bit_length / step)] + res = [type(self)() for _ in range(bit_length // step)] gbitdec(self, step, *res) return res @@ -817,12 +835,15 @@ class regint(_register, _int): def __neg__(self): return 0 - self - def __div__(self, other): + def __floordiv__(self, other): return self.int_op(other, divint) - def __rdiv__(self, other): + def __rfloordiv__(self, other): return self.int_op(other, divint, True) + __truediv__ = __floordiv__ + __rtruediv__ = __rfloordiv__ + def __mod__(self, other): return self - (self / other) * other @@ -851,13 +872,13 @@ class regint(_register, _int): return 1 - (self < other) def __lshift__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): return self * 2**other else: return regint(cint(self) << other) def __rshift__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): return self / 2**other else: return regint(cint(self) >> other) @@ -911,6 +932,24 @@ class regint(_register, _int): def print_if(self, string): cint(self).print_if(string) +class localint(object): + """ Local integer that must prevented from leaking into the secure + computation. Uses regint internally. """ + + def __init__(self, value=None): + self._v = regint(value) + self.size = 1 + + def output(self): + self._v.print_reg_plain() + + __lt__ = lambda self, other: localint(self._v < other) + __le__ = lambda self, other: localint(self._v <= other) + __gt__ = lambda self, other: localint(self._v > other) + __ge__ = lambda self, other: localint(self._v >= other) + __eq__ = lambda self, other: localint(self._v == other) + __ne__ = lambda self, other: localint(self._v != other) + class _secret(_register): __slots__ = [] @@ -996,11 +1035,11 @@ class _secret(_register): def matrix_mul(cls, A, B, n, res_params=None): assert len(A) % n == 0 assert len(B) % n == 0 - size = len(A) * len(B) / n**2 + size = len(A) * len(B) // n**2 res = cls(size=size) - n_rows = len(A) / n - n_cols = len(B) / n - dotprods(*sum(([res[j], [A[j / n_cols * n + k] for k in range(n)], + n_rows = len(A) // n + n_cols = len(B) // n + dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)], [B[k * n_cols + j % n_cols] for k in range(n)]] for j in range(size)), [])) return res @@ -1054,7 +1093,7 @@ class _secret(_register): m_inst(res, other, self) else: m_inst(res, self, other) - elif isinstance(other, (int, long)): + elif isinstance(other, int): if self.clear_type.in_immediate_range(other): si_inst(res, self, other) else: @@ -1086,11 +1125,11 @@ class _secret(_register): return self.secret_op(other, subs, submr, subsfi, True) @vectorize - def __div__(self, other): + def __truediv__(self, other): return self * (self.clear_type(1) / other) @vectorize - def __rdiv__(self, other): + def __rtruediv__(self, other): a,b = self.get_random_inverse() return other * a / (a * self).reveal() @@ -1253,7 +1292,7 @@ class sint(_secret, _int): @vectorize def __mod__(self, modulus): - if isinstance(modulus, (int, long)): + if isinstance(modulus, int): l = math.log(modulus, 2) if 2**int(round(l)) == modulus: return self.mod2m(int(l)) @@ -1405,7 +1444,7 @@ class sgf2n(_secret, _gf2n): return self ^ cgf2n(2**program.galois_length - 1) def __xor__(self, other): - if other is 0 or other is 0L: + if other is 0: return self else: return super(sgf2n, self).add(other) @@ -1414,7 +1453,7 @@ class sgf2n(_secret, _gf2n): @vectorize def __and__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): other_bits = [(other >> i) & 1 \ for i in range(program.galois_length)] else: @@ -1515,7 +1554,7 @@ class _bitint(object): else: pre_op = floatingpoint.PreOpL if d: - carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1] + carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1] else: carries = [] res = lower + cls.sum_from_carries(a, b, carries) @@ -1539,7 +1578,7 @@ class _bitint(object): for k in range(m, -1, -1): if sum(range(m, k - 1, -1)) + 1 >= n: break - blocks = range(m, k, -1) + blocks = list(range(m, k, -1)) blocks.append(n - sum(blocks)) blocks.reverse() #print 'blocks:', blocks @@ -1597,9 +1636,9 @@ class _bitint(object): @staticmethod def get_highest_different_bits(a, b, index): - diff = [ai + bi for (ai,bi) in reversed(zip(a,b))] + diff = [ai + bi for (ai,bi) in reversed(list(zip(a,b)))] preor = floatingpoint.PreOR(diff, raw=True) - highest_diff = [x - y for (x,y) in reversed(zip(preor, [0] + preor))] + highest_diff = [x - y for (x,y) in reversed(list(zip(preor, [0] + preor)))] raw = sum(map(operator.mul, highest_diff, (a,b)[index])) return raw.bit_decompose()[0] @@ -1622,7 +1661,7 @@ class _bitint(object): if type(other) == self.bin_type: raise CompilerError('Unclear multiplication') self_bits = self.bit_decompose() - if isinstance(other, (int, long)): + if isinstance(other, int): other_bits = util.bit_decompose(other, self.n_bits) bit_matrix = [[x * y for y in self_bits] for x in other_bits] else: @@ -1644,8 +1683,8 @@ class _bitint(object): @classmethod def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True): - columns = [filter(None, (bit_matrix[j][i-j] \ - for j in range(min(len(bit_matrix), i + 1)))) \ + columns = [[_f for _f in (bit_matrix[j][i-j] \ + for j in range(min(len(bit_matrix), i + 1))) if _f] \ for i in range(len(bit_matrix[0]))] return cls.wallace_tree_from_columns(columns, get_carry) @@ -1671,7 +1710,7 @@ class _bitint(object): columns = new_columns[:-1] for col in columns: col.extend([0] * (2 - len(col))) - return self.bit_adder(*zip(*columns)) + return self.bit_adder(*list(zip(*columns))) @classmethod def wallace_tree(cls, rows): @@ -1685,17 +1724,17 @@ class _bitint(object): d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)] borrow = lambda y,x,*args: \ (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) - borrows = (0,) + zip(*floatingpoint.PreOpL(borrow, d))[1] + borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1] return self.compose(ai + bi + borrow \ for (ai,bi,borrow) in zip(a,b,borrows)) def __rsub__(self, other): raise NotImplementedError() - def __div__(self, other): + def __truediv__(self, other): raise NotImplementedError() - def __rdiv__(self, other): + def __truerdiv__(self, other): raise NotImplementedError() def __lshift__(self, other): @@ -1953,7 +1992,7 @@ class cfix(_number, _structure): """ Clear fixed point type. """ __slots__ = ['value', 'f', 'k', 'size'] reg_type = 'c' - scalars = (int, long, float, regint) + scalars = (int, float, regint) @classmethod def set_precision(cls, f, k = None): # k is the whole bitlength of fixed point @@ -1978,7 +2017,7 @@ class cfix(_number, _structure): if n == 1: return cfix(cint_inputs) else: - return map(cfix, cint_inputs) + return list(map(cfix, cint_inputs)) @vectorize def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): @@ -1990,7 +2029,7 @@ class cfix(_number, _structure): """ Send a list of cfix values to socket. Values are sent as bit shifted cints. """ def cfix_to_cint(fix_val): return cint(fix_val.v) - cint_values = map(cfix_to_cint, values) + cint_values = list(map(cfix_to_cint, values)) writesocketc(client_id, message_type, *cint_values) @staticmethod @@ -2152,7 +2191,7 @@ class cfix(_number, _structure): raise NotImplementedError @vectorize - def __div__(self, other): + def __truediv__(self, other): other = parse_type(other) if isinstance(other, cfix): return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f)) @@ -2190,7 +2229,7 @@ class _single(_number, _structure): """ Securely obtain shares of n values input by a client. Assumes client has already run bit shift to convert fixed point to integer.""" sint_inputs = cls.int_type.receive_from_client(n, client_id, ClientMessageType.TripleShares) - return map(cls, sint_inputs) + return list(map(cls, sint_inputs)) @vectorized_classmethod def load_mem(cls, address, mem_type=None): @@ -2333,7 +2372,7 @@ class _fix(_single): @classmethod def coerce(cls, other): - if isinstance(other, (_fix, cfix)): + if isinstance(other, (_fix, cls.clear_type)): return other else: return cls.conv(other) @@ -2402,7 +2441,7 @@ class _fix(_single): @vectorize def mul(self, other): - if isinstance(other, (sint, cint, regint, int, long)): + if isinstance(other, (sint, cint, regint, int)): return self._new(self.v * other, k=self.k, f=self.f) elif isinstance(other, float): if int(other) == other: @@ -2413,13 +2452,11 @@ class _fix(_single): f = self.f while v % 2 == 0: f -= 1 - v /= 2 + v //= 2 k = len(bin(abs(v))) - 1 - other = cfix(cint(v)) - other.f = f - other.k = k + other = self.multipliable(v, k, f) other = self.coerce(other) - if isinstance(other, (_fix, cfix)): + if isinstance(other, (_fix, self.clear_type)): val = self.v.TruncMul(other.v, self.k + other.k, other.f, self.kappa, self.round_nearest) @@ -2438,7 +2475,7 @@ class _fix(_single): return type(self)(-self.v) @vectorize - def __div__(self, other): + def __truediv__(self, other): other = self.coerce(other) if isinstance(other, _fix): return type(self)(library.FPDiv(self.v, other.v, self.k, self.f, @@ -2450,7 +2487,7 @@ class _fix(_single): raise TypeError('Incompatible fixed point types in division') @vectorize - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.coerce(other) / self @vectorize @@ -2497,6 +2534,10 @@ class sfix(_fix): def unreduced(self, v, other=None, res_params=None, n_summands=1): return unreduced_sfix(v, self.k * 2, self.f, self.kappa) + @staticmethod + def multipliable(v, k, f): + return cfix(cint.conv(v), k, f) + class unreduced_sfix(_single): int_type = sint @@ -2511,7 +2552,7 @@ class unreduced_sfix(_single): self.kappa = kappa def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self assert self.k == other.k assert self.m == other.m @@ -2643,7 +2684,7 @@ class _unreduced_squant(object): self.res_params = res_params or params[0] def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self assert self.params == other.params assert self.res_params == other.res_params @@ -2807,10 +2848,10 @@ class sfloat(_number, _structure): v = int(round(abs(v) * 2 ** (-p))) if v == 2 ** vlen: p += 1 - v /= 2 + v //= 2 z = 0 if p < -2 ** (plen - 1): - print 'Warning: %e truncated to zero' % vv + print('Warning: %e truncated to zero' % vv) v, p, z = 0, 0, 1 if p >= 2 ** (plen - 1): raise CompilerError('Cannot convert %s to float ' \ @@ -2950,8 +2991,8 @@ class sfloat(_number, _structure): v = t u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest, self.vlen + 2 + sfloat.round_nearest, self.kappa, - range(1 + sfloat.round_nearest, - self.vlen + 2 + sfloat.round_nearest)) + list(range(1 + sfloat.round_nearest, + self.vlen + 2 + sfloat.round_nearest))) # using u[0] doesn't seem necessary h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa) p0 = self.vlen + 1 - sum(h) @@ -3013,7 +3054,7 @@ class sfloat(_number, _structure): def __rsub__(self, other): return -self + other - def __div__(self, other): + def __truediv__(self, other): other = self.conv(other) v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1), self.vlen, self.kappa, self.round_nearest) @@ -3029,21 +3070,16 @@ class sfloat(_number, _structure): sfloat.set_error(other.z) return sfloat(v, p, z, s) - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.conv(other) / self @vectorize def __neg__(self): return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z)) - def __abs__(self): - if self.s: - return -self - else: - return self - @vectorize def __lt__(self, other): + other = self.conv(other) if isinstance(other, sfloat): z1 = self.z z2 = other.z @@ -3066,8 +3102,15 @@ class sfloat(_number, _structure): def __ge__(self, other): return 1 - (self < other) + def __gt__(self, other): + return self.conv(other) < self + + def __le__(self, other): + return self.conv(other) >= self + @vectorize def __eq__(self, other): + other = self.conv(other) # the sign can be both ways for zeroes both_zero = self.z * other.z return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \ @@ -3151,24 +3194,25 @@ class Array(object): program.free(self.address, self.value_type.reg_type) def get_address(self, index): + key = str(index) if isinstance(index, int) and self.length is not None: index += self.length * (index < 0) if index >= self.length or index < 0: raise IndexError('index %s, length %s' % \ (str(index), str(self.length))) - if (program.curr_block, index) not in self.address_cache: + if (program.curr_block, key) not in self.address_cache: n = self.value_type.n_elements() length = self.length if n == 1: # length can be None for single-element arrays length = 0 - self.address_cache[program.curr_block, index] = \ + self.address_cache[program.curr_block, key] = \ util.untuplify([self.address + index + i * length \ for i in range(n)]) if self.debug: library.print_ln_if(index >= self.length, 'OF:' + self.debug) - library.print_ln_if(self.address_cache[program.curr_block, index] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug) - return self.address_cache[program.curr_block, index] + library.print_ln_if(self.address_cache[program.curr_block, key] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug) + return self.address_cache[program.curr_block, key] def get_slice(self, index): if index.stop is None and self.length is None: @@ -3178,7 +3222,7 @@ class Array(object): def __getitem__(self, index): if isinstance(index, slice): start, stop, step = self.get_slice(index) - res_length = (stop - start - 1) / step + 1 + res_length = (stop - start - 1) // step + 1 res = Array(res_length, self.value_type) @library.for_range(res_length) def f(i): @@ -3303,7 +3347,7 @@ class SubMultiArray(object): def __getitem__(self, index): if util.is_constant(index) and index >= self.sizes[0]: raise StopIteration - key = program.curr_block, index + key = program.curr_block, str(index) if key not in self.sub_cache: if self.debug: library.print_ln_if(index >= self.sizes[0], \ @@ -3531,7 +3575,7 @@ class _mem(_number): __add__ = lambda self,other: self.read() + other __sub__ = lambda self,other: self.read() - other __mul__ = lambda self,other: self.read() * other - __div__ = lambda self,other: self.read() / other + __truediv__ = lambda self,other: self.read() / other __mod__ = lambda self,other: self.read() % other __pow__ = lambda self,other: self.read() ** other __neg__ = lambda self,other: -self.read() @@ -3550,7 +3594,7 @@ class _mem(_number): __radd__ = lambda self,other: other + self.read() __rsub__ = lambda self,other: other - self.read() __rmul__ = lambda self,other: other * self.read() - __rdiv__ = lambda self,other: other / self.read() + __rtruediv__ = lambda self,other: other / self.read() __rmod__ = lambda self,other: other % self.read() __rand__ = lambda self,other: other & self.read() __rxor__ = lambda self,other: other ^ self.read() @@ -3627,7 +3671,7 @@ class MemValue(_mem): self.check() if isinstance(value, MemValue): self.register = value.read() - elif isinstance(value, (int,long)): + elif isinstance(value, int): self.register = self.value_type(value) else: self.register = value @@ -3717,7 +3761,7 @@ def getNamedTupleType(*names): class NamedTuple(object): class NamedTupleArray(object): def __init__(self, size, t): - import types + from . import types self.arrays = [types.Array(size, t) for i in range(len(names))] def __getitem__(self, index): return NamedTuple(array[index] for array in self.arrays) @@ -3749,4 +3793,4 @@ def getNamedTupleType(*names): return self.__type__(x.reveal() for x in self) return NamedTuple -import library +from . import library diff --git a/Compiler/util.py b/Compiler/util.py index 403a81fa..8b7ea214 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -1,5 +1,6 @@ import math import operator +from functools import reduce def format_trace(trace, prefix=' '): if trace is None: @@ -46,7 +47,7 @@ def right_shift(a, b, bits): return a.right_shift(b, bits) def bit_decompose(a, bits): - if isinstance(a, (int,long)): + if isinstance(a, int): return [int((a >> i) & 1) for i in range(bits)] else: return a.bit_decompose(bits) @@ -82,7 +83,7 @@ def if_else(cond, a, b): else: return cond.if_else(a, b) except: - print cond, a, b + print(cond, a, b) raise def cond_swap(cond, a, b): @@ -112,8 +113,8 @@ def tree_reduce(function, sequence): if n == 1: return sequence[0] else: - reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)] - return tree_reduce(function, reduced + sequence[n/2*2:]) + reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n//2)] + return tree_reduce(function, reduced + sequence[n//2*2:]) def or_op(a, b): return a + b - a * b @@ -144,7 +145,7 @@ def reveal(x): return x def is_constant(x): - return isinstance(x, (int, long, bool)) + return isinstance(x, (int, bool)) def is_constant_float(x): return isinstance(x, float) or is_constant(x) @@ -180,3 +181,44 @@ def expand(x, size): return x.expand_to_vector(size) except AttributeError: return x + +class set_by_id(object): + def __init__(self, init=[]): + self.content = {} + for x in init: + self.add(x) + + def __contains__(self, value): + return id(value) in self.content + + def __iter__(self): + return iter(self.content.values()) + + def add(self, value): + self.content[id(value)] = value + +class dict_by_id(object): + def __init__(self): + self.content = {} + + def __contains__(self, key): + return id(key) in self.content + + def __getitem__(self, key): + return self.content[id(key)][1] + + def __setitem__(self, key, value): + self.content[id(key)] = (key, value) + + def keys(self): + return (x[0] for x in self.content.values()) + +class defaultdict_by_id(dict_by_id): + def __init__(self, default): + dict_by_id.__init__(self) + self.default = default + + def __getitem__(self, key): + if key not in self: + self[key] = self.default() + return dict_by_id.__getitem__(self, key) diff --git a/ECDSA/EcdsaOptions.h b/ECDSA/EcdsaOptions.h new file mode 100644 index 00000000..619aaa03 --- /dev/null +++ b/ECDSA/EcdsaOptions.h @@ -0,0 +1,66 @@ +/* + * EcdsaOptions.h + * + */ + +#ifndef ECDSA_ECDSAOPTIONS_H_ +#define ECDSA_ECDSAOPTIONS_H_ + +#include "Tools/ezOptionParser.h" + +class EcdsaOptions +{ +public: + bool prep_mul; + bool fewer_rounds; + bool check_open; + bool check_beaver_open; + + EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv) + { + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Delay multiplication until signing", // Help description. + "-D", // Flag token. + "--delay-multiplication" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Fewer rounds, more EC", // Help description. + "-P", // Flag token. + "--parallel-open" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip checking final openings (but not necessarily openings for Beaver; only relevant with active protocols)", // Help description. + "-C", // Flag token. + "--no-open-check" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip checking Beaver openings (only relevant with active protocols)", // Help description. + "-B", // Flag token. + "--no-beaver-open-check" // Flag token. + ); + opt.parse(argc, argv); + prep_mul = not opt.isSet("-D"); + fewer_rounds = opt.isSet("-P"); + check_open = not opt.isSet("-C"); + check_beaver_open = not opt.isSet("-B"); + opt.resetArgs(); + } +}; + +#endif /* ECDSA_ECDSAOPTIONS_H_ */ diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 6ddc519c..bf2c544f 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -18,7 +18,7 @@ int main() string prefix = PREP_DIR "ECDSA/"; mkdir_p(prefix.c_str()); ofstream outf; - write_online_setup(outf, prefix, P256Element::Scalar::pr(), 0, false); + write_online_setup_without_init(outf, prefix, P256Element::Scalar::pr(), 0); generate_mac_keys>(key, key2, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 0d58db43..d0828ec0 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -77,6 +77,14 @@ P256Element& P256Element::operator +=(const P256Element& other) return *this; } +P256Element& P256Element::operator /=(const Scalar& other) +{ + auto tmp = other; + tmp.invert(); + *this = *this * tmp; + return *this; +} + bool P256Element::operator ==(const P256Element& other) const { return point == other.point; diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 603a7e0d..adb38a9d 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -10,14 +10,10 @@ #include "Math/gfp.h" -#if GFP_MOD_SZ != 4 -#error GFP_MOD_SZ must be 4 -#endif - class P256Element : public ValueInterface { public: - typedef gfp Scalar; + typedef gfp_<2, 4> Scalar; private: static CryptoPP::DL_GroupParameters_EC params; @@ -52,6 +48,7 @@ public: P256Element operator*(const Scalar& other) const; P256Element& operator+=(const P256Element& other); + P256Element& operator/=(const Scalar& other); bool operator==(const P256Element& other) const; bool operator!=(const P256Element& other) const; diff --git a/ECDSA/README.md b/ECDSA/README.md index d7db6191..5d1db349 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -5,8 +5,7 @@ in `preprocessing.hpp` and `sign.hpp`, respectively. #### Compilation -- Add `MOD = -DGFP_MOD_SZ=4` to `CONFIG.mine` -- Also consider adding either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x` +- Add either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x` - For older hardware, also add `ARCH = -march=native` - Install [Crypto++](https://www.cryptopp.com) (`libcrypto++-dev` on Ubuntu). We used version 5.6.4, which is the default on Ubuntu 18.04. - Compile the binaries: `make -j8 ecdsa` diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 9a703f37..bea4db5b 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -23,6 +23,7 @@ int main(int argc, const char** argv) { ez::ezOptionParser opt; + EcdsaOptions opts(opt, argc, argv); Names N(opt, argc, argv, 2); int n_tuples = 1000; if (not opt.lastArgs.empty()) @@ -38,7 +39,7 @@ int main(int argc, const char** argv) typedef Share pShare; DataPositions usage; Sub_Data_Files prep(N, prefix, usage); - typename pShare::MAC_Check MCp(keyp); + typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); SubProcessor proc(_, MCp, prep, P); @@ -46,7 +47,7 @@ int main(int argc, const char** argv) proc.DataF.get_two(DATA_INVERSE, sk, __); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc); + preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); - sign_benchmark(tuples, sk, MCp, P); + sign_benchmark(tuples, sk, MCp, P, opts); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 174ec672..3ab383a8 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -29,15 +29,7 @@ void run(int argc, const char** argv) { bigint::init_thread(); ez::ezOptionParser opt; - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Delay multiplication until signing", // Help description. - "-D", // Flag token. - "--delay-multiplication" // Flag token. - ); + EcdsaOptions opts(opt, argc, argv); Names N(opt, argc, argv, 3); int n_tuples = 1000; if (not opt.lastArgs.empty()) @@ -51,10 +43,12 @@ void run(int argc, const char** argv) P.Broadcast_Receive(bundle, false); Timer timer; timer.start(); + auto stats = P.comm_stats; pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; + (P.comm_stats - stats).print(true); - OnlineOptions::singleton.batch_size = n_tuples; + OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; auto& prep = *Preprocessing::get_live_prep(0, usage); typename pShare::MAC_Check MCp; @@ -63,9 +57,9 @@ void run(int argc, const char** argv) bool prep_mul = not opt.isSet("-D"); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc, prep_mul); + preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); - sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc); + sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); delete &prep; } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 998657db..f9cc9c9f 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -25,27 +25,72 @@ template class T> void run(int argc, const char** argv) { ez::ezOptionParser opt; + EcdsaOptions opts(opt, argc, argv); opt.add( "", // Default. 0, // Required? 0, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Delay multiplication until signing", // Help description. - "-D", // Flag token. - "--delay-multiplication" // Flag token. + "Use SimpleOT instead of OT extension", // Help description. + "-S", // Flag token. + "--simple-ot" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Don't check correlation in OT extension (only relevant with MASCOT)", // Help description. + "-U", // Flag token. + "--unchecked-correlation" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Fewer rounds for authentication (only relevant with MASCOT)", // Help description. + "-A", // Flag token. + "--auth-fewer-rounds" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use Fiat-Shamir for amplification (only relevant with MASCOT)", // Help description. + "-H", // Flag token. + "--fiat-shamir" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip sacrifice (only relevant with MASCOT)", // Help description. + "-E", // Flag token. + "--embrace-life" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "No MACs (only relevant with MASCOT; implies skipping MAC checks)", // Help description. + "-M", // Flag token. + "--no-macs" // Flag token. + ); + Names N(opt, argc, argv, 2); int n_tuples = 1000; if (not opt.lastArgs.empty()) n_tuples = atoi(opt.lastArgs[0]->c_str()); PlainPlayer P(N); P256Element::init(); - gfp1::init_field(P256Element::Scalar::pr(), false); + P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); BaseMachine machine; - machine.ot_setups.resize(1); - for (int i = 0; i < 2; i++) - machine.ot_setups[0].push_back({P, true}); + machine.ot_setups.push_back({P, true}); P256Element::Scalar keyp; SeededPRNG G; @@ -65,16 +110,28 @@ void run(int argc, const char** argv) P.Broadcast_Receive(bundle, false); Timer timer; timer.start(); + auto stats = P.comm_stats; sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; + (P.comm_stats - stats).print(true); - OnlineOptions::singleton.batch_size = n_tuples; + OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::LivePrep prep(0, usage); + prep.params.correlation_check &= not opt.isSet("-U"); + prep.params.fewer_rounds = opt.isSet("-A"); + prep.params.fiat_shamir = opt.isSet("-H"); + prep.params.check = not opt.isSet("-E"); + prep.params.generateMACs = not opt.isSet("-M"); + opts.check_beaver_open &= prep.params.generateMACs; + opts.check_open &= prep.params.generateMACs; SubProcessor proc(_, MCp, prep, P); + typename pShare::prep_type::Direct_MC MCpp(keyp); + prep.triple_generator->MC = &MCpp; bool prep_mul = not opt.isSet("-D"); + prep.params.use_extension = not opt.isSet("-S"); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc, prep_mul); + preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); - sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc); + sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); } diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 10306daf..449f9331 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -7,6 +7,7 @@ #define ECDSA_PREPROCESSING_HPP_ #include "P256Element.h" +#include "EcdsaOptions.h" #include "Processor/Data_Files.h" #include "Protocols/ReplicatedPrep.h" #include "Protocols/MaliciousShamirShare.h" @@ -23,40 +24,66 @@ public: template class T> void preprocessing(vector>& tuples, int buffer_size, T& sk, - SubProcessor>& proc, bool prep_mul = true) + SubProcessor>& proc, + EcdsaOptions opts) { + bool prep_mul = opts.prep_mul; Timer timer; timer.start(); Player& P = proc.P; auto& prep = proc.DataF; size_t start = P.sent + prep.data_sent(); + auto stats = P.comm_stats + prep.comm_stats(); + auto& extra_player = P; + auto& protocol = proc.protocol; auto& MCp = proc.MC; typedef T pShare; typedef T cShare; vector inv_ks; vector secret_Rs; + prep.buffer_triples(); + vector bs, cs; for (int i = 0; i < buffer_size; i++) { - pShare a, a_inv; - prep.get_two(DATA_INVERSE, a, a_inv); - inv_ks.push_back(a_inv); - secret_Rs.push_back(a); + pShare a, b, c; + prep.get_three(DATA_TRIPLE, a, b, c); + inv_ks.push_back(a); + bs.push_back(b); + cs.push_back(c); } + vector cs_opened; + MCp.POpen_Begin(cs_opened, cs, extra_player); + if (opts.fewer_rounds) + secret_Rs.insert(secret_Rs.begin(), bs.begin(), bs.end()); + else + { + MCp.POpen_End(cs_opened, cs, extra_player); + for (int i = 0; i < buffer_size; i++) + secret_Rs.push_back(bs[i] / cs_opened[i]); + } + vector opened_Rs; + typename cShare::Direct_MC MCc(MCp.get_alphai()); + MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { protocol.init_mul(&proc); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); - protocol.exchange(); + protocol.start_exchange(); } - else - prep.buffer_triples(); - vector opened_Rs; - typename cShare::MAC_Check MCc(MCp.get_alphai()); - MCc.POpen(opened_Rs, secret_Rs, P); - MCc.Check(P); - MCp.Check(P); + if (opts.fewer_rounds) + MCp.POpen_End(cs_opened, cs, extra_player); + MCc.POpen_End(opened_Rs, secret_Rs, extra_player); + if (opts.fewer_rounds) + for (int i = 0; i < buffer_size; i++) + opened_Rs[i] /= cs_opened[i]; + if (prep_mul) + protocol.stop_exchange(); + if (opts.check_open) + MCc.Check(extra_player); + if (opts.check_open or opts.check_beaver_open) + MCp.Check(extra_player); for (int i = 0; i < buffer_size; i++) { tuples.push_back( @@ -68,6 +95,7 @@ void preprocessing(vector>& tuples, int buffer_size, << " seconds, throughput " << buffer_size / timer.elapsed() << ", " << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size << " kbytes per tuple" << endl; + (P.comm_stats + prep.comm_stats() - stats).print(true); } template class T> diff --git a/ECDSA/rep-ecdsa-party.cpp b/ECDSA/rep-ecdsa-party.cpp index 1678f1de..f1ccc572 100644 --- a/ECDSA/rep-ecdsa-party.cpp +++ b/ECDSA/rep-ecdsa-party.cpp @@ -8,10 +8,10 @@ #include "hm-ecdsa-party.hpp" template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) +Preprocessing>* Preprocessing>::get_live_prep( + SubProcessor>* proc, DataPositions& usage) { - return new ReplicatedPrep>(proc, usage); + return new ReplicatedPrep>(proc, usage); } int main(int argc, const char** argv) diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index c0e8c338..f6f4d663 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -71,12 +71,12 @@ EcSignature sign(const unsigned char* message, size_t length, prod = protocol.finalize_mul(); } auto rx = tuple.R.x(); - signature.s = MC.POpen( + signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " << (P.sent - start) << " bytes" << endl; auto diff = (P.comm_stats - stats); - diff.print(); + diff.print(true); return signature; } @@ -112,26 +112,30 @@ void check(EcSignature signature, const unsigned char* message, size_t length, template class T> void sign_benchmark(vector>& tuples, T sk, typename T::MAC_Check& MCp, Player& P, + EcdsaOptions& opts, SubProcessor>* proc = 0) { unsigned char message[1024]; GlobalPRNG(P).get_octets(message, 1024); - typename T::MAC_Check MCc(MCp.get_alphai()); + typename T::Direct_MC MCc(MCp.get_alphai()); // synchronize Bundle bundle(P); P.Broadcast_Receive(bundle, true); Timer timer; timer.start(); - P256Element pk = MCc.POpen(sk, P); + auto stats = P.comm_stats; + P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - P.comm_stats.print(); + (P.comm_stats - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message, 1 << i, pk); + if (not opts.check_open) + continue; Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); diff --git a/ExternalIO/bankers-bonus-commsec-client.cpp b/ExternalIO/bankers-bonus-commsec-client.cpp index 33ab007c..365b9b63 100644 --- a/ExternalIO/bankers-bonus-commsec-client.cpp +++ b/ExternalIO/bankers-bonus-commsec-client.cpp @@ -365,6 +365,7 @@ int main(int argc, char** argv) // init static gfp string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree()); initialise_fields(prep_data_prefix); + bigint::init_thread(); // Generate session keys to decrypt data sent from each spdz engine (party) vector session_keys(nparties); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index ee1870c7..84fe88cd 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -39,6 +39,8 @@ class Ciphertext // Rely on default copy assignment/constructor + word get_pk_id() { return pk_id; } + void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id) { cc0=a0; cc1=a1; this->pk_id = pk_id; } void set(const Rq_Element& a0, const Rq_Element& a1, const FHE_PK& pk); diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 04dd9105..460fee18 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -9,13 +9,20 @@ template Multiplier::Multiplier(int offset, PairwiseGenerator& generator) : - generator(generator), machine(generator.machine), - P(generator.P, offset), - num_players(generator.P.num_players()), - my_num(generator.P.my_num()), + Multiplier(offset, generator.machine, generator.P, generator.timers) +{ +} + +template +Multiplier::Multiplier(int offset, PairwiseMachine& machine, Player& P, + map& timers) : + machine(machine), + P(P, offset), + num_players(P.num_players()), + my_num(P.my_num()), other_pk(machine.other_pks[(my_num + num_players - offset) % num_players]), other_enc_alpha(machine.enc_alphas[(my_num + num_players - offset) % num_players]), - timers(generator.timers), + timers(timers), C(machine.pk), mask(machine.pk), product_share(machine.setup().FieldD), rc(machine.pk), volatile_capacity(0) diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index 4a9ba4a5..17159a7f 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -21,7 +21,6 @@ class PairwiseMachine; template class Multiplier { - PairwiseGenerator& generator; PairwiseMachine& machine; OffsetPlayer P; int num_players, my_num; @@ -39,6 +38,9 @@ class Multiplier public: Multiplier(int offset, PairwiseGenerator& generator); + Multiplier(int offset, PairwiseMachine& machine, Player& P, + map& timers); + void multiply_and_add(Plaintext_& res, const Ciphertext& C, const Plaintext_& b); void multiply_and_add(Plaintext_& res, const Ciphertext& C, diff --git a/GC/Processor.hpp b/GC/Processor.hpp index ae23693a..9ee5c2b9 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -13,6 +13,8 @@ using namespace std; #include "Access.h" #include "Processor/FixInput.h" +#include "Processor/ProcessorBase.hpp" + namespace GC { diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 8b4de98b..37c1e411 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -25,9 +25,9 @@ void SemiPrep::set_protocol(Beaver& protocol) (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, {}, thread.P); triple_generator->multi_threaded = false; } diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index 2f96a2e7..d8a04964 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -95,7 +95,7 @@ ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : else P = new PlainPlayer(this->N, 0xFFFF); for (int i = 0; i < this->machine.nthreads; i++) - this->machine.ot_setups.push_back({{{*P, true}}}); + this->machine.ot_setups.push_back({*P, true}); delete P; } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 7a9c0b38..1e18d7e1 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -152,8 +152,7 @@ void ReplicatedSecret::reveal(size_t n_bits, Clear& x) auto& share = *this; vector opened; auto& party = ShareThread::s(); - party.MC->POpen_Begin(opened, {share}, *party.P); - party.MC->POpen_End(opened, {share}, *party.P); + party.MC->POpen(opened, {share}, *party.P); x = IntBase(opened[0]); } diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 043d9355..0a533b2e 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -59,7 +59,7 @@ void ThreadMaster::run() if (T::needs_ot) for (int i = 0; i < machine.nthreads; i++) - machine.ot_setups.push_back({{*P, true}, {*P, true}}); + machine.ot_setups.push_back({*P, true}); for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 0ff4c2a0..e332be63 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -31,18 +31,17 @@ void TinyPrep::set_protocol(Beaver& protocol) params.generateMACs = true; params.amplify = false; params.check = false; - params.set_mac_key(thread.MC->get_alphai()); triple_generator = new typename T::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, thread.MC->get_alphai(), thread.P); triple_generator->multi_threaded = false; input_generator = new typename T::part_type::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(1), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, thread.MC->get_alphai(), thread.P); input_generator->multi_threaded = false; thread.MC->get_part_MC().set_prep(*this); } diff --git a/Machines/OTMachine.cpp b/Machines/OTMachine.cpp index 7d16cd0f..e589cfc8 100644 --- a/Machines/OTMachine.cpp +++ b/Machines/OTMachine.cpp @@ -264,11 +264,8 @@ OTMachine::OTMachine(int argc, const char** argv) // convert baseOT selection bits to BitVector // (not already BitVector due to legacy PVW code) + baseReceiverInput = bot.receiver_inputs; baseReceiverInput.resize(nbase); - for (int i = 0; i < nbase; i++) - { - baseReceiverInput.set_bit(i, bot.receiver_inputs[i]); - } } OTMachine::~OTMachine() diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 69a6d0b6..08fc64d0 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -107,7 +107,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr 1, // Number of args expected. 0, // Delimiter if expecting multiple args. "Maximum number of parties to send to at once", // Help description. - "-b", // Flag token. + "-B", // Flag token. "--max-broadcast" // Flag token. ); opt.add( diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index 79a23c14..811cb482 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -118,6 +118,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) : opt.get("-l")->getInt(nloops); generateBits = opt.get("-B")->isSet; check = opt.get("-c")->isSet || generateBits; + correlation_check = opt.get("-c")->isSet; generateMACs = opt.get("-m")->isSet || check; amplify = opt.get("-a")->isSet || generateMACs; primeField = opt.get("-P")->isSet; @@ -143,21 +144,22 @@ TripleMachine::TripleMachine(int argc, const char** argv) : // doesn't work with Montgomery multiplication gfp1::init_field(p, false); + gfp::init_field(p, true); gf2n_long::init_field(128); PRNG G; G.ReSeed(); - mac_key2l.randomize(G); - mac_key2s.randomize(G); + mac_key2.randomize(G); mac_keyp.randomize(G); mac_keyz.randomize(G); } template -GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i) +GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i, + typename T::mac_key_type mac_key) { return new typename T::TripleGenerator(setup, N[i % nConnections], i, - nTriplesPerThread, nloops, *this); + nTriplesPerThread, nloops, *this, mac_key); } void TripleMachine::run() @@ -180,24 +182,24 @@ void TripleMachine::run() for (int i = 0; i < nthreads; i++) { if (primeField) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyp); else if (z2k) { if (z2k == 32 and z2s == 32) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 64 and z2s == 64) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 64 and z2s == 48) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 66 and z2s == 64) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 66 and z2s == 48) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else throw runtime_error("not compiled for k=" + to_string(z2k) + " and s=" + to_string(z2s)); } else - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_key2); } ntriples = generators[0]->nTriples * nthreads; cout <<"Setup generators\n"; @@ -251,10 +253,8 @@ void TripleMachine::run() void TripleMachine::output_mac_keys() { if (z2k) { - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2l); + write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2); } - else if (gf2n::degree() > 64) - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2l); else - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s); + write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2); } diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp new file mode 100644 index 00000000..2ec4618b --- /dev/null +++ b/Machines/hemi-party.cpp @@ -0,0 +1,29 @@ +/* + * hemi-party.cpp + * + */ + +#include "Protocols/HemiShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "FHE/P2Data.h" +#include "Tools/ezOptionParser.h" + +#include "Player-Online.hpp" +#include "Protocols/HemiPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/SemiPrep.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/Beaver.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + spdz_main, HemiShare>(argc, argv, opt); +} diff --git a/Makefile b/Makefile index ed1bfe48..443280f1 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ DEPS := $(wildcard */*.d) all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x ifeq ($(USE_NTL),1) -all: overdrive she-offline cowgear-party.x +all: overdrive she-offline cowgear-party.x hemi-party.x endif -include $(DEPS) @@ -165,6 +165,7 @@ malicious-shamir-party.x: Machines/ShamirMachine.o spdz2k-party.x: $(OT) semi-party.x: $(OT) semi2k-party.x: $(OT) +hemi-party.x: $(FHEOFFLINE) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o mascot-party.x: Machines/SPDZ.o $(OT) Player-Online.x: Machines/SPDZ.o $(OT) diff --git a/Math/Setup.cpp b/Math/Setup.cpp index e103353d..ebe18dd0 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -110,6 +110,13 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i } void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont) +{ + write_online_setup_without_init(outf, dirname, p, lg2); + gfp::init_field(p, mont); + init_gf2n(lg2); +} + +void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2) { if (p == 0) throw runtime_error("prime cannot be 0"); @@ -132,9 +139,6 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2 // Fix as a negative lg2 is a ``signal'' to choose slightly weaker // LWE parameters outf << abs(lg2) << endl; - - gfp::init_field(p, mont); - init_gf2n(lg2); } void init_gf2n(int lg2) diff --git a/Math/Setup.h b/Math/Setup.h index d9a196df..b9574764 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -22,6 +22,7 @@ using namespace std; // Create setup file for gfp and gf2n void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2); void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont = true); +void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2); // Setup primes only // Chooses a p of at least lgp bits diff --git a/Math/Square.cpp b/Math/Square.cpp index b7f3ece8..83e1d97f 100644 --- a/Math/Square.cpp +++ b/Math/Square.cpp @@ -15,14 +15,14 @@ void Square::to(gf2n_short& result) result = sum; } -template <> -void Square::to(gfp1& result) +template +template +void Square::to(gfp_& result) { - const int L = gfp1::N_LIMBS; mp_limb_t product[2 * L], sum[2 * L], tmp[L][2 * L]; memset(tmp, 0, sizeof(tmp)); memset(sum, 0, sizeof(sum)); - for (int i = 0; i < gfp1::length(); i++) + for (int i = 0; i < gfp_::length(); i++) { memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i])); if (i % 64 == 0) @@ -32,10 +32,22 @@ void Square::to(gfp1& result) mpn_add_fixed_n<2 * L>(sum, product, sum); } mp_limb_t q[2 * L], ans[2 * L]; - mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L); + mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp_::get_ZpD().get_prA(), L); result.assign((void*) ans); } +template<> +void Square::to(gfp1& result) +{ + to<1, GFP_MOD_SZ>(result); +} + +template<> +void Square::to(gfp3& result) +{ + to<3, 4>(result); +} + template<> void Square::to(BitVec& result) { diff --git a/Math/Square.h b/Math/Square.h index b33d8134..28484dbd 100644 --- a/Math/Square.h +++ b/Math/Square.h @@ -31,6 +31,8 @@ public: void conditional_add(BitVector& conditions, Square& other, int offset); void to(U& result); + template + void to(gfp_& result); void pack(octetStream& os) const; void unpack(octetStream& os); diff --git a/Math/Z2k.h b/Math/Z2k.h index 05530d8d..2bd9fbed 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -283,7 +283,8 @@ Z2 Z2::operator>>(int i) const { Z2 res; int n_byte_shift = i / 8; - memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift); + if (N_BYTES - n_byte_shift > 0) + memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift); int n_inside_shift = i % 8; if (n_inside_shift > 0) mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 41e18f0b..2bc5f652 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -147,4 +147,17 @@ ostream& operator<<(ostream& o, const Z2& x) return o; } +template +istream& operator>>(istream& i, SignedZ2& x) +{ + auto& tmp = bigint::tmp; + i >> tmp; + if (tmp.numBits() > K + 1) + throw runtime_error( + tmp.get_str() + " out of range for signed " + to_string(K) + + "-bit numbers"); + x = tmp; + return i; +} + #endif diff --git a/Math/gfp.h b/Math/gfp.h index de79d04f..8360bee8 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -262,6 +262,8 @@ typedef gfp_<0, GFP_MOD_SZ> gfp; typedef gfp_<1, GFP_MOD_SZ> gfp1; // enough for Brain protocol with 64-bit computation and 40-bit security typedef gfp_<2, 4> gfp2; +// for OT-based ECDSA +typedef gfp_<3, 4> gfp3; void to_signed_bigint(bigint& ans,const gfp& x); diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index 0dbb9449..83aa4745 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -110,7 +110,7 @@ inline mp_limb_t mpn_add_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const char carry = 0; for (int i = 0; i < n; i++) #if defined(__clang__) -#if __clang_major__ < 8 || defined(__APPLE__) +#if __clang_major__ < 8 || (defined(__APPLE__) && __clang_major__ < 11) carry = __builtin_ia32_addcarry_u64(carry, x[i], y[i], (unsigned long long*)&res[i]); #else carry = __builtin_ia32_addcarryx_u64(carry, x[i], y[i], (unsigned long long*)&res[i]); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 9e853ee8..5edd4950 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -97,12 +97,20 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv, 1, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "This player's number", // Help description. + "This player's number (required)", // Help description. "-p", // Flag token. "--player" // Flag token. ); opt.parse(argc, argv); opt.get("-p")->getInt(player_no); + vector missing; + if (not opt.gotRequired(missing)) + { + string usage; + opt.getUsage(usage); + cerr << usage; + exit(1); + } global_server = network_opts.start_networking(*this, player_no); } @@ -123,7 +131,7 @@ void Names::setup_names(const char *servername, int my_port) set_up_client_socket(socket_num, servername, pn); send(socket_num, (octet*)&player_no, sizeof(player_no)); #ifdef DEBUG_NETWORKING - fprintf(stderr, "Sent %d to %s:%d\n", player_no, servername, pn); + cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; #endif int inst=-1; // wait until instruction to start. @@ -338,6 +346,14 @@ void MultiPlayer::send_all(const octetStream& o,bool donthash) const } +void Player::receive_all(vector& os) const +{ + for (int j = 0; j < num_players(); j++) + if (j != my_num()) + receive_player(j, os[j], true); +} + + void Player::receive_player(int i,octetStream& o,bool donthash) const { #ifdef VERBOSE_COMM @@ -345,6 +361,7 @@ void Player::receive_player(int i,octetStream& o,bool donthash) const #endif TimeScope ts(timer); receive_player_no_stats(i, o); + comm_stats["Receiving directly"].add(o, ts); if (!donthash) { blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); } } @@ -627,12 +644,14 @@ void RealTwoPartyPlayer::receive(octetStream& o) const TimeScope ts(timer); o.reset_write_head(); o.Receive(socket); + comm_stats["Receiving one-to-one"].add(o, ts); } void VirtualTwoPartyPlayer::receive(octetStream& o) const { TimeScope ts(timer); P.receive_player_no_stats(other_player, o); + comm_stats["Receiving one-to-one"].add(o, ts); } void RealTwoPartyPlayer::send_receive_player(vector& o) const @@ -688,6 +707,8 @@ void TwoPartyPlayer::Broadcast_Receive(vector& o, CommStats& CommStats::operator +=(const CommStats& other) { data += other.data; + rounds += other.rounds; + timer += other.timer; return *this; } @@ -698,6 +719,13 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other) return *this; } +NamedCommStats NamedCommStats::operator +(const NamedCommStats& other) const +{ + auto res = *this; + res += other; + return res; +} + CommStats& CommStats::operator -=(const CommStats& other) { data -= other.data; @@ -722,13 +750,15 @@ size_t NamedCommStats::total_data() return res; } -void NamedCommStats::print() +void NamedCommStats::print(bool newline) { for (auto it = begin(); it != end(); it++) if (it->second.data) cerr << it->first << " " << 1e-6 * it->second.data << " MB in " << it->second.rounds << " rounds, taking " << it->second.timer.elapsed() << " seconds" << endl; + if (size() and newline) + cerr << endl; } template class MultiPlayer; diff --git a/Networking/Player.h b/Networking/Player.h index d07fd92c..88e3293f 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -91,6 +91,7 @@ struct CommStats Timer timer; CommStats() : data(0), rounds(0) {} Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; } + void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } CommStats& operator+=(const CommStats& other); CommStats& operator-=(const CommStats& other); }; @@ -99,9 +100,10 @@ class NamedCommStats : public map { public: NamedCommStats& operator+=(const NamedCommStats& other); + NamedCommStats operator+(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const; size_t total_data(); - void print(); + void print(bool newline = false); #ifdef VERBOSE_COMM CommStats& operator[](const string& name) { @@ -160,6 +162,7 @@ public: virtual void send_all(const octetStream& o,bool donthash=false) const = 0; void send_to(int player,const octetStream& o,bool donthash=false) const; virtual void send_to_no_stats(int player,const octetStream& o) const = 0; + void receive_all(vector& os) const; void receive_player(int i,octetStream& o,bool donthash=false) const; virtual void receive_player_no_stats(int i,octetStream& o) const = 0; virtual void receive_player(int i,FlexBuffer& buffer) const; diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index 87cc845f..8adbb098 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -83,84 +83,40 @@ ServerSocket::~ServerSocket() void ServerSocket::accept_clients() { - map unassigned_sockets; - while (true) { - fd_set readfds; - FD_ZERO(&readfds); - int nfds = main_socket; - FD_SET(main_socket, &readfds); - for (auto &socket : unassigned_sockets) - { - FD_SET(socket.first, &readfds); - nfds = max(socket.first, nfds); - } + struct sockaddr dest; + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + int socksize = sizeof(dest); + int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + if (consocket<0) { error("set_up_socket:accept"); } - select(nfds + 1, &readfds, 0, 0, 0); - - if (FD_ISSET(main_socket, &readfds)) - { - struct sockaddr dest; - memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ - int socksize = sizeof(dest); - int consocket = accept(main_socket, (struct sockaddr*) &dest, - (socklen_t*) &socksize); - if (consocket < 0) - error("set_up_socket:accept"); - unassigned_sockets[consocket] = dest; + int client_id; + try + { + receive(consocket, (unsigned char*)&client_id, sizeof(client_id)); + } + catch (closed_connection&) + { #ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &dest; - fprintf(stderr, "new client on %s:%d\n", inet_ntoa(conn.sin_addr), - ntohs(conn.sin_port)); + auto& conn = *(sockaddr_in*)&dest; + cerr << "client on " << inet_ntoa(conn.sin_addr) << ":" + << ntohs(conn.sin_port) << " left without identification" + << endl; #endif - } + } - vector processed_sockets; - for (auto &socket : unassigned_sockets) - { - int consocket = socket.first; - if (FD_ISSET(consocket, &readfds)) - { - try - { - int client_id; - receive(consocket, (unsigned char*) &client_id, - sizeof(client_id)); - - data_signal.lock(); - clients[client_id] = consocket; - data_signal.broadcast(); - data_signal.unlock(); - -#ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &socket.second; - fprintf(stderr, "client id %d on %s:%d\n", client_id, - inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); -#endif + data_signal.lock(); + clients[client_id] = consocket; + data_signal.broadcast(); + data_signal.unlock(); #ifdef __APPLE__ - int flags = fcntl(consocket, F_GETFL, 0); - int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags); - if (fl < 0) - error("set non-blocking"); + int flags = fcntl(consocket, F_GETFL, 0); + int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags); + if (fl < 0) + error("set non-blocking"); #endif - } - catch (closed_connection&) - { -#ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &socket.second; - cerr << "client on " << inet_ntoa(conn.sin_addr) << ":" - << ntohs(conn.sin_port) << " left without identification" - << endl; -#endif - close_client_socket(consocket); - } - processed_sockets.push_back(consocket); - } - } - for (int socket : processed_sockets) - unassigned_sockets.erase(socket); } } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index e93f35cb..99c5a3a6 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -116,7 +116,7 @@ void BaseOT::exec_base(bool new_receiver_inputs) { if (new_receiver_inputs) receiver_inputs[i + j] = G.get_uchar()&1; - cs[j] = receiver_inputs[i + j]; + cs[j] = receiver_inputs[i + j].get(); } receiver_rsgen(&receiver, Rs_pack[0], cs); os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0])); @@ -293,7 +293,7 @@ void FakeOT::exec_base(bool new_receiver_inputs) { for (int j = 0; j < 2; j++) bv[j].unpack(os[1]); - receiver_outputs[i] = bv[receiver_inputs[i]]; + receiver_outputs[i] = bv[receiver_inputs[i].get()]; } set_seeds(); diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 54be8a1d..1b314d38 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -30,7 +30,7 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol class BaseOT { public: - vector receiver_inputs; + BitVector receiver_inputs; vector< vector > sender_inputs; vector receiver_outputs; TwoPartyPlayer* P; @@ -63,7 +63,7 @@ public: int length() { return ot_length; } - void set_receiver_inputs(const vector& new_inputs) + void set_receiver_inputs(const BitVector& new_inputs) { if ((int)new_inputs.size() != nOT) throw invalid_length(); @@ -72,7 +72,7 @@ public: void set_receiver_inputs(int128 inputs) { - vector new_inputs(128); + BitVector new_inputs(128); for (int i = 0; i < 128; i++) new_inputs[i] = (inputs >> i).get_lower() & 1; set_receiver_inputs(new_inputs); @@ -81,6 +81,7 @@ public: // do the OTs -- generate fresh random choice bits by default virtual void exec_base(bool new_receiver_inputs=true); // use PRG to get the next ot_length bits + void set_seeds(); void extend_length(); void check(); @@ -90,8 +91,6 @@ protected: bool is_sender() { return (bool) (ot_role & SENDER); } bool is_receiver() { return (bool) (ot_role & RECEIVER); } - - void set_seeds(); }; class FakeOT : public BaseOT diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index bf3c32fa..95687bb0 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -702,4 +702,5 @@ BMS XXXX(Matrix, gf2n_short) XXXX(Matrix>, gf2n_long) XXXX(Matrix>, gfp1) +XXXX(Matrix>, gfp3) XXXX(Matrix, BitVec) diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 3b189b2e..9ddfde0c 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -142,7 +142,7 @@ public: BitMatrix() {} BitMatrix(int length); - __m128i operator[](int i) { return squares[i / 128].rows[i % 128]; } + __m128i& operator[](int i) { return squares[i / 128].rows[i % 128]; } void resize(int length); int size(); diff --git a/OT/MascotParams.cpp b/OT/MascotParams.cpp index c6375fbe..868ea874 100644 --- a/OT/MascotParams.cpp +++ b/OT/MascotParams.cpp @@ -26,81 +26,15 @@ MascotParams::MascotParams() generateMACs = true; amplify = true; check = true; + correlation_check = true; generateBits = false; + use_extension = true; + fewer_rounds = false; + fiat_shamir = false; timerclear(&start); } void MascotParams::set_passive() { - generateMACs = amplify = check = false; -} - -template<> gf2n_long MascotParams::get_mac_key() -{ - return mac_key2l; -} - -template<> gf2n_short MascotParams::get_mac_key() -{ - return mac_key2s; -} - -template<> gfp1 MascotParams::get_mac_key() -{ - return mac_keyp; -} - -template<> Z2<48> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<64> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<40> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<32> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> BitVec MascotParams::get_mac_key() -{ - return 0; -} - -template<> void MascotParams::set_mac_key(gf2n_long key) -{ - mac_key2l = key; -} - -template<> void MascotParams::set_mac_key(gf2n_short key) -{ - mac_key2s = key; -} - -template<> void MascotParams::set_mac_key(gfp1 key) -{ - mac_keyp = key; -} - -template<> void MascotParams::set_mac_key(Z2<64> key) -{ - mac_keyz = key; -} - -template<> void MascotParams::set_mac_key(Z2<48> key) -{ - mac_keyz = key; -} - -template<> void MascotParams::set_mac_key(Z2<40> key) -{ - mac_keyz = key; + generateMACs = amplify = check = correlation_check = false; } diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index b0e7f6c2..48411264 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -68,6 +68,8 @@ protected: SeededPRNG share_prg; + mac_key_type mac_key; + void start_progress(); void print_progress(int k); @@ -101,8 +103,11 @@ public: vector> preampTriples; vector> plainTriples; - OTTripleGenerator(OTTripleSetup& setup, const Names& names, + typename T::MAC_Check* MC; + + OTTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); ~OTTripleGenerator(); @@ -113,7 +118,10 @@ public: void run_multipliers(MultJob job); + mac_key_type get_mac_key() const { return mac_key; } + size_t data_sent(); + NamedCommStats comm_stats(); }; template @@ -130,8 +138,9 @@ public: vector< ShareTriple_ > uncheckedTriples; vector>> inputs; - NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, + NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); virtual ~NPartyTripleGenerator() {} @@ -159,8 +168,9 @@ class MascotTripleGenerator : public NPartyTripleGenerator public: vector bits; - MascotTripleGenerator(OTTripleSetup& setup, const Names& names, + MascotTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); }; @@ -181,8 +191,9 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator U& MC, PRNG& G); public: - Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names, + Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); void generateTriples(); @@ -199,4 +210,15 @@ size_t OTTripleGenerator::data_sent() return res; } +template +NamedCommStats OTTripleGenerator::comm_stats() +{ + NamedCommStats res; + if (parentPlayer != &globalPlayer) + res = globalPlayer.comm_stats; + for (auto& player : players) + res += player->comm_stats; + return res; +} + #endif diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index e92ccc94..9460106d 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -5,27 +5,14 @@ #include "OT/OTExtensionWithMatrix.h" #include "OT/OTMultiplier.h" -#include "Math/gfp.h" -#include "Protocols/Share.h" -#include "Protocols/SemiShare.h" -#include "Protocols/Semi2kShare.h" -#include "Protocols/Spdz2kShare.h" #include "Math/operators.h" #include "Tools/Subroutines.h" #include "Protocols/MAC_Check.h" -#include "Protocols/Spdz2kPrep.h" -#include "GC/SemiSecret.h" #include "OT/Triple.hpp" -#include "OT/Rectangle.hpp" #include "OT/OTMultiplier.hpp" #include "Protocols/MAC_Check.hpp" -#include "Protocols/SemiMC.h" -#include "Protocols/MascotPrep.hpp" -#include "Protocols/ReplicatedInput.hpp" #include "Protocols/SemiInput.hpp" -#include "Processor/Input.hpp" -#include "Math/Z2k.hpp" #include #include @@ -43,44 +30,46 @@ void* run_ot_thread(void* ptr) * N.B. setup must not be stored as it will be used by other threads */ template -NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, +NPartyTripleGenerator::NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : OTTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -MascotTripleGenerator::MascotTripleGenerator(OTTripleSetup& setup, +MascotTripleGenerator::MascotTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -Spdz2kTripleGenerator::Spdz2kTripleGenerator(OTTripleSetup& setup, +Spdz2kTripleGenerator::Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -OTTripleGenerator::OTTripleGenerator(OTTripleSetup& setup, +OTTripleGenerator::OTTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names, - thread_num * names.num_players() * names.num_players())), parentPlayer(parentPlayer), thread_num(thread_num), + mac_key(mac_key), my_num(setup.get_my_num()), nloops(nloops), nparties(setup.get_nparties()), - machine(machine) + machine(machine), + MC(0) { nTriplesPerLoop = DIV_CEIL(_nTriples, nloops); nTriples = nTriplesPerLoop * nloops; @@ -208,7 +197,6 @@ void NPartyTripleGenerator::generateInputs(int player) { typedef open_type T; - auto& machine = this->machine; auto& nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; @@ -235,7 +223,7 @@ void NPartyTripleGenerator::generateInputs(int player) GlobalPRNG G(globalPlayer); Share check_sum; inputs.resize(toCheck); - auto mac_key = machine.template get_mac_key(); + auto mac_key = this->get_mac_key(); SemiInput> input(0, globalPlayer); input.reset_all(globalPlayer); vector secrets(toCheck); @@ -289,7 +277,7 @@ void MascotTripleGenerator::generateBitsGf2n() bits.resize(nBitsToCheck); vector to_open(1); vector opened(1); - MAC_Check_ MC(this->machine.template get_mac_key()); + MAC_Check_ MC(this->get_mac_key()); this->start_progress(); @@ -313,7 +301,7 @@ void MascotTripleGenerator::generateBitsGf2n() typename T::clear r; for (int j = 0; j < nBitsToCheck; j++) { - auto mac_sum = valueBits[0].get_bit(j) ? MC.get_alphai() : 0; + auto mac_sum = valueBits[0].get_bit(j) ? this->get_mac_key() : 0; for (int i = 0; i < this->nparties-1; i++) mac_sum += this->ot_multipliers[i]->macs[0][j]; bits[j].set_share(valueBits[0].get_bit(j)); @@ -352,6 +340,13 @@ void MascotTripleGenerator>::generateBits() generateTriples(); } +template<> +inline +void MascotTripleGenerator>::generateBits() +{ + generateTriples(); +} + template void Spdz2kTripleGenerator::generateTriples() { @@ -360,7 +355,6 @@ void Spdz2kTripleGenerator::generateTriples() auto& uncheckedTriples = this->uncheckedTriples; auto& timers = this->timers; - auto& machine = this->machine; auto& nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; @@ -382,7 +376,7 @@ void Spdz2kTripleGenerator::generateTriples() vector< PlainTriple_, Z2, 2> > amplifiedTriples(nTriplesPerLoop); uncheckedTriples.resize(nTriplesPerLoop); MAC_Check_Z2k, Z2, Z2, Share> > MC( - machine.template get_mac_key >()); + this->get_mac_key()); this->start_progress(); @@ -455,7 +449,7 @@ void Spdz2kTripleGenerator::generateTriples() // get piggy-backed random value Z2 r_share = b_padded_bits.get_ptr_to_byte(nTriplesPerLoop, Z2::N_BYTES); Z2 r_mac; - r_mac.mul(r_share, this->machine.template get_mac_key>()); + r_mac.mul(r_share, this->get_mac_key()); for (int i = 0; i < this->nparties-1; i++) r_mac += (ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop); Share> r; @@ -563,16 +557,17 @@ void MascotTripleGenerator::generateTriples() valueBits[2*i].resize(field_size * nPreampTriplesPerLoop); valueBits[1].resize(field_size * nTriplesPerLoop); vector< PlainTriple > amplifiedTriples; - MAC_Check MC(machine.template get_mac_key()); + MAC_Check MC(this->get_mac_key()); if (machine.amplify) preampTriples.resize(nTriplesPerLoop); if (machine.generateMACs) { amplifiedTriples.resize(nTriplesPerLoop); - uncheckedTriples.resize(nTriplesPerLoop); } + uncheckedTriples.resize(nTriplesPerLoop); + this->start_progress(); for (int k = 0; k < nloops; k++) @@ -581,10 +576,15 @@ void MascotTripleGenerator::generateTriples() if (machine.amplify) { - octet seed[SEED_SIZE]; - Create_Random_Seed(seed, globalPlayer, SEED_SIZE); PRNG G; - G.SetSeed(seed); + if (machine.fiat_shamir and nparties == 2) + ot_multipliers[0]->otCorrelator.common_seed(G); + else + { + octet seed[SEED_SIZE]; + Create_Random_Seed(seed, globalPlayer, SEED_SIZE); + G.SetSeed(seed); + } for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) { PlainTriple triple; @@ -598,12 +598,16 @@ void MascotTripleGenerator::generateTriples() triple.output(outputFile); timers["Writing"].stop(); } + else + for (int i = 0; i < 3; i++) + uncheckedTriples[iTriple].byIndex(i, 0).set_share(triple.byIndex(i, 0)); } if (machine.generateMACs) { for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) - amplifiedTriples[iTriple].to(valueBits, iTriple); + amplifiedTriples[iTriple].to(valueBits, iTriple, + machine.check ? 2 : 1); for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push({}); @@ -625,7 +629,7 @@ void MascotTripleGenerator::generateTriples() if (machine.check) { - sacrifice(uncheckedTriples, MC, G); + sacrifice(uncheckedTriples, this->MC ? *this->MC : MC, G); } } } diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index e87d919c..f3634adf 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -259,7 +259,7 @@ void naive_transpose64(vector& output, const vector& input } -OTExtension::OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, +OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive) : player(player) { nbaseOTs = baseOT.nOT; diff --git a/OT/OTExtension.h b/OT/OTExtension.h index 05067e0f..df53eca4 100644 --- a/OT/OTExtension.h +++ b/OT/OTExtension.h @@ -30,7 +30,7 @@ public: vector receiverOutput; map times; - OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, bool passive); + OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive); OTExtension(int nbaseOTs, int baseLength, int nloops, int nsubloops, diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 1d6b5bcd..39be81ce 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -308,7 +308,8 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs) } template -void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput) +void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, + V& receiverOutput, bool correlated) { //cout << "Hashing... " << flush; octetStream os, h_os(HASH_SIZE); @@ -341,7 +342,11 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r for (int j = 0; j < 8; j++) { tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j]; - tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0); + if (correlated) + tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0); + else + tmp[1][j] = + senderOutputMatrices[1].squares[i_outer_input].rows[i_inner_input + j]; } for (int j = 0; j < 2; j++) mmo.hashBlocks( @@ -366,17 +371,39 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r template template -void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output) +void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output, int start) { - if (receiverOutputMatrix.squares.size() < nTriples) + if (receiverOutputMatrix.squares.size() < nTriples + start) throw invalid_length(); output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { - receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]); + receiverOutputMatrix.squares[j + start].sub( + senderOutputMatrices[0].squares[j + start]).to(output[j]); } } +template +void OTCorrelator::common_seed(PRNG& G) +{ + Slice t1Slice(t1, 0, t1.squares.size()); + Slice uSlice(u, 0, u.squares.size()); + + octetStream os; + if (player->my_num()) + { + t1Slice.pack(os); + uSlice.pack(os); + } + else + { + uSlice.pack(os); + t1Slice.pack(os); + } + auto hash = os.hash(); + G = PRNG(hash); +} + octet* OTExtensionWithMatrix::get_receiver_output(int i) { return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128]; @@ -515,34 +542,35 @@ template class OTCorrelator; #define Z(BM,GF) \ template class OTCorrelator; \ template void OTCorrelator::reduce_squares(unsigned int nTriples, \ - vector& output); + vector& output, int); #define ZZZZ(GF) \ template void OTExtensionWithMatrix::print_post_correlate( \ BitVector& newReceiverInput, int j, int offset, int sender); \ #define ZZZ(GF, M) Z(M, GF) \ -template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&); +template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&, bool); ZZZZ(gf2n_long) ZZZ(gf2n_short, Matrix) ZZZ(gf2n_long, Matrix>) ZZZ(gfp1, Matrix>) +ZZZ(gfp3, Matrix>) ZZZ(BitVec, Matrix) #undef XX #define XX(T,U,N,L) \ template class OTCorrelator, Z2 > > >; \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector& output); \ + vector& output, int); \ template void OTExtensionWithMatrix::hash_outputs(int, \ std::vector, Z2 > >, std::allocator, Z2 > > > >&, \ - Matrix, Z2 > >&); + Matrix, Z2 > >&, bool); #undef X #define X(N,L) \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ + vector >& output, int); \ XX(Z2,Z2,N,L) //X(96, 160) diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index 29af01f3..72007fa1 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -45,7 +45,9 @@ public: U& baseReceiverOutput); void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); template - void reduce_squares(unsigned int nTriples, vector& output); + void reduce_squares(unsigned int nTriples, vector& output, + int start = 0); + void common_seed(PRNG& G); }; class OTExtensionWithMatrix : public OTCorrelator @@ -80,7 +82,8 @@ public: void transpose(int start, int slice); void expand_transposed(); template - void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput); + void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput, + bool correlated = true); void print(BitVector& newReceiverInput, int i = 0); template diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index f40da615..f5ba18d9 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -7,18 +7,11 @@ #include "OT/OTMultiplier.h" #include "OT/NPartyTripleGenerator.h" -#include "OT/Rectangle.h" -#include "Math/Z2k.h" -#include "Math/BitVec.h" -#include "Protocols/SemiShare.h" -#include "Protocols/Semi2kShare.h" -#include "Protocols/Spdz2kShare.h" +#include "OT/BaseOT.h" #include "OT/OTVole.hpp" #include "OT/Row.hpp" #include "OT/Rectangle.hpp" -#include "Math/Z2k.hpp" -#include "Math/Square.hpp" #include @@ -31,7 +24,8 @@ OTMultiplier::OTMultiplier(OTTripleGenerator& generator, rot_ext(128, 128, 0, 1, generator.players[thread_num], generator.baseReceiverInput, generator.baseSenderInputs[thread_num], - generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check), + generator.baseReceiverOutputs[thread_num], BOTH, + !generator.machine.correlation_check), otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true) { this->thread = 0; @@ -89,7 +83,7 @@ OTMultiplier::~OTMultiplier() template void OTMultiplier::multiply() { - keyBits.set(generator.machine.template get_mac_key()); + keyBits.set(generator.get_mac_key()); rot_ext.extend(keyBits.size(), keyBits); this->outbox.push({}); senderOutput.resize(keyBits.size()); @@ -140,11 +134,6 @@ void OTMultiplier::multiplyForTriples() { typedef typename W::Rectangle X; - // dummy input for OT correlator - vector _; - vector< vector > __; - BitVector ___; - otCorrelator.resize(X::N_COLUMNS * generator.nPreampTriplesPerLoop); rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128); @@ -161,8 +150,26 @@ void OTMultiplier::multiplyForTriples() this->inbox.pop(job); BitVector aBits = generator.valueBits[0]; //timers["Extension"].start(); - rot_ext.extend_correlated(aBits); - rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + if (generator.machine.use_extension) + { + rot_ext.extend_correlated(aBits); + } + else + { + BaseOT bot(aBits.size(), -1, generator.players[thread_num]); + bot.set_receiver_inputs(aBits); + bot.exec_base(false); + for (size_t i = 0; i < aBits.size(); i++) + { + rot_ext.receiverOutputMatrix[i] = + bot.receiver_outputs[i].get_int128(0).a; + for (int j = 0; j < 2; j++) + rot_ext.senderOutputMatrices[j][i] = + bot.sender_inputs[i][j].get_int128(0).a; + } + } + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, + baseReceiverOutput, generator.machine.use_extension); //timers["Extension"].stop(); //timers["Correlation"].start(); @@ -215,8 +222,6 @@ void MascotMultiplier::after_correlation() { typedef typename U::open_type T; - this->auth_ot_ext.resize( - this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS); this->auth_ot_ext.set_role(BOTH); this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop, @@ -229,15 +234,45 @@ void MascotMultiplier::after_correlation() this->macs.resize(3); MultJob job; this->inbox.pop(job); + auto& generator = this->generator; + array n_vals; for (int j = 0; j < 3; j++) { - int nValues = this->generator.nTriplesPerLoop; + n_vals[j] = generator.nTriplesPerLoop; if (this->generator.machine.check && (j % 2 == 0)) - nValues *= 2; - this->auth_ot_ext.expand(0, nValues); - this->auth_ot_ext.correlate(0, nValues, - this->generator.valueBits[j], true); - this->auth_ot_ext.reduce_squares(nValues, this->macs[j]); + n_vals[j] *= 2; + } + if (generator.machine.fewer_rounds) + { + BitVector bits; + int total = 0; + for (int j = 0; j < 3; j++) + { + bits.append(generator.valueBits[j], + n_vals[j] * T::Square::N_COLUMNS); + total += n_vals[j]; + } + this->auth_ot_ext.resize(bits.size()); + this->auth_ot_ext.expand(0, total); + this->auth_ot_ext.correlate(0, total, bits, true); + total = 0; + for (int j = 0; j < 3; j++) + { + this->auth_ot_ext.reduce_squares(n_vals[j], this->macs[j], total); + total += n_vals[j]; + } + } + else + { + this->auth_ot_ext.resize(n_vals[0] * T::Square::N_COLUMNS); + for (int j = 0; j < 3; j++) + { + int nValues = n_vals[j]; + this->auth_ot_ext.expand(0, nValues); + this->auth_ot_ext.correlate(0, nValues, + this->generator.valueBits[j], true); + this->auth_ot_ext.reduce_squares(nValues, this->macs[j]); + } } this->outbox.push(job); } diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index bc0a01ab..7cabdf5a 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -41,3 +41,22 @@ void OTTripleSetup::close_connections() delete players[i]; } } + +OTTripleSetup OTTripleSetup::get_fresh() +{ + OTTripleSetup res = *this; + for (int i = 0; i < nparties - 1; i++) + { + BaseOT bot(nbase, 128, 0); + bot.sender_inputs = baseSenderInputs[i]; + bot.receiver_outputs = baseReceiverOutputs[i]; + bot.set_seeds(); + bot.extend_length(); + baseSenderInputs[i] = bot.sender_inputs; + baseReceiverOutputs[i] = bot.receiver_outputs; + bot.extend_length(); + res.baseSenderInputs[i] = bot.sender_inputs; + res.baseReceiverOutputs[i] = bot.receiver_outputs; + } + return res; +} diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 52ae7e07..a30b72bd 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -11,7 +11,7 @@ */ class OTTripleSetup { - vector base_receiver_inputs; + BitVector base_receiver_inputs; vector baseOTs; PRNG G; @@ -25,10 +25,10 @@ public: vector< vector< vector > > baseSenderInputs; vector< vector > baseReceiverOutputs; - int get_nparties() { return nparties; } - int get_nbase() { return nbase; } - int get_my_num() { return my_num; } - int get_base_receiver_input(int i) { return base_receiver_inputs[i]; } + int get_nparties() const { return nparties; } + int get_nbase() const { return nbase; } + int get_my_num() const { return my_num; } + int get_base_receiver_input(int i) const { return base_receiver_inputs[i]; } OTTripleSetup(Player& N, bool real_OTs) : nparties(N.num_players()), my_num(N.my_num()), nbase(128) @@ -78,6 +78,8 @@ public: //template //T get_mac_key(); + + OTTripleSetup get_fresh(); }; diff --git a/OT/Triple.hpp b/OT/Triple.hpp index 842f13b1..a7fd99c7 100644 --- a/OT/Triple.hpp +++ b/OT/Triple.hpp @@ -16,13 +16,16 @@ public: T b; T c[N]; - int repeat(int l) + int repeat(int l, bool check) { switch (l) { case 0: case 2: - return N; + if (check) + return N; + else + return 1; case 1: return 1; default: @@ -75,12 +78,12 @@ class PlainTriple : public Triple { public: // this assumes that valueBits[1] is still set to the bits of b - void to(vector& valueBits, int i) + void to(vector& valueBits, int i, int repeat = N) { for (int j = 0; j < N; j++) { - valueBits[0].set_portion(i * N + j, this->a[j]); - valueBits[2].set_portion(i * N + j, this->c[j]); + valueBits[0].set_portion(i * repeat + j, this->a[j]); + valueBits[2].set_portion(i * repeat + j, this->c[j]); } } }; @@ -123,12 +126,12 @@ public: { for (int l = 0; l < 3; l++) { - int repeat = this->repeat(l); + int repeat = this->repeat(l, generator.machine.check); for (int j = 0; j < repeat; j++) { T value = triple.byIndex(l,j); T mac; - mac.mul(value, generator.machine.template get_mac_key()); + mac.mul(value, generator.get_mac_key()); for (int i = 0; i < generator.nparties-1; i++) mac += generator.ot_multipliers[i]->macs.at(l).at(iTriple * repeat + j); Share& share = this->byIndex(l,j); diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index e49906d9..0dd9dde5 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -16,28 +16,21 @@ class GeneratorThread; class MascotParams : virtual public OfflineParams { -protected: - gf2n_short mac_key2s; - gf2n_long mac_key2l; - gfp1 mac_keyp; - Z2<128> mac_keyz; - public: string prep_data_dir; bool generateMACs; bool amplify; bool check; + bool correlation_check; bool generateBits; + bool use_extension; + bool fewer_rounds; + bool fiat_shamir; struct timeval start, stop; MascotParams(); void set_passive(); - - template - T get_mac_key(); - template - void set_mac_key(T key); }; class TripleMachine : public OfflineMachineBase, public MascotParams @@ -45,6 +38,10 @@ class TripleMachine : public OfflineMachineBase, public MascotParams Names N[2]; int nConnections; + gf2n mac_key2; + gfp1 mac_keyp; + Z2<128> mac_keyz; + public: int nloops; bool primeField; @@ -54,7 +51,8 @@ public: TripleMachine(int argc, const char** argv); template - GeneratorThread* new_generator(OTTripleSetup& setup, int i); + GeneratorThread* new_generator(OTTripleSetup& setup, int i, + typename T::mac_key_type mac_key); void run(); diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index a5e227d6..5adebec5 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -30,7 +30,7 @@ public: string progname; int nthreads; - vector> ot_setups; + vector ot_setups; static BaseMachine& s(); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 663dfecf..e5671d26 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -94,6 +94,7 @@ public: virtual void purge() {} virtual size_t data_sent() { return 0; } + virtual NamedCommStats comm_stats() { return {}; } virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; @@ -112,6 +113,7 @@ public: virtual array get_triple(int n_bits); virtual void buffer_triples() {} + virtual void buffer_inverses() {} }; template diff --git a/Processor/Input.hpp b/Processor/Input.hpp index f2a22350..8bb376d3 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -13,6 +13,8 @@ #include "FixInput.h" #include "FloatInput.h" +#include "IntInput.hpp" + template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) @@ -295,7 +297,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, cout << "Please input " << U::NAME << "s:" << endl; \ prepare(Proc, player, &args[i + U::N_DEST + 1], size); \ break; - X(IntInput) X(FixInput) X(FloatInput) + X(IntInput) X(FixInput) X(FloatInput) #undef X default: throw runtime_error("unknown input type: " + to_string(type)); @@ -317,7 +319,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ finalize(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \ break; - X(IntInput) X(FixInput) X(FloatInput) + X(IntInput) X(FixInput) X(FloatInput) #undef X default: throw runtime_error("unknown input type: " + to_string(type)); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index f4772da2..6d22f2c7 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -61,6 +61,9 @@ enum USE_PREP = 0x1C, STARTGRIND = 0x1D, STOPGRIND = 0x1E, + NPLAYERS = 0xE2, + THRESHOLD = 0xE3, + PLAYERID = 0xE4, // Addition ADDC = 0x20, ADDS = 0x21, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 8fd8fdbb..a7d9a36f 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -177,6 +177,9 @@ void BaseInstruction::parse_operands(istream& s, int pos) case PRINTCHRINT: case PRINTSTRINT: case PRINTINT: + case NPLAYERS: + case THRESHOLD: + case PLAYERID: r[0]=get_int(s); break; // instructions with 3 registers + 1 integer operand @@ -442,6 +445,9 @@ int BaseInstruction::get_reg_type() const case CONVMODP: case GCONVGF2N: case RAND: + case NPLAYERS: + case THRESHOLD: + case PLAYERID: return INT; case PREP: case USE_PREP: @@ -1046,10 +1052,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.temp.ans2.output(Proc.private_output, false); break; case INPUT: - sint::Input::template input(Proc.Procp, start, size); + sint::Input::template input>(Proc.Procp, start, size); return; case GINPUT: - sgf2n::Input::template input(Proc.Proc2, start, size); + sgf2n::Input::template input>(Proc.Proc2, start, size); return; case INPUTFIX: sint::Input::template input(Proc.Procp, start, size); @@ -1404,6 +1410,15 @@ inline void Instruction::execute(Processor& Proc) const case STOPGRIND: CALLGRIND_STOP_INSTRUMENTATION; break; + case NPLAYERS: + Proc.write_Ci(r[0], Proc.P.num_players()); + break; + case THRESHOLD: + Proc.write_Ci(r[0], sint::threshold(Proc.P.num_players())); + break; + case PLAYERID: + Proc.write_Ci(r[0], Proc.P.my_num()); + break; // *** // TODO: read/write shared GF(2^n) data instructions // *** diff --git a/Processor/IntInput.cpp b/Processor/IntInput.cpp deleted file mode 100644 index 959745da..00000000 --- a/Processor/IntInput.cpp +++ /dev/null @@ -1,14 +0,0 @@ -/* - * IntInput.cpp - * - */ - -#include "IntInput.h" - -const char* IntInput::NAME = "integer"; - -void IntInput::read(std::istream& in, const int* params) -{ - (void) params; - in >> items[0]; -} diff --git a/Processor/IntInput.h b/Processor/IntInput.h index 2881e0c3..c550cf45 100644 --- a/Processor/IntInput.h +++ b/Processor/IntInput.h @@ -8,6 +8,7 @@ #include +template class IntInput { public: @@ -17,7 +18,7 @@ public: const static int TYPE = 0; - long items[N_DEST]; + T items[N_DEST]; void read(std::istream& in, const int* params); }; diff --git a/Processor/IntInput.hpp b/Processor/IntInput.hpp new file mode 100644 index 00000000..97bc7c0b --- /dev/null +++ b/Processor/IntInput.hpp @@ -0,0 +1,15 @@ +/* + * IntInput.cpp + * + */ + +#include "IntInput.h" + +template +const char* IntInput::NAME = "integer"; + +template +void IntInput::read(std::istream& in, const int*) +{ + in >> items[0]; +} diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index cbcfa941..bc4eef06 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -127,10 +127,8 @@ Machine::Machine(int my_number, Names& playerNames, P = new CryptoPlayer(playerNames, 0xF000); else P = new PlainPlayer(playerNames, 0xF000); - ot_setups.resize(nthreads); for (int i = 0; i < nthreads; i++) - for (int j = 0; j < 3; j++) - ot_setups.at(i).push_back({ *P, true }); + ot_setups.push_back({ *P, true }); delete P; } diff --git a/Processor/Processor.h b/Processor/Processor.h index 4baa9281..c3217c09 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -52,8 +52,6 @@ public: Preprocessing& DataF, Player& P); // Access to PO (via calls to POpen start/stop) - void POpen_Start(const vector& reg,const Player& P,int size); - void POpen_Stop(const vector& reg,const Player& P,int size); void POpen(const vector& reg,const Player& P,int size); void muls(const vector& reg, int size); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 0a96bd74..ba504a57 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -6,6 +6,7 @@ #include "Protocols/ReplicatedInput.hpp" #include "Protocols/ReplicatedPrivateOutput.hpp" +#include "Processor/ProcessorBase.hpp" #include #include @@ -406,15 +407,16 @@ void Processor::write_shares_to_file(const vector& data_regist } template -void SubProcessor::POpen_Start(const vector& reg,const Player& P,int size) +void SubProcessor::POpen(const vector& reg,const Player& P,int size) { - int sz=reg.size(); + assert(reg.size() % 2 == 0); + int sz=reg.size() / 2; Sh_PO.clear(); Sh_PO.reserve(sz*size); if (size>1) { - for (typename vector::const_iterator reg_it=reg.begin(); - reg_it!=reg.end(); reg_it++) + for (typename vector::const_iterator reg_it=reg.begin() + 1; + reg_it < reg.end(); reg_it += 2) { auto begin=S.begin()+*reg_it; Sh_PO.insert(Sh_PO.end(),begin,begin+size); @@ -423,24 +425,15 @@ void SubProcessor::POpen_Start(const vector& reg,const Player& P,int siz else { for (int i=0; i -void SubProcessor::POpen_Stop(const vector& reg,const Player& P,int size) -{ - int sz=reg.size(); - PO.resize(sz*size); - MC.POpen_End(PO,Sh_PO,P); + MC.POpen(PO,Sh_PO,P); if (size>1) { auto PO_it=PO.begin(); for (typename vector::const_iterator reg_it=reg.begin(); - reg_it!=reg.end(); reg_it++) + reg_it!=reg.end(); reg_it += 2) { for (auto C_it=C.begin()+*reg_it; C_it!=C.begin()+*reg_it+size; C_it++) @@ -452,36 +445,16 @@ void SubProcessor::POpen_Stop(const vector& reg,const Player& P,int size } else { - for (unsigned int i=0; i& dest, vector& source, const vector& reg) -{ - int n = reg.size() / 2; - source.resize(n); - dest.resize(n); - for (int i = 0; i < n; i++) - { - source[i] = reg[2 * i + 1]; - dest[i] = reg[2 * i]; - } -} - -template -void SubProcessor::POpen(const vector& reg, const Player& P, - int size) -{ - vector source, dest; - unzip_open(dest, source, reg); - POpen_Start(source, P, size); - POpen_Stop(dest, P, size); -} - template void SubProcessor::muls(const vector& reg, int size) { diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.hpp similarity index 87% rename from Processor/ProcessorBase.cpp rename to Processor/ProcessorBase.hpp index 03f40813..7d17b428 100644 --- a/Processor/ProcessorBase.cpp +++ b/Processor/ProcessorBase.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROCESSOR_PROCESSORBASE_HPP_ +#define PROCESSOR_PROCESSORBASE_HPP_ + #include "ProcessorBase.h" #include "IntInput.h" #include "FixInput.h" @@ -11,6 +14,7 @@ #include +inline void ProcessorBase::open_input_file(const string& name) { #ifdef DEBUG_FILES @@ -20,6 +24,7 @@ void ProcessorBase::open_input_file(const string& name) input_filename = name; } +inline void ProcessorBase::open_input_file(int my_num, int thread_num) { string input_file = "Player-Data/Input-P" + to_string(my_num) + "-" + to_string(thread_num); @@ -54,6 +59,4 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co return res; } -template IntInput ProcessorBase::get_input(bool, const int*); -template FixInput ProcessorBase::get_input(bool, const int*); -template FloatInput ProcessorBase::get_input(bool, const int*); +#endif diff --git a/Programs/Source/aes.mpc b/Programs/Source/aes.mpc index 191a15ee..aaa6f1d2 100644 --- a/Programs/Source/aes.mpc +++ b/Programs/Source/aes.mpc @@ -129,7 +129,7 @@ def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4): temp[2] = box.apply_sbox(temp[2]) temp[3] = box.apply_sbox(temp[3]) - temp[0] = temp[0] + ApplyEmbedding(rcon[int(i/Nk)]) + temp[0] = temp[0] + ApplyEmbedding(rcon[int(i//Nk)]) for j in range(4): round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j] @@ -233,7 +233,7 @@ def inverseMod(val): for idx in range(40): if idx % 5 == 0: - bd_val[idx] = raw_bit_dec[idx / 5] + bd_val[idx] = raw_bit_dec[idx // 5] bd_squared = bd_val squared_index = 2 diff --git a/Programs/Source/bankers_bonus_commsec.mpc b/Programs/Source/bankers_bonus_commsec.mpc index 99df66b0..45a47ef9 100644 --- a/Programs/Source/bankers_bonus_commsec.mpc +++ b/Programs/Source/bankers_bonus_commsec.mpc @@ -118,10 +118,10 @@ def main(): return True if n_rounds > 0: - print 'run %d rounds' % n_rounds + print('run %d rounds' % n_rounds) for_range(n_rounds)(game_loop) else: - print 'run forever' + print('run forever') do_while(game_loop) main() diff --git a/Programs/Source/gc_and.mpc b/Programs/Source/gc_and.mpc index be9fd8f0..3da453d6 100644 --- a/Programs/Source/gc_and.mpc +++ b/Programs/Source/gc_and.mpc @@ -13,7 +13,7 @@ if len(program.args) > 2: m = int(program.args[2]) pack = min(n, 50) -n = (n + pack - 1) / pack +n = (n + pack - 1) // pack a = sbit(1) b = sbit(1, n=pack) diff --git a/Programs/Source/htmac.mpc b/Programs/Source/htmac.mpc index c796469f..310e2076 100644 --- a/Programs/Source/htmac.mpc +++ b/Programs/Source/htmac.mpc @@ -50,9 +50,9 @@ test_decryption = True instructions_base.set_global_vector_size(n_parallel) if use_mimc_prf: - execfile('./Programs/Source/prf_mimc.mpc') + exec(compile(__builtins__['open']('./Programs/Source/prf_mimc.mpc').read(), './Programs/Source/prf_mimc.mpc', 'exec')) elif use_leg_prf: - execfile('./Programs/Source/prf_leg.mpc') + exec(compile(__builtins__['open']('./Programs/Source/prf_leg.mpc').read(), './Programs/Source/prf_leg.mpc', 'exec')) class HMAC(object): def __init__(self, _enc): @@ -97,7 +97,7 @@ class NonceEncryptMAC(object): def get_long_random(self, nbits): """ Returns random cint() % 2^{nbits} """ result = cint(0) - for i in range(nbits / 30): + for i in range(nbits // 30): result += cint(regint.get_random(30)) result <<= 30 @@ -178,7 +178,7 @@ def time_private_mac(n_total, n_parallel, nmessages): # Benchmark n_total HtMAC's while executing in parallel n_parallel start_timer(1) - @for_range(n_total / n_parallel) + @for_range(n_total // n_parallel) def block(index): # Re-use off-line data after n_parallel runs for benchmarking purposes. # If real system-use need to initialize num_calls with a larger constant. diff --git a/Programs/Source/regression.mpc b/Programs/Source/regression.mpc index 992a5db0..a599782c 100644 --- a/Programs/Source/regression.mpc +++ b/Programs/Source/regression.mpc @@ -8,7 +8,7 @@ ml.set_n_threads(8) debug = False if 'halfprec' in program.args: - print '8-bit precision after point' + print('8-bit precision after point') sfix.set_precision(8, 31) cfix.set_precision(8, 31) else: @@ -26,7 +26,7 @@ n_features = 12634 if len(program.args) > 2: if 'bc' in program.args: - print 'Compiling for BC-TCGA' + print('Compiling for BC-TCGA') n_examples = 472 n_normal = 49 n_features = 17814 @@ -41,7 +41,7 @@ try: except: pass -print 'Using %d threads' % ml.Layer.n_threads +print('Using %d threads' % ml.Layer.n_threads) n_fold = 5 test_share = 1. / n_fold @@ -63,8 +63,8 @@ else: n_test = sum(n_tests) indices = [regint.Array(n) for n in n_ex] -indices[0].assign(range(n_pos, n_pos + n_normal)) -indices[1].assign(range(n_pos)) +indices[0].assign(list(range(n_pos, n_pos + n_normal))) +indices[1].assign(list(range(n_pos))) test = regint.Array(n_test) @@ -97,7 +97,7 @@ for arg in program.args: m = re.match('tol=(.*)', arg) if m: sgd.tol = float(m.group(1)) - print 'Stop with tolerance', sgd.tol + print('Stop with tolerance', sgd.tol) sum_acc = cfix.MemValue(0) diff --git a/Programs/Source/test_sbitfix.mpc b/Programs/Source/test_sbitfix.mpc index b8dd6f4f..22be93c7 100644 --- a/Programs/Source/test_sbitfix.mpc +++ b/Programs/Source/test_sbitfix.mpc @@ -6,8 +6,8 @@ sbitfix.set_precision(16, 32) def test(a, b, value_type=None): try: b = int(round((b * (1 << a.f)))) - if b < 0: - b += 2 ** sbitfix.k + if b < 0: + b += 2 ** sbitfix.k a = a.v.reveal() except AttributeError: pass diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc index ba02c80a..2e7e0c7e 100644 --- a/Programs/Source/vickrey.mpc +++ b/Programs/Source/vickrey.mpc @@ -55,12 +55,12 @@ def f(_): def thread(): i = get_arg() - n_per_thread = n_inputs / n_threads + n_per_thread = n_inputs // n_threads if n_per_thread % 2 != 0: raise Exception('Number of inputs must be divisible by 2') start = i * n_per_thread tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \ - for j in range(n_per_thread / 2)] + for j in range(n_per_thread // 2)] first, second = util.tree_reduce(first_and_second, tuples) results[2*i], results[2*i+1] = first, second diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 9d234628..9e7c3f81 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -29,6 +29,8 @@ class Beaver : public ProtocolBase typename T::MAC_Check* MC; public: + static const bool uses_triples = true; + Player& P; Beaver(Player& P) : prep(0), MC(0), P(P) {} @@ -39,6 +41,9 @@ public: void exchange(); T finalize_mul(int n = -1); + void start_exchange(); + void stop_exchange(); + int get_n_relevant_players() { return P.num_players(); } }; diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 66d53f38..0ed322f2 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -50,6 +50,20 @@ void Beaver::exchange() triple = triples.begin(); } +template +void Beaver::start_exchange() +{ + MC->POpen_Begin(opened, shares, P); +} + +template +void Beaver::stop_exchange() +{ + MC->POpen_End(opened, shares, P); + it = opened.begin(); + triple = triples.begin(); +} + template T Beaver::finalize_mul(int n) { diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h new file mode 100644 index 00000000..fda19a0d --- /dev/null +++ b/Protocols/HemiPrep.h @@ -0,0 +1,39 @@ +/* + * HemiPrep.h + * + */ + +#ifndef PROTOCOLS_HEMIPREP_H_ +#define PROTOCOLS_HEMIPREP_H_ + +#include "ReplicatedPrep.h" +#include "FHEOffline/Multiplier.h" + +template +class HemiPrep : public SemiHonestRingPrep +{ + typedef typename T::clear::FD FD; + + static PairwiseMachine* pairwise_machine; + static Lock lock; + + vector*> multipliers; + + SeededPRNG G; + + map timers; + +public: + static void basic_setup(Player& P); + static void teardown(); + + HemiPrep(SubProcessor* proc, DataPositions& usage) : + RingPrep(proc, usage), SemiHonestRingPrep(proc, usage) + { + } + + void buffer_triples(); + void buffer_inverses(); +}; + +#endif /* PROTOCOLS_HEMIPREP_H_ */ diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp new file mode 100644 index 00000000..6621f744 --- /dev/null +++ b/Protocols/HemiPrep.hpp @@ -0,0 +1,87 @@ +/* + * HemiPrep.hpp + * + */ + +#ifndef PROTOCOLS_HEMIPREP_HPP_ +#define PROTOCOLS_HEMIPREP_HPP_ + +#include "HemiPrep.h" +#include "FHEOffline/PairwiseMachine.h" +#include "Tools/Bundle.h" + +template +PairwiseMachine* HemiPrep::pairwise_machine = 0; + +template +Lock HemiPrep::lock; + +template +void HemiPrep::teardown() +{ + if (pairwise_machine) + delete pairwise_machine; +} + +template +void HemiPrep::basic_setup(Player& P) +{ + assert(pairwise_machine == 0); + pairwise_machine = new PairwiseMachine(P); + auto& machine = *pairwise_machine; + auto& setup = machine.setup(); + setup.secure_init(P, machine, T::clear::length(), 40); +} + +template +void HemiPrep::buffer_triples() +{ + assert(this->proc != 0); + auto& P = this->proc->P; + + lock.lock(); + if (pairwise_machine == 0 or pairwise_machine->enc_alphas.empty()) + { + PlainPlayer P(this->proc->P.N, T::clear::type_char()); + if (pairwise_machine == 0) + basic_setup(P); + pairwise_machine->setup().covert_key_generation(P, + *pairwise_machine, 1); + pairwise_machine->enc_alphas.resize(1, pairwise_machine->pk); + } + lock.unlock(); + + if (multipliers.empty()) + for (int i = 1; i < P.num_players(); i++) + multipliers.push_back( + new Multiplier(i, *pairwise_machine, P, timers)); + + auto& FieldD = pairwise_machine->setup().FieldD; + Plaintext_ a(FieldD), b(FieldD), c(FieldD); + a.randomize(G); + b.randomize(G); + c.mul(a, b); + Bundle bundle(P); + pairwise_machine->pk.encrypt(a).pack(bundle.mine); + P.Broadcast_Receive(bundle, true); + Ciphertext C(pairwise_machine->pk); + for (auto m : multipliers) + { + C.unpack(bundle[P.get_player(-m->get_offset())]); + m->multiply_and_add(c, C, b); + } + assert(b.num_slots() == a.num_slots()); + assert(c.num_slots() == a.num_slots()); + for (unsigned i = 0; i < a.num_slots(); i++) + this->triples.push_back( + {{ a.element(i), b.element(i), c.element(i) }}); +} + +template +void HemiPrep::buffer_inverses() +{ + assert(this->proc != 0); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); +} + +#endif diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h new file mode 100644 index 00000000..e51c0f61 --- /dev/null +++ b/Protocols/HemiShare.h @@ -0,0 +1,40 @@ +/* + * HemiShare.h + * + */ + +#ifndef PROTOCOLS_HEMISHARE_H_ +#define PROTOCOLS_HEMISHARE_H_ + +#include "SemiShare.h" + +template class HemiPrep; + +template +class HemiShare : public SemiShare +{ + typedef HemiShare This; + typedef SemiShare super; + +public: + typedef SemiMC MAC_Check; + typedef DirectSemiMC Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef HemiPrep LivePrep; + + static const bool needs_ot = false; + + HemiShare() + { + } + template + HemiShare(const U& other) : + super(other) + { + } + +}; + +#endif /* PROTOCOLS_HEMISHARE_H_ */ diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index d520504f..29f509f6 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -188,7 +188,7 @@ public: template -class Direct_MAC_Check: public Separate_MAC_Check +class Direct_MAC_Check: public MAC_Check_ { typedef typename T::open_type open_type; @@ -196,7 +196,9 @@ class Direct_MAC_Check: public Separate_MAC_Check vector oss; public: - Direct_MAC_Check(const typename T::mac_key_type& ai, Names& Nms, int thread_num); + // legacy interface + Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai, Names& Nms, int thread_num); + Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai); ~Direct_MAC_Check(); void POpen_Begin(vector& values,const vector& S,const Player& P); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index e0f93048..06b360d0 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -487,7 +487,16 @@ void Parallel_MAC_Check::POpen_End(vector& values, template -Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type& ai, Names& Nms, int num) : Separate_MAC_Check(ai, Nms, num) { +Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai, + Names&, int) : + Direct_MAC_Check(ai) +{ +} + +template +Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai) : + MAC_Check_(ai) +{ open_counter = 0; } @@ -532,9 +541,7 @@ void Direct_MAC_Check::POpen_End(vector& values,const vector& S this->timers[RECV].start(); - for (int j=0; jtimers[RECV].stop(); open_counter++; diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index d231cfac..1b2393f4 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -32,8 +32,10 @@ public: virtual void POpen_Begin(vector& values,const vector& S,const Player& P) = 0; virtual void POpen_End(vector& values,const vector& S,const Player& P) = 0; - void POpen(vector& values,const vector& S,const Player& P); + virtual void POpen(vector& values,const vector& S,const Player& P); typename T::open_type POpen(const T& secret, const Player& P); + // alternative name to avoid conflict + typename T::open_type open(const T& secret, const Player& P) { return POpen(secret, P); } virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 03c4ee1d..3457426d 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -15,6 +15,8 @@ protected: typedef ReplicatedMC super; public: + virtual void POpen(vector& values, + const vector& S, const Player& P); virtual void POpen_Begin(vector& values, const vector& S, const Player& P); virtual void POpen_End(vector& values, @@ -35,6 +37,8 @@ class HashMaliciousRepMC : public MaliciousRepMC void reset(); void update(); + void finalize(const vector& values); + public: // emulate MAC_Check HashMaliciousRepMC(const typename T::value_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() @@ -47,6 +51,7 @@ public: HashMaliciousRepMC(); ~HashMaliciousRepMC(); + void POpen(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); @@ -60,6 +65,8 @@ class CommMaliciousRepMC : public MaliciousRepMC vector os; public: + void POpen(vector& values, const vector& S, + const Player& P); void POpen_Begin(vector& values, const vector& S, const Player& P); void POpen_End(vector& values, const vector& S, diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 32d8af01..c9cc1850 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -30,6 +30,13 @@ void MaliciousRepMC::POpen_End(vector& values, throw runtime_error("use subclass"); } +template +void MaliciousRepMC::POpen(vector&, + const vector&, const Player&) +{ + throw runtime_error("use subclass"); +} + template void MaliciousRepMC::Check(const Player& P) { @@ -60,11 +67,25 @@ HashMaliciousRepMC::~HashMaliciousRepMC() free(hash_state); } +template +void HashMaliciousRepMC::POpen(vector& values, + const vector& S, const Player& P) +{ + ReplicatedMC::POpen(values, S, P); + finalize(values); +} + template void HashMaliciousRepMC::POpen_End(vector& values, const vector& S, const Player& P) { ReplicatedMC::POpen_End(values, S, P); + finalize(values); +} + +template +void HashMaliciousRepMC::finalize(const vector& values) +{ os.reset_write_head(); for (auto& value : values) value.pack(os); @@ -118,6 +139,14 @@ void HashMaliciousRepMC::Check(const Player& P) } } +template +void CommMaliciousRepMC::POpen(vector& values, + const vector& S, const Player& P) +{ + POpen_Begin(values, S, P); + POpen_End(values, S, P); +} + template void CommMaliciousRepMC::POpen_Begin(vector& values, const vector& S, const Player& P) diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index 999e77b3..b9ca6b15 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -13,6 +13,9 @@ class MaliciousShamirMC : public ShamirMC { vector> reconstructions; + void finalize(vector& values, const vector& S, + const Player& P); + public: MaliciousShamirMC(); @@ -28,6 +31,8 @@ public: { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values, const vector& S, + const Player& P); void POpen_End(vector& values, const vector& S, const Player& P); }; diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index c2bc0bf7..0236f4ea 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -12,11 +12,27 @@ MaliciousShamirMC::MaliciousShamirMC() this->threshold = 2 * ShamirMachine::s().threshold; } +template +void MaliciousShamirMC::POpen(vector& values, + const vector& S, const Player& P) +{ + this->prepare(S, P); + this->exchange(P); + finalize(values, S, P); +} + template void MaliciousShamirMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void) P; + P.receive_all(this->os); + finalize(values, S, P); +} + +template +void MaliciousShamirMC::finalize(vector& values, + const vector& S, const Player& P) +{ int threshold = ShamirMachine::s().threshold; if (reconstructions.empty()) { @@ -36,7 +52,10 @@ void MaliciousShamirMC::POpen_End(vector& values, for (size_t i = 0; i < values.size(); i++) { for (size_t j = 0; j < shares.size(); j++) - shares[j].unpack(this->os[j]); + if (int(j) == P.my_num()) + shares[j] = S[i]; + else + shares[j].unpack(this->os[j]); T value = 0; for (int j = 0; j < threshold + 1; j++) value += shares[j] * reconstructions[threshold + 1][j]; diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 721b1f61..19a4f7cb 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -13,10 +13,9 @@ template class OTPrep : public virtual RingPrep { -protected: +public: typename T::TripleGenerator* triple_generator; -public: MascotParams params; OTPrep(SubProcessor* proc, DataPositions& usage); @@ -25,6 +24,7 @@ public: void set_protocol(typename T::Protocol& protocol); size_t data_sent(); + NamedCommStats comm_stats(); }; template diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 3ddd06ed..9a94f484 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -17,7 +17,6 @@ template OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : RingPrep(proc, usage), triple_generator(0) { - this->buffer_size = OnlineOptions::singleton.batch_size; } template @@ -34,13 +33,11 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) SubProcessor* proc = this->proc; assert(proc != 0); auto& ot_setups = BaseMachine::s().ot_setups.at(proc->Proc.thread_num); - assert(not ot_setups.empty()); - OTTripleSetup setup = ot_setups.back(); - ot_setups.pop_back(); - params.set_mac_key(typename T::mac_key_type::next(proc->MC.get_alphai())); + OTTripleSetup setup = ot_setups.get_fresh(); triple_generator = new typename T::TripleGenerator(setup, - proc->P.N, proc->Proc.thread_num, this->buffer_size, 1, - params, &proc->P); + proc->P.N, proc->Proc.thread_num, + OnlineOptions::singleton.batch_size, 1, + params, proc->MC.get_alphai(), &proc->P); triple_generator->multi_threaded = false; } @@ -119,4 +116,13 @@ size_t OTPrep::data_sent() return 0; } +template +NamedCommStats OTPrep::comm_stats() +{ + if (triple_generator) + return triple_generator->comm_stats(); + else + return {}; +} + #endif diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 05b07147..ba0a3a73 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -42,6 +42,11 @@ public: return "replicated " + T::type_string(); } + static int threshold(int) + { + return 1; + } + static Rep3Share constant(T value, int my_num, const T& alphai = {}) { return Rep3Share(value, my_num, alphai); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 8612aa9d..f41bf98b 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; #include "Tools/octetStream.h" @@ -26,11 +27,14 @@ template class Preprocessing; class ReplicatedBase { public: - PRNG shared_prngs[2]; + array shared_prngs; Player& P; ReplicatedBase(Player& P); + ReplicatedBase(Player& P, array& prngs); + + ReplicatedBase branch(); int get_n_relevant_players() { return P.num_players() - 1; } }; @@ -62,6 +66,9 @@ public: virtual void trunc_pr(const vector& regs, int size, SubProcessor& proc) { (void) regs, (void) size; (void) proc; throw not_implemented(); } + + virtual void start_exchange() { exchange(); } + virtual void stop_exchange() {} }; template @@ -75,7 +82,10 @@ public: typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; + static const bool uses_triples = false; + Replicated(Player& P); + Replicated(const ReplicatedBase& other); static void assign(T& share, const typename T::clear& value, int my_num) { @@ -103,6 +113,9 @@ public: void trunc_pr(const vector& regs, int size, SubProcessor& proc); T get_random(); + + void start_exchange(); + void stop_exchange(); }; #endif /* PROTOCOLS_REPLICATED_H_ */ diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index aeafc670..c9e8271c 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -29,6 +29,12 @@ Replicated::Replicated(Player& P) : ReplicatedBase(P) assert(T::length == 2); } +template +Replicated::Replicated(const ReplicatedBase& other) : + ReplicatedBase(other) +{ +} + inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) { assert(P.num_players() == 3); @@ -43,6 +49,18 @@ inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) shared_prngs[1].SetSeed(os.get_data()); } +inline ReplicatedBase::ReplicatedBase(Player& P, array& prngs) : + P(P) +{ + for (int i = 0; i < 2; i++) + shared_prngs[i].SetSeed(prngs[i]); +} + +inline ReplicatedBase ReplicatedBase::branch() +{ + return {P, shared_prngs}; +} + template ProtocolBase::~ProtocolBase() { @@ -128,6 +146,18 @@ void Replicated::exchange() P.pass_around(os[0], os[1], 1); } +template +void Replicated::start_exchange() +{ + P.send_relative(1, os[0]); +} + +template +void Replicated::stop_exchange() +{ + P.receive_relative(-1, os[1]); +} + template inline T Replicated::finalize_mul(int n) { diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index 496e4ff0..6bc657bc 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -14,6 +14,9 @@ class ReplicatedMC : public MAC_Check_Base octetStream o; octetStream to_send; + void prepare(const vector& S); + void finalize(vector& values, const vector& S); + public: // emulate MAC_Check ReplicatedMC(const typename T::value_type& _ = {}, int __ = 0, int ___ = 0) @@ -23,6 +26,7 @@ public: ReplicatedMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values,const vector& S,const Player& P); void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 1f3c781d..f3924a64 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -9,23 +9,44 @@ #include "ReplicatedMC.h" template -void ReplicatedMC::POpen_Begin(vector& values, +void ReplicatedMC::POpen(vector& values, const vector& S, const Player& P) +{ + prepare(S); + P.pass_around(to_send, o, -1); + finalize(values, S); +} + +template +void ReplicatedMC::POpen_Begin(vector&, + const vector& S, const Player& P) +{ + prepare(S); + P.send_relative(-1, to_send); +} + +template +void ReplicatedMC::prepare(const vector& S) { assert(T::length == 2); - (void)values; o.reset_write_head(); to_send.reset_write_head(); for (auto& x : S) x[0].pack(to_send); - P.pass_around(to_send, o, -1); } template void ReplicatedMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void)P; + P.receive_relative(1, o); + finalize(values, S); +} + +template +void ReplicatedMC::finalize(vector& values, + const vector& S) +{ values.resize(S.size()); for (size_t i = 0; i < S.size(); i++) { diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 5203690c..f7b601a6 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -54,7 +54,8 @@ template void ReplicatedRingPrep::buffer_triples() { assert(this->protocol != 0); - typename T::Protocol protocol(this->protocol->P); + // independent instance to avoid conflicts + typename T::Protocol protocol(this->protocol->branch()); generate_triples(this->triples, OnlineOptions::singleton.batch_size, &protocol); } @@ -264,8 +265,7 @@ void RingPrep::buffer_bits_without_check() int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits(n_relevant_players, vector(buffer_size)); typename T::Input input(proc, P); - for (int i = 0; i < P.num_players(); i++) - input.reset(i); + input.reset_all(P); for (int i = 0; i < n_relevant_players; i++) { int input_player = (base_player + i) % P.num_players(); @@ -274,20 +274,15 @@ void RingPrep::buffer_bits_without_check() SeededPRNG G; for (int i = 0; i < buffer_size; i++) input.add_mine(G.get_bit()); - input.send_mine(); - for (auto& x : player_bits[i]) - x = input.finalize_mine(); } else - { for (int i = 0; i < buffer_size; i++) input.add_other(input_player); - octetStream os; - P.receive_player(input_player, os, true); - for (auto& x : player_bits[i]) - input.finalize_other(input_player, x, os); - } } + input.exchange(); + for (int i = 0; i < n_relevant_players; i++) + for (auto& x : player_bits[i]) + x = input.finalize((base_player + i) % P.num_players()); auto& prot = *protocol; XOR(bits, player_bits[0], player_bits[1], buffer_size, prot, proc); for (int i = 2; i < n_relevant_players; i++) diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 97fcc9c6..8eefc5a2 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -29,12 +29,12 @@ class DirectSemiMC : public SemiMC public: DirectSemiMC() {} // emulate Direct_MAC_Check - DirectSemiMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) - { (void)_; (void)__; (void)___; (void)____; } + DirectSemiMC(const typename T::mac_key_type&, const Names& = {}, int = 0, int = 0) {} void POpen_(vector& values,const vector& S,const PlayerBase& P); - void POpen_Begin(vector& values,const vector& S,const Player& P) + void POpen(vector& values,const vector& S,const Player& P) { POpen_(values, S, P); } + void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); void Check(const Player& P) { (void)P; } diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index 386b68da..48656133 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -42,10 +42,24 @@ void DirectSemiMC::POpen_(vector& values, } template -void DirectSemiMC::POpen_End(vector& values, +void DirectSemiMC::POpen_Begin(vector& values, const vector& S, const Player& P) { - (void) values, (void) S, (void) P; + values.clear(); + values.insert(values.begin(), S.begin(), S.end()); + octetStream os; + for (auto& x : values) + x.pack(os); + P.send_all(os, true); +} + +template +void DirectSemiMC::POpen_End(vector& values, + const vector&, const Player& P) +{ + Bundle oss(P); + P.receive_all(oss); + direct_add_openings(values, P, oss); } #endif diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 2a42f636..20d5e4ad 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -52,6 +52,11 @@ public: static string type_short() { return "D" + string(1, T::type_char()); } + static int threshold(int nplayers) + { + return nplayers - 1; + } + static SemiShare constant(const clear& other, int my_num, const T& alphai = {}) { return SemiShare(other, my_num, alphai); diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 83341f64..558509c9 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -38,6 +38,8 @@ class Shamir : public ProtocolBase> int n_mul_players; public: + static const bool uses_triples = false; + Player& P; static U get_rec_factor(int i, int n); @@ -45,6 +47,8 @@ public: Shamir(Player& P); ~Shamir(); + Shamir branch(); + int get_n_relevant_players(); void reset(); @@ -52,7 +56,11 @@ public: void init_mul(); void init_mul(SubProcessor* proc); U prepare_mul(const T& x, const T& y, int n = -1); + void exchange(); + void start_exchange(); + void stop_exchange(); + T finalize_mul(int n = -1); T finalize(int n_input_players); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d361db1a..6ecf8663 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -34,6 +34,12 @@ Shamir::~Shamir() delete resharing; } +template +Shamir Shamir::branch() +{ + return P; +} + template int Shamir::get_n_relevant_players() { @@ -100,6 +106,25 @@ void Shamir::exchange() } } +template +void Shamir::start_exchange() +{ + if (P.my_num() < n_mul_players) + for (int offset = 1; offset < P.num_players(); offset++) + P.send_relative(offset, resharing->os[P.get_player(offset)]); +} + +template +void Shamir::stop_exchange() +{ + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + if (receive_from < n_mul_players) + P.receive_player(receive_from, os[receive_from], true); + } +} + template ShamirShare Shamir::finalize_mul(int n) { diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index fd98fe00..4a4228ce 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -15,10 +15,17 @@ class ShamirMC : public MAC_Check_Base { vector reconstruction; + bool send; + + void finalize(vector& values, const vector& S); + protected: vector os; int threshold; + void prepare(const vector& S, const Player& P); + void exchange(const Player& P); + public: ShamirMC() : threshold(ShamirMachine::s().threshold) {} @@ -31,6 +38,7 @@ public: ShamirMC() { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values,const vector& S,const Player& P); void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 6d362757..63b5ba39 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -10,14 +10,35 @@ void ShamirMC::POpen_Begin(vector& values, const vector& S, const Player& P) { (void) values; + prepare(S, P); + P.send_all(os[P.my_num()], true); +} + +template +void ShamirMC::prepare(const vector& S, const Player& P) +{ os.clear(); os.resize(P.num_players()); - bool send = P.my_num() <= threshold; + send = P.my_num() <= threshold; if (send) { for (auto& share : S) share.pack(os[P.my_num()]); } +} + +template +void ShamirMC::POpen(vector& values, const vector& S, + const Player& P) +{ + prepare(S, P); + exchange(P); + finalize(values, S); +} + +template +void ShamirMC::exchange(const Player& P) +{ for (int offset = 1; offset < P.num_players(); offset++) { int send_to = P.get_player(offset); @@ -37,7 +58,14 @@ template void ShamirMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void) P; + P.receive_all(os); + finalize(values, S); +} + +template +void ShamirMC::finalize(vector& values, + const vector& S) +{ int n_relevant_players = ShamirMachine::s().threshold + 1; if (reconstruction.empty()) { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index b312db5a..eef06c15 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -40,6 +40,11 @@ public: return "Shamir " + T::type_string(); } + static int threshold(int) + { + return ShamirMachine::s().threshold; + } + static ShamirShare constant(T value, int my_num, const T& alphai = {}) { return ShamirShare(value, my_num, alphai); diff --git a/Protocols/Share.h b/Protocols/Share.h index 80394fce..b011ecf4 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -71,6 +71,9 @@ class Share static DataFieldType field_type() { return T::field_type(); } + static int threshold(int nplayers) + { return nplayers - 1; } + static Share constant(const clear& aa, int my_num, const typename T::Scalar& alphai) { return Share(aa, my_num, alphai); } diff --git a/README.md b/README.md index 3ab0d11b..865321df 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ us, but you can also write an email to mp-spdz@googlegroups.com #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) This requires either a Linux distribution originally released 2011 or -later (glibc 2.12) or macOS High Sierra or later as well as Python 2 +later (glibc 2.12) or macOS High Sierra or later as well as Python 3 and basic command-line utilities. Download and unpack the [distribution](https://github.com/n1analytics/MP-SPDZ/releases), @@ -72,7 +72,7 @@ The following table lists all protocols that are fully supported. | --- | --- | --- | --- | --- | | Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear](#secret-sharing) | N/A | N/A | N/A | -| Semi-honest, dishonest majority | [Semi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi / Hemi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | @@ -128,14 +128,14 @@ phase outputs the amount of offline material required, which allows to compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7) + - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7). We recommend clang because it performs better. - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.16 - OpenSSL, tested against and 1.0.2 and 1.1.0 - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.65 - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.65 - 64-bit CPU - - Python 2.x + - Python 3.5 or later - NTL library for CowGear and the SPDZ-2 and Overdrive offline phases (optional; tested with NTL 10.5) - If using macOS, Sierra or later @@ -239,6 +239,7 @@ The following table shows all programs for dishonest-majority computation using | `semi-party.x` | OT-based | Mod prime | Semi-honest | `semi.sh` | | `semi2k-party.x` | OT-based | Mod 2^k | Semi-honest | `semi2k.sh` | | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | +| `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -254,13 +255,17 @@ security. Tiny denotes the adaption of SPDZ2k to the binary setting. In particular, the SPDZ2k sacrifice does not work for bits, so we replace it by cut-and-choose according to [Furukawa et -al.](https://eprint.iacr.org/2016/944.pdf). +al.](https://eprint.iacr.org/2016/944) CowGear denotes a covertly secure version of LowGear. The reason for this is the key generation that only achieves covert security. It is possible however to run full LowGear for triple generation by using `-s` with the desired security parameter. +Hemi denotes the stripped version version of LowGear for semi-honest +security similar to Semi, that is, generating additively shared Beaver +triples using semi-homomorphic encryption. + We will use MASCOT to demonstrate the use, but the other protocols work similarly. diff --git a/Scripts/hemi.sh b/Scripts/hemi.sh new file mode 100755 index 00000000..f0be05bf --- /dev/null +++ b/Scripts/hemi.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player hemi-party.x $* || exit 1 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index b1b5e2eb..431186fd 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -9,6 +9,20 @@ gdb_screen() screen -S :$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; gdb $prog -ex \"run $*\"" } +lldb_screen() +{ + prog=$1 + shift + IFS= + name=${*/-/} + IFS=' ' + echo debug $prog with arguments $* + echo name: $name + tmp=/tmp/$RANDOM + echo run > $tmp + screen -S :$i -d -m bash -l -c "lldb -s $tmp $prog -- $*" +} + run_player() { port=$((RANDOM%10000+10000)) bin=$1 @@ -42,6 +56,7 @@ run_player() { { if test $i = 0; then tee $log; else cat > $log; fi; } & done last_player=$(($players - 1)) + i=$last_player >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 } diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index 78194c6b..57ba1155 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,8 +1,5 @@ #!/bin/bash -echo 'MOD = -DGFP_MOD_SZ=4' >> CONFIG.mine - -make clean make -j4 ecdsa Fake-ECDSA.x run() diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 4853d7ff..0a519ea2 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -12,18 +12,18 @@ function test fi } -./compile.py tutorial - -for i in rep-field mal-rep-field ps-rep-field shamir mal-shamir cowgear semi mascot; do - test $i -done - ./compile.py -R 64 tutorial for i in ring brain mal-rep-ring ps-rep-ring semi2k spdz2k; do test $i done +./compile.py tutorial + +for i in rep-field mal-rep-field ps-rep-field shamir mal-shamir hemi cowgear semi mascot; do + test $i +done + ./compile.py -B 16 tutorial for i in replicated mal-rep-bin semi-bin yao tiny rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 31b4e686..e4668ce1 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -52,6 +52,16 @@ bool BitVector::parity() const #endif } +void BitVector::append(const BitVector& other, size_t length) +{ + assert(nbits % 8 == 0); + assert(length % 8 == 0); + assert(length <= other.nbits); + auto old_nbytes = nbytes; + resize(nbits + length); + memcpy(bytes + old_nbytes, other.bytes, length / 8); +} + void BitVector::randomize(PRNG& G) { G.get_octets(bytes, nbytes); diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 54e319b3..3c174a18 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -239,6 +239,8 @@ class BitVector return true; } + void append(const BitVector& other, size_t length); + void randomize(PRNG& G); template void randomize_blocks(PRNG& G); diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index 9c410d3b..8c78ce3d 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -81,13 +81,13 @@ void MMO::hashBlocks(void* output, const void* input) encrypt_and_xor<1>(output, output, IV[0]); } -template <> -void MMO::hashBlocks(void* output, const void* input) +template +void MMO::hashEightGfp(void* output, const void* input) { - if (gfp1::get_ZpD().get_t() < 2) + if (gfp_::get_ZpD().get_t() < 2) throw not_implemented(); - gfp1* out = (gfp1*)output; - hashBlocks<8, gfp1::N_BYTES>(output, input, sizeof(gfp1)); + gfp_* out = (gfp_*)output; + hashBlocks<8, gfp_::N_BYTES>(output, input, sizeof(gfp_)); for (int i = 0; i < 8; i++) out[i].zero_overhang(); int left = 8; @@ -97,7 +97,7 @@ void MMO::hashBlocks(void* output, const void* input) int now_left = 0; for (int j = 0; j < left; j++) if (mpn_cmp((mp_limb_t*) out[indices[j]].get_ptr(), - gfp1::get_ZpD().get_prA(), gfp1::t()) >= 0) + gfp_::get_ZpD().get_prA(), gfp_::t()) >= 0) { indices[now_left] = indices[j]; now_left++; @@ -105,19 +105,31 @@ void MMO::hashBlocks(void* output, const void* input) left = now_left; int block_size = sizeof(__m128i); - int n_blocks = DIV_CEIL(gfp1::size(), block_size); + int n_blocks = DIV_CEIL(gfp_::size(), block_size); for (int i = 0; i < n_blocks; i++) for (int j = 0; j < left; j++) { __m128i* addr = (__m128i*) out[indices[j]].get_ptr() + i; __m128i* in = (__m128i*) out[indices[j]].get_ptr(); auto tmp = aes_128_encrypt(_mm_loadu_si128(in), IV[i]); - memcpy(addr, &tmp, min(block_size, gfp1::size() - i * block_size)); + memcpy(addr, &tmp, min(block_size, gfp_::size() - i * block_size)); out[indices[j]].zero_overhang(); } } } +template <> +void MMO::hashBlocks(void* output, const void* input) +{ + hashEightGfp<1, GFP_MOD_SZ>(output, input); +} + +template <> +void MMO::hashBlocks(void* output, const void* input) +{ + hashEightGfp<3, 4>(output, input); +} + #define ZZ(F,N) \ template void MMO::hashBlocks(void*, const void*); #define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8) diff --git a/Tools/MMO.h b/Tools/MMO.h index 2631640f..f2fd2996 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -32,6 +32,8 @@ public: void hashBlocks(void* output, const void* input, size_t alloc_size); template void hashBlocks(void* output, const void* input); + template + void hashEightGfp(void* output, const void* input); template void outputOneBlock(octet* output); Key hash(const Key& input); diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index 9dc0b800..37eb2832 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -23,7 +23,7 @@ double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_t } -long long timespec_diff(struct timespec *start_time, struct timespec *end_time) +long long timespec_diff(const struct timespec *start_time, const struct timespec *end_time) { long long sec =end_time->tv_sec -start_time->tv_sec ; long long nsec=end_time->tv_nsec-start_time->tv_nsec; @@ -72,3 +72,19 @@ Timer& Timer::operator -=(const Timer& other) elapsed_time -= other.elapsed_time; return *this; } + +Timer& Timer::operator +=(const Timer& other) +{ + assert(clock_id == other.clock_id); + assert(not running); + elapsed_time += other.elapsed_time + other.elapsed_since_last_start(); + return *this; +} + +Timer& Timer::operator +=(const TimeScope& other) +{ + assert(clock_id == other.timer.clock_id); + assert(not running); + elapsed_time += other.timer.elapsed_since_last_start(); + return *this; +} diff --git a/Tools/time-func.h b/Tools/time-func.h index 144f11ff..8381d6ed 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -10,7 +10,9 @@ long long timeval_diff(struct timeval *start_time, struct timeval *end_time); double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_time); -long long timespec_diff(struct timespec *start_time, struct timespec *end_time); +long long timespec_diff(const struct timespec *start_time, const struct timespec *end_time); + +class TimeScope; class Timer { @@ -26,6 +28,8 @@ class Timer double idle(); Timer& operator-=(const Timer& other); + Timer& operator+=(const Timer& other); + Timer& operator+=(const TimeScope& other); private: timespec startv; @@ -33,11 +37,12 @@ class Timer long long elapsed_time; clockid_t clock_id; - long long elapsed_since_last_start(); + long long elapsed_since_last_start() const; }; class TimeScope { + friend class Timer; Timer& timer; public: @@ -83,7 +88,7 @@ inline void Timer::reset() clock_gettime(clock_id, &startv); } -inline long long Timer::elapsed_since_last_start() +inline long long Timer::elapsed_since_last_start() const { timespec endv; clock_gettime(clock_id, &endv); diff --git a/compile.py b/compile.py index 48586cf3..77c2c577 100755 --- a/compile.py +++ b/compile.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # ===== Compiler usage instructions =====