diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d08a074..f54cf68e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.2 + +- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission +- Binary computation for dishonest majority using secret sharing +- Mathematical functions from [SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA) +- Fixed security bug: CowGear would reuse triples. + ## 0.1.1 (Aug 6, 2019) - ECDSA diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 121e0494..2ed6016b 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -101,12 +101,12 @@ def determine_scope(block, options): used_from_scope = set() def find_in_scope(reg, scope): - if scope is None: - return False - elif reg in scope.defined_registers: - return True - else: - return find_in_scope(reg, scope.scope) + while True: + if scope is None: + return False + elif reg in scope.defined_registers: + return True + scope = scope.scope def read(reg, n): if last_def[reg] == -1: @@ -386,7 +386,7 @@ class Merger: last_print_str = None last = defaultdict(lambda: defaultdict(lambda: None)) last_open = deque() - last_text_input = None + last_text_input = [None, None] depths = [0] * len(block.instructions) self.depths = depths @@ -474,10 +474,14 @@ class Merger: # will be merged if isinstance(instr, TextInputInstruction): - if last_text_input is not None and \ - type(block.instructions[last_text_input]) is not type(instr): - add_edge(last_text_input, n) - last_text_input = n + if last_text_input[0] is not None: + if instr.merge_id() != \ + block.instructions[last_text_input[0]].merge_id(): + add_edge(last_text_input[0], n) + last_text_input[1] = last_text_input[0] + elif last_text_input[1] is not None: + add_edge(last_text_input[1], n) + last_text_input[0] = n if isinstance(instr, merge_classes): open_nodes.add(n) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 9dd55292..80d54c6a 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -80,6 +80,12 @@ def LTZ(s, a, k, kappa): Trunc(t, a, k, k - 1, kappa, True) subsfi(s, t, 0) +def LessThanZero(a, k, kappa): + import types + res = types.sint() + LTZ(res, a, k, kappa) + return res + def Trunc(d, a, k, m, kappa, signed): """ d = a >> m @@ -153,6 +159,8 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): k: bit length of a m: compile-time integer """ + if m == 0: + return a if k == int(program.options.ring): # cannot work with bit length k+1 tmp = TruncRing(None, a, k, m - 1, signed) @@ -359,7 +367,7 @@ def CarryOutAux(d, a, kappa): movs(d, a[0][1]) # carry out with carry-in bit c -def CarryOut(res, a, b, c, kappa): +def CarryOut(res, a, b, c=0, kappa=None): """ res = last carry bit in addition of a and b @@ -368,8 +376,9 @@ def CarryOut(res, a, b, c, kappa): c: initial carry-in bit """ k = len(a) + import types d = [program.curr_block.new_reg('s') for i in range(k)] - t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)] + t = [[types.sint() for i in range(k)] for i in range(4)] s = [program.curr_block.new_reg('s') for i in range(3)] for i in range(k): mulm(t[0][i], b[i], a[i]) @@ -377,12 +386,19 @@ def CarryOut(res, a, b, c, kappa): addm(t[2][i], b[i], a[i]) subs(t[3][i], t[2][i], t[1][i]) d[i] = [t[3][i], t[0][i]] - mulsi(s[0], d[-1][0], c) - adds(s[1], d[-1][1], s[0]) + s[0] = d[-1][0] * c + s[1] = d[-1][1] + s[0] d[-1][1] = s[1] CarryOutAux(res, d[::-1], kappa) +def CarryOutLE(a, b, c=0): + """ Little-endian version """ + import types + res = types.sint() + CarryOut(res, a[::-1], b[::-1], c) + return res + def BitLTL(res, a, b, kappa): """ res = a > (n_shift + m + 1)) - overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1)) - res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m)) + if signed: + a += (1 << (k - 1)) + if program.Program.prog.use_trunc_pr: + res = sint() + trunc_pr(res, a, k, m) + else: + # extra bit to mask overflow + r_bits = [sint.get_random_bit() for i in range(k + 1)] + n_shift = n_ring - len(r_bits) + tmp = a + sint.bit_compose(r_bits) + masked = (tmp << n_shift).reveal() + shifted = (masked << 1 >> (n_shift + m + 1)) + overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1)) + res = shifted - sint.bit_compose(r_bits[m:k]) + \ + (overflow << (k - m)) + if signed: + res -= (1 << (k - m - 1)) return res def TruncPrField(a, k, m, kappa=None): + if m == 0: + return a if kappa is None: kappa = 40 @@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False): w = types.cint(int(2.9142 * two_power(l))) - 2 * b x = alpha - b * w y = a * w - y = y.round(2 * l + 1, l, kappa, round_nearest) + y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False) x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) for i in range(theta-1): - y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest) - y = y.round(2 * l + 1, l + 1, kappa, round_nearest) - x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest) - x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest) + y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, + round_nearest, + signed=False) + y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False) + x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest, + signed=False) + x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest, + signed=False) x2 = types.sint() comparison.Mod2m(x2, x, 2 * l, l, kappa, False) x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True) - y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest) + y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, + round_nearest, signed=False) y = y.round(2 * l + 1, l - 1, kappa, round_nearest) return y diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 568091ab..12f4f33e 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -894,6 +894,55 @@ class inputfloat(base.TextInputInstruction): req_node.increment((self.field_type, 'input', player), \ 4 * self.get_size()) +@base.vectorize +class inputmixed(base.TextInputInstruction): + __slots__ = [] + code = base.opcodes['INPUTMIXED'] + field_type = 'modp' + # the following has to match TYPE: (N_DEST, N_PARAM) + types = { + 0: (1, 0), + 1: (1, 1), + 2: (4, 1) + } + type_ids = { + 'int': 0, + 'fix': 1, + 'float': 2 + } + + def __init__(self, name, *args): + try: + type_id = self.type_ids[name] + except: + pass + super(inputmixed_class, self).__init__(type_id, *args) + + @property + def arg_format(self): + for i in self.bases(): + t = self.args[i] + yield 'int' + for j in range(self.types[t][0]): + yield 'sw' + for j in range(self.types[t][1]): + yield 'int' + yield 'p' + + def bases(self): + i = 0 + while i < len(self.args): + yield i + i += sum(self.types[self.args[i]]) + 2 + + def add_usage(self, req_node): + for i in self.bases(): + t = self.args[i] + player = self.args[i + sum(self.types[t]) + 1] + n_dest = self.types[t][0] + req_node.increment((self.field_type, 'input', player), \ + n_dest * self.get_size()) + @base.gf2n class startinput(base.RawInputInstruction): r""" Receive inputs from player $p$. """ @@ -957,6 +1006,11 @@ class print_reg_plain(base.IOInstruction): code = base.opcodes['PRINTREGPLAIN'] arg_format = ['c'] +class cond_print_plain(base.IOInstruction): + r""" Conditionally print the value of a register. """ + code = base.opcodes['CONDPRINTPLAIN'] + arg_format = ['c', 'c'] + class print_int(base.IOInstruction): r""" Print only the value of register \verb|ci| to stdout. """ __slots__ = [] @@ -1383,6 +1437,9 @@ class muls(base.VarArgsInstruction, base.DataInstruction): def merge_id(self): # can merge different sizes + # but not if large + if self.get_size() > 100: + return type(self), self.get_size() return type(self) # def expand(self): @@ -1468,6 +1525,14 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction): for reg in self.args[i + 2:i + self.args[i]]: yield reg +@base.vectorize +class trunc_pr(base.VarArgsInstruction): + """ Probalistic truncation for semi-honest computation """ + """ with honest majority """ + __slots__ = [] + code = base.opcodes['TRUNC_PR'] + arg_format = tools.cycle(['sw','s','int','int']) + ### ### CISC-style instructions ### diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index a4f7a290..d4baa25d 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -89,6 +89,7 @@ opcodes = dict( MULS = 0xA6, MULRS = 0xA7, DOTPRODS = 0xA8, + TRUNC_PR = 0xA9, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -102,6 +103,7 @@ opcodes = dict( INPUT = 0x60, INPUTFIX = 0xF0, INPUTFLOAT = 0xF1, + INPUTMIXED = 0xF2, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -168,6 +170,7 @@ opcodes = dict( READFILESHARE = 0xBE, CONDPRINTSTR = 0xBF, PRINTFLOATPREC = 0xE0, + CONDPRINTPLAIN = 0xE1, GBITDEC = 0x184, GBITCOM = 0x185, # Secure socket @@ -767,21 +770,6 @@ class ClearShiftInstruction(ClearImmediate): ### Jumps etc ### -class dummywrite(Instruction): - """ Dummy instruction to create source node in the dependency graph, - preventing read-before-write warnings. """ - __slots__ = [] - - def __init__(self, *args, **kwargs): - self.arg_format = [arg.reg_type + 'w' for arg in args] - super(dummywrite, self).__init__(*args, **kwargs) - - def execute(self): - pass - - def get_encoding(self): - return [] - class JumpInstruction(Instruction): __slots__ = ['jump_arg'] diff --git a/Compiler/library.py b/Compiler/library.py index 8f42f320..25cf861e 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -5,6 +5,7 @@ from Compiler import instructions,instructions_base,comparison,program,util import inspect,math import random import collections +import operator def get_program(): return instructions.program @@ -93,16 +94,25 @@ def print_ln(s='', *args): print_str(s, *args) print_char('\n') -def print_ln_if(cond, s): +def print_ln_if(cond, ss, *args): if util.is_constant(cond): if cond: - print_ln(s) + print_ln(ss, *args) else: - s += ' ' * ((len(s) + 3) % 4) - s += '\n' - while s: - cond.print_if(s[:4]) - s = s[4:] + subs = ss.split('%s') + assert len(subs) == len(args) + 1 + cond = cint.conv(cond) + for i, s in enumerate(subs): + if i != 0: + cond_print_plain(cond, cint.conv(args[i - 1])) + if i < len(args): + s += ' ' * ((-len(s)) % 4) + else: + s += ' ' * ((-len(s) + 3) % 4) + s += '\n' + while s: + cond.print_if(s[:4]) + s = s[4:] def print_float_precision(n): print_float_prec(n) @@ -798,19 +808,23 @@ def range_loop(loop_body, start, stop=None, step=None): lambda x: ((stop - start) / step) * x[0] def for_range(start, stop=None, step=None): + """ Execute loop bodies consecutively """ def decorator(loop_body): range_loop(loop_body, start, stop, step) return loop_body return decorator def for_range_parallel(n_parallel, n_loops): + """ Execute up to n_parallel loop bodies in parallel """ return map_reduce_single(n_parallel, n_loops) -def for_range_opt(n_loops): - return map_reduce_single(None, n_loops) +def for_range_opt(n_loops, budget=None): + """ Execute loop bodies in parallel up to an optimization budget """ + return map_reduce_single(None, n_loops, budget=budget) def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], - reducer=lambda *x: [], mem_state=None): + reducer=lambda *x: [], mem_state=None, budget=None): + budget = budget or get_program().budget if not (isinstance(n_parallel, int) or n_parallel is None): raise CompilerException('Number of parallel executions' \ 'must be constant') @@ -848,14 +862,16 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], r = reducer(mem_state, state) write_state_to_memory(r) else: - n_parallel_reg = MemValue(regint(0)) + if n_loops == 0: + return + regint.push(0) parent_block = get_block() - @while_do(lambda x: x + n_parallel_reg <= n_loops, regint(0)) + @while_do(lambda x: x + regint.pop() <= n_loops, regint(0)) def _(i): state = tuplify(initializer()) k = 0 block = get_block() - while k < n_loops and (len(get_block()) < get_program().budget \ + while k < n_loops and (len(get_block()) < budget \ or k == 0) \ and block is get_block(): j = i + k @@ -865,7 +881,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], write_state_to_memory(r) global n_opt_loops n_opt_loops = k - n_parallel_reg.write(k) + regint.push(k) return i + k my_n_parallel = n_opt_loops loop_rounds = n_loops / my_n_parallel @@ -915,12 +931,46 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], return decorator def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}): + """ + Execute loop bodies in up to n_threads threads, + up to n_parallel in parallel per thread + """ return map_reduce(n_threads, n_parallel, n_loops, \ lambda *x: [], lambda *x: [], thread_mem_req) +def for_range_opt_multithread(n_threads, n_loops): + """ + Execute loop bodies in up to n_threads threads, + in parallel up to an optimization budget per thread + """ + return for_range_multithread(n_threads, None, n_loops) + +def multithread(n_threads, n_items): + """ + Distribute the computation of n_items to n_threads threads, + but leave the in-thread repetition up to the user + """ + if n_threads == 1 or n_items == 1: + return lambda loop_body: loop_body(0, n_items) + return map_reduce(n_threads, None, n_items, initializer=lambda: [], + reducer=None, looping=False) + def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ - thread_mem_req={}): + thread_mem_req={}, looping=True): n_threads = n_threads or 1 + if isinstance(n_loops, list): + split = n_loops + n_loops = reduce(operator.mul, n_loops) + def decorator(loop_body): + def new_body(i): + indices = [] + for n in reversed(split): + indices.insert(0, i % n) + i /= n + return loop_body(*indices) + return new_body + new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req) + return lambda loop_body: new_dec(decorator(loop_body)) if n_threads == 1 or n_loops == 1: dec = map_reduce_single(n_parallel, n_loops, initializer, reducer) if thread_mem_req: @@ -937,12 +987,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') state = tuple(initializer()) def f(inc): + base = args[get_arg()][0] + if not looping: + return loop_body(base, thread_rounds + inc) if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) mem_state = Array(len(state), type(state[0]) \ if state else cint, args[get_arg()][1]) - base = args[get_arg()][0] @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): @@ -1014,8 +1066,9 @@ def while_loop(loop_body, condition, arg): pushint(arg if isinstance(arg,regint) else regint(arg)) def loop_fn(): result = loop_body(regint.pop()) + cont = condition(result) pushint(result) - return condition(result) + return cont if_statement(pre_condition, lambda: do_while(loop_fn)) regint.pop() @@ -1278,7 +1331,7 @@ def sint_cint_division(a, b, k, f, kappa): theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) - sign_a = sint(1) - 2 * sint(a < 0) + sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) @@ -1326,7 +1379,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): y = a.extend(2 *k) * w y = y.round(2*k, f, kappa, nearest, signed=True) - for i in range(theta): + for i in range(theta - 1): x = x.extend(2 * k) y = y.extend(2 * k) * (alpha + x).extend(2 * k) x = x * x @@ -1358,7 +1411,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): # For simplex, we can get rid of computing abs(b) temp = None if simplex_flag == False: - temp = b.less_than(0, 2 * k) + temp = comparison.LessThanZero(b, 2 * k, kappa) elif simplex_flag == True: temp = cint(0) diff --git a/Compiler/ml.py b/Compiler/ml.py new file mode 100644 index 00000000..4c8e59b4 --- /dev/null +++ b/Compiler/ml.py @@ -0,0 +1,814 @@ +import mpc_math, math + +from Compiler.types import * +from Compiler.types import _unreduced_squant +from Compiler.library import * + +def log_e(x): + return mpc_math.log_fx(x, math.e) + +def exp(x): + return mpc_math.pow_fx(math.e, x) + +def sanitize(x, raw, lower, upper): + exp_limit = 2 ** (x.k - x.f - 1) + limit = math.log(exp_limit) + if get_program().options.ring: + res = raw + else: + res = (x > limit).if_else(upper, raw) + return (x < -limit).if_else(lower, res) + +def sigmoid(x): + return sigmoid_from_e_x(x, exp(-x)) + +def sigmoid_from_e_x(x, e_x): + return sanitize(x, 1 / (1 + e_x), 0, 1) + +def sigmoid_prime(x): + sx = sigmoid(x) + return sx * (1 - sx) + +def lse_0_from_e_x(x, e_x): + return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0) + +def lse_0(x): + return lse_0_from_e_x(x, exp(x)) + +def relu_prime(x): + return (0 <= x) + +def relu(x): + return (0 < x).if_else(x, 0) + +def progress(x): + return + print_ln(x) + time() + +def set_n_threads(n_threads): + Layer.n_threads = n_threads + Optimizer.n_threads = n_threads + +class Layer: + n_threads = 1 + +class Output(Layer): + def __init__(self, N, debug=False): + self.N = N + self.X = sfix.Array(N) + self.Y = sfix.Array(N) + self.nabla_X = sfix.Array(N) + self.l = MemValue(sfix(-1)) + self.e_x = sfix.Array(N) + self.debug = debug + self.weights = cint.Array(N) + self.weights.assign_all(1) + self.weight_total = N + + nablas = lambda self: () + thetas = lambda self: () + reset = lambda self: None + + def divisor(self, divisor, size): + return cfix(1.0 / divisor, size=size) + + def forward(self, N=None): + N = N or self.N + lse = sfix.Array(N) + @multithread(self.n_threads, N) + def _(base, size): + x = self.X.get_vector(base, size) + y = self.Y.get_vector(base, size) + e_x = exp(-x) + self.e_x.assign(e_x, base) + lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base) + e_x = self.e_x.get_vector(0, N) + self.l.write(sum(lse) * \ + self.divisor(self.N, 1)) + + def backward(self): + @multithread(self.n_threads, self.N) + def _(base, size): + diff = sigmoid_from_e_x(self.X.get_vector(base, size), + self.e_x.get_vector(base, size)) - \ + self.Y.get_vector(base, size) + assert sfix.f == cfix.f + diff *= self.weights.get_vector(base, size) + self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \ + base) + # @for_range_opt(len(diff)) + # def _(i): + # self.nabla_X[i] = self.nabla_X[i] * self.weights[i] + if self.debug: + a = cfix.Array(len(diff)) + a.assign(diff.reveal()) + @for_range(len(diff)) + def _(i): + x = a[i] + print_ln_if((x < -1.001) + (x > 1.001), 'sigmoid') + #print_ln('%s', x) + + def set_weights(self, weights): + self.weights.assign(weights) + self.weight_total = sum(weights) + +class DenseBase(Layer): + thetas = lambda self: (self.W, self.b) + nablas = lambda self: (self.nabla_W, self.nabla_b) + + def backward_params(self, f_schur_Y): + N = self.N + tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + + @for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out]) + def _(j, k): + assert self.d == 1 + a = [f_schur_Y[i][0][k] for i in range(N)] + b = [self.X[i][0][j] for i in range(N)] + tmp[j][k] = sfix.unreduced_dot_product(a, b) + + if self.d_in * self.d_out < 100000: + print 'reduce at once' + @multithread(self.n_threads, self.d_in * self.d_out) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + else: + @for_range_opt(self.d_in) + def _(i): + self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() + + self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N)) + for j in range(self.d)) for i in range(self.d_out)) + + progress('nabla W/b') + +class Dense(DenseBase): + def __init__(self, N, d_in, d_out, d=1, activation='id'): + self.activation = activation + if activation == 'id': + self.f = lambda x: x + elif activation == 'relu': + self.f = relu + self.f_prime = relu_prime + elif activation == 'sigmoid': + self.f = sigmoid + self.f_prime = sigmoid_prime + + self.N = N + self.d_in = d_in + self.d_out = d_out + self.d = d + + self.X = MultiArray([N, d, d_in], sfix) + self.Y = MultiArray([N, d, d_out], sfix) + self.W = sfix.Matrix(d_in, d_out) + self.b = sfix.Array(d_out) + + self.reset() + + self.nabla_Y = MultiArray([N, d, d_out], sfix) + self.nabla_X = MultiArray([N, d, d_in], sfix) + self.nabla_W = sfix.Matrix(d_in, d_out) + self.nabla_W.assign_all(0) + self.nabla_b = sfix.Array(d_out) + + self.f_input = MultiArray([N, d, d_out], sfix) + + def reset(self): + d_in = self.d_in + d_out = self.d_out + r = math.sqrt(6.0 / (d_in + d_out)) + @for_range(d_in) + def _(i): + @for_range(d_out) + def _(j): + self.W[i][j] = sfix.get_random(-r, r) + self.b.assign_all(0) + + def compute_f_input(self): + prod = MultiArray([self.N, self.d, self.d_out], sfix) + @for_range_opt_multithread(self.n_threads, self.N) + def _(i): + self.X[i].plain_mul(self.W, res=prod[i]) + + @for_range_opt_multithread(self.n_threads, self.N) + def _(i): + @for_range_opt(self.d) + def _(j): + v = prod[i][j].get_vector() + self.b.get_vector() + self.f_input[i][j].assign(v) + progress('f input') + + def forward(self): + self.compute_f_input() + self.Y.assign_vector(self.f(self.f_input.get_vector())) + + def backward(self, compute_nabla_X=True): + N = self.N + d = self.d + d_out = self.d_out + X = self.X + Y = self.Y + W = self.W + b = self.b + nabla_X = self.nabla_X + nabla_Y = self.nabla_Y + nabla_W = self.nabla_W + nabla_b = self.nabla_b + + if self.activation == 'id': + f_schur_Y = nabla_Y + else: + f_prime_bit = MultiArray([N, d, d_out], sint) + f_schur_Y = MultiArray([N, d, d_out], sfix) + + self.compute_f_input() + f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector())) + + progress('f prime') + + @for_range_opt(N) + def _(i): + f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i]) + + progress('f prime schur Y') + + if compute_nabla_X: + @for_range_opt(N) + def _(i): + if self.activation == 'id': + nabla_X[i] = nabla_Y[i].mul_trans(W) + else: + nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W) + + progress('nabla X') + + self.backward_params(f_schur_Y) + +class QuantizedDense(DenseBase): + def __init__(self, N, d_in, d_out): + self.N = N + self.d_in = d_in + self.d_out = d_out + self.d = 1 + self.H = math.sqrt(1.5 / (d_in + d_out)) + + self.W = sfix.Matrix(d_in, d_out) + self.nabla_W = self.W.same_shape() + self.T = sint.Matrix(d_in, d_out) + self.b = sfix.Array(d_out) + self.nabla_b = self.b.same_shape() + + self.X = MultiArray([N, 1, d_in], sfix) + self.Y = MultiArray([N, 1, d_out], sfix) + self.nabla_Y = self.Y.same_shape() + + def reset(self): + @for_range(self.d_in) + def _(i): + @for_range(self.d_out) + def _(j): + self.W[i][j] = sfix.get_random(-1, 1) + self.b.assign_all(0) + + def forward(self): + @for_range_opt(self.d_in) + def _(i): + @for_range_opt(self.d_out) + def _(j): + over = self.W[i][j] > 0.5 + under = self.W[i][j] < -0.5 + self.T[i][j] = over.if_else(1, under.if_else(-1, 0)) + over = self.W[i][j] > 1 + under = self.W[i][j] < -1 + self.W[i][j] = over.if_else(1, under.if_else(-1, self.W[i][j])) + @for_range_opt(self.N) + def _(i): + assert self.d_out == 1 + self.Y[i][0][0] = self.b[0] + self.H * sfix._new( + sint.dot_product([self.T[j][0] for j in range(self.d_in)], + [self.X[i][0][j].v for j in range(self.d_in)])) + + def backward(self, compute_nabla_X=False): + assert not compute_nabla_X + self.backward_params(self.nabla_Y) + +class Dropout: + def __init__(self, N, d1, d2=1): + self.N = N + self.d1 = d1 + self.d2 = d2 + self.X = MultiArray([N, d1, d2], sfix) + self.Y = MultiArray([N, d1, d2], sfix) + self.nabla_Y = MultiArray([N, d1, d2], sfix) + self.nabla_X = MultiArray([N, d1, d2], sfix) + self.alpha = 0.5 + self.B = MultiArray([N, d1, d2], sint) + + def forward(self): + assert self.alpha == 0.5 + @for_range(self.N) + def _(i): + @for_range(self.d1) + def _(j): + @for_range(self.d2) + def _(k): + self.B[i][j][k] = sint.get_random_bit() + self.Y = self.X.schur(self.B) + + def backward(self): + self.nabla_X = self.nabla_Y.schur(self.B) + +class QuantBase(object): + n_threads = 1 + + @staticmethod + def new_squant(): + class _(squant): + @classmethod + def get_input_from(cls, player, size=None): + return cls._new(sint.get_input_from(player, size=size)) + return _ + + def __init__(self, input_shape, output_shape): + self.input_shape = input_shape + self.output_shape = output_shape + + self.input_squant = self.new_squant() + self.output_squant = self.new_squant() + + self.X = MultiArray(input_shape, self.input_squant) + self.Y = MultiArray(output_shape, self.output_squant) + + def temp_shape(self): + return [0] + +class QuantConvBase(QuantBase): + fewer_rounds = True + temp_weights = None + temp_inputs = None + + @classmethod + def init_temp(cls, layers): + size = 0 + for layer in layers: + size = max(size, reduce(operator.mul, layer.temp_shape())) + cls.temp_weights = sfix.Array(size) + cls.temp_inputs = sfix.Array(size) + + def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride): + super(QuantConvBase, self).__init__(input_shape, output_shape) + + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.stride = stride + + self.weight_squant = self.new_squant() + self.bias_squant = self.new_squant() + + self.weights = MultiArray(weight_shape, self.weight_squant) + self.bias = Array(output_shape[-1], self.bias_squant) + + self.unreduced = MultiArray(self.output_shape, sint, + address=self.Y.address) + + assert(weight_shape[-1] == input_shape[-1]) + assert(bias_shape[0] == output_shape[-1]) + assert(len(bias_shape) == 1) + assert(len(input_shape) == 4) + assert(len(output_shape) == 4) + assert(len(weight_shape) == 4) + + def input_from(self, player): + for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant: + s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) + self.weights.input_from(player, budget=100000) + self.bias.input_from(player) + print 'WARNING: assuming that bias quantization parameters are correct' + + self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params) + + def dot_product(self, iv, wv, out_y, out_x, out_c): + bias = self.bias[out_c] + acc = squant.unreduced_dot_product(iv, wv) + acc.v += bias.v + acc.res_params = self.output_squant.params + #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() + self.unreduced[0][out_y][out_x][out_c] = acc.v + + def reduction(self): + unreduced = self.unreduced + n_summands = self.n_summands() + start_timer(2) + n_outputs = reduce(operator.mul, self.output_shape) + if n_outputs % self.n_threads == 0: + n_per_thread = n_outputs / self.n_threads + @for_range_opt_multithread(self.n_threads, self.n_threads) + def _(i): + res = _unreduced_squant( + sint.load_mem(unreduced.address + i * n_per_thread, + size=n_per_thread), + (self.input_squant.params, self.weight_squant.params), + self.output_squant.params, + n_summands).reduce_after_mul() + res.store_in_mem(self.Y.address + i * n_per_thread) + else: + @for_range_opt_multithread(self.n_threads, self.output_shape[1]) + def _(out_y): + self.Y[0][out_y].assign_vector(_unreduced_squant( + unreduced[0][out_y].get_vector(), + (self.input_squant.params, self.weight_squant.params), + self.output_squant.params, + n_summands).reduce_after_mul()) + stop_timer(2) + + def temp_shape(self): + return list(self.output_shape[1:]) + [self.n_summands()] + + def prepare_temp(self): + shape = self.temp_shape() + inputs = MultiArray(shape, self.input_squant, + address=self.temp_inputs) + weights = MultiArray(shape, self.weight_squant, + address=self.temp_weights) + return inputs, weights + +class QuantConv2d(QuantConvBase): + def n_summands(self): + _, weights_h, weights_w, _ = self.weight_shape + _, inputs_h, inputs_w, n_channels_in = self.input_shape + return weights_h * weights_w * n_channels_in + + def forward(self, N=1): + assert(N == 1) + assert(self.weight_shape[0] == self.output_shape[-1]) + + _, weights_h, weights_w, _ = self.weight_shape + _, inputs_h, inputs_w, n_channels_in = self.input_shape + _, output_h, output_w, n_channels_out = self.output_shape + + stride_h, stride_w = self.stride + padding_h, padding_w = (weights_h // 2, weights_w // 2) + + if self.fewer_rounds: + inputs, weights = self.prepare_temp() + + @for_range_opt_multithread(self.n_threads, + [output_h, output_w, n_channels_out]) + def _(out_y, out_x, out_c): + in_x_origin = (out_x * stride_w) - padding_w + in_y_origin = (out_y * stride_h) - padding_h + iv = [] + wv = [] + for filter_y in range(weights_h): + in_y = in_y_origin + filter_y + inside_y = (0 <= in_y) * (in_y < inputs_h) + for filter_x in range(weights_w): + in_x = in_x_origin + filter_x + inside_x = (0 <= in_x) * (in_x < inputs_w) + inside = inside_y * inside_x + if inside is 0: + continue + for in_c in range(n_channels_in): + iv += [self.X[0][in_y * inside_y] + [in_x * inside_x][in_c]] + wv += [self.weights[out_c][filter_y][filter_x][in_c]] + wv[-1] *= inside + if self.fewer_rounds: + inputs[out_y][out_x][out_c].assign(iv) + weights[out_y][out_x][out_c].assign(wv) + else: + self.dot_product(iv, wv, out_y, out_x, out_c) + + if self.fewer_rounds: + @for_range_opt_multithread(self.n_threads, + list(self.output_shape[1:])) + def _(out_y, out_x, out_c): + self.dot_product(inputs[out_y][out_x][out_c], + weights[out_y][out_x][out_c], + out_y, out_x, out_c) + + self.reduction() + +class QuantDepthwiseConv2d(QuantConvBase): + def n_summands(self): + _, weights_h, weights_w, _ = self.weight_shape + return weights_h * weights_w + + def forward(self, N=1): + assert(N == 1) + assert(self.weight_shape[-1] == self.output_shape[-1]) + assert(self.input_shape[-1] == self.output_shape[-1]) + + _, weights_h, weights_w, _ = self.weight_shape + _, inputs_h, inputs_w, n_channels_in = self.input_shape + _, output_h, output_w, n_channels_out = self.output_shape + + stride_h, stride_w = self.stride + padding_h, padding_w = (weights_h // 2, weights_w // 2) + + depth_multiplier = 1 + + if self.fewer_rounds: + inputs, weights = self.prepare_temp() + + @for_range_opt_multithread(self.n_threads, + [output_h, output_w, n_channels_in]) + def _(out_y, out_x, in_c): + for m in range(depth_multiplier): + oc = m + in_c * depth_multiplier + in_x_origin = (out_x * stride_w) - padding_w + in_y_origin = (out_y * stride_h) - padding_h + iv = [] + wv = [] + for filter_y in range(weights_h): + for filter_x in range(weights_w): + in_x = in_x_origin + filter_x + in_y = in_y_origin + filter_y + inside = (0 <= in_x) * (in_x < inputs_w) * \ + (0 <= in_y) * (in_y < inputs_h) + if inside is 0: + continue + iv += [self.X[0][in_y][in_x][in_c]] + wv += [self.weights[0][filter_y][filter_x][oc]] + wv[-1] *= inside + if self.fewer_rounds: + inputs[out_y][out_x][oc].assign(iv) + weights[out_y][out_x][oc].assign(wv) + else: + self.dot_product(iv, wv, out_y, out_x, oc) + + if self.fewer_rounds: + @for_range_opt_multithread(self.n_threads, + list(self.output_shape[1:])) + def _(out_y, out_x, out_c): + self.dot_product(inputs[out_y][out_x][out_c], + weights[out_y][out_x][out_c], + out_y, out_x, out_c) + + self.reduction() + +class QuantAveragePool2d(QuantBase): + def __init__(self, input_shape, output_shape, filter_size): + super(QuantAveragePool2d, self).__init__(input_shape, output_shape) + self.filter_size = filter_size + + def input_from(self, player): + print 'WARNING: assuming that input and output quantization parameters are the same' + for s in self.input_squant, self.output_squant: + s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) + + def forward(self, N=1): + assert(N == 1) + + _, input_h, input_w, n_channels_in = self.input_shape + _, output_h, output_w, n_channels_out = self.output_shape + + n = input_h * input_w + print 'divisor: ', n + + assert output_h == output_w == 1 + assert n_channels_in == n_channels_out + + padding_h, padding_w = (0, 0) + stride_h, stride_w = (2, 2) + filter_h, filter_w = self.filter_size + + @for_range_opt(output_h) + def _(out_y): + @for_range_opt(output_w) + def _(out_x): + @for_range_opt(n_channels_in) + def _(c): + in_x_origin = (out_x * stride_w) - padding_w + in_y_origin = (out_y * stride_h) - padding_h + fxs = (-in_x_origin).max(0) + #fxe = min(filter_w, input_w - in_x_origin) + fys = (-in_y_origin).max(0) + #fye = min(filter_h, input_h - in_y_origin) + acc = 0 + #fc = 0 + for i in range(filter_h): + filter_y = fys + i + for j in range(filter_w): + filter_x = fxs + j + in_x = in_x_origin + filter_x + in_y = in_y_origin + filter_y + acc += self.X[0][in_y][in_x][c].v + #fc += 1 + logn = int(math.log(n, 2)) + acc = (acc + n / 2) + if 2 ** logn == n: + acc = acc.round(self.output_squant.params.k + logn, + logn, nearest=True) + else: + acc = acc.int_div(sint(n), + self.output_squant.params.k + logn) + #acc = min(255, max(0, acc)) + self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) + +class QuantReshape(QuantBase): + def __init__(self, input_shape, _, output_shape): + super(QuantReshape, self).__init__(input_shape, output_shape) + + def input_from(self, player): + print 'WARNING: assuming that input and output quantization parameters are the same' + _ = self.new_squant() + for s in self.input_squant, _, self.output_squant: + s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) + for i in range(2): + sint.get_input_from(player) + + def forward(self, N=1): + assert(N == 1) + # reshaping is implicit + self.Y.assign(self.X) + +class QuantSoftmax(QuantBase): + def input_from(self, player): + print 'WARNING: assuming that input and output quantization parameters are the same' + for s in self.input_squant, self.output_squant: + s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) + + def forward(self, N=1): + assert(N == 1) + assert(len(self.input_shape) == 2) + + # just print the best + def comp(left, right): + c = left[1].v.greater_than(right[1].v, self.input_squant.params.k) + #print_ln('comp %s %s %s', c.reveal(), left[1].v.reveal(), right[1].v.reveal()) + return [c.if_else(x, y) for x, y in zip(left, right)] + print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal()) + +class Optimizer: + n_threads = Layer.n_threads + + def forward(self, N): + for j in range(len(self.layers) - 1): + self.layers[j].forward() + self.layers[j + 1].X.assign(self.layers[j].Y) + self.layers[-1].forward(N) + + def backward(self): + for j in range(1, len(self.layers)): + self.layers[-j].backward() + self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X) + self.layers[0].backward(compute_nabla_X=False) + + def run(self): + i = MemValue(0) + @do_while + def _(): + if self.X_by_label is not None: + N = self.layers[0].N + assert self.layers[-1].N == N + assert N % 2 == 0 + n = N / 2 + @for_range(n) + def _(i): + self.layers[-1].Y[i] = 0 + self.layers[-1].Y[i + n] = 1 + n_per_epoch = int(math.ceil(1. * max(len(X) for X in + self.X_by_label) / n)) + print '%d runs per epoch' % n_per_epoch + indices_by_label = [] + for label, X in enumerate(self.X_by_label): + indices = regint.Array(n * n_per_epoch) + indices_by_label.append(indices) + indices.assign(i % len(X) for i in range(len(indices))) + indices.shuffle() + @for_range(n_per_epoch) + def _(j): + j = MemValue(j) + for label, X in enumerate(self.X_by_label): + indices = indices_by_label[label] + @for_range_multithread(self.n_threads, 1, n) + def _(i): + idx = indices[i + j * n_per_epoch] + self.layers[0].X[i + label * n] = X[idx] + self.forward(None) + self.backward() + self.update(i) + else: + self.forward(None) + self.backward() + self.update(i) + loss = self.layers[-1].l + if self.report_loss: + print_ln('loss after epoch %s: %s', i, loss.reveal()) + else: + print_ln('done with epoch %s', i) + time() + i.iadd(1) + res = (i < self.n_epochs) + if self.tol > 0: + res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() + return res + print_ln('finished after %s epochs', i) + +class Adam(Optimizer): + def __init__(self, layers, n_epochs): + self.alpha = .001 + self.beta1 = 0.9 + self.beta2 = 0.999 + self.epsilon = 10 ** -8 + self.n_epochs = n_epochs + + self.layers = layers + self.ms = [] + self.vs = [] + self.gs = [] + self.thetas = [] + for layer in layers: + for nabla in layer.nablas(): + self.gs.append(nabla) + for x in self.ms, self.vs: + x.append(nabla.same_shape()) + for theta in layer.thetas(): + self.thetas.append(theta) + + self.mhat_factors = Array(n_epochs, sfix) + self.vhat_factors = Array(n_epochs, sfix) + + for i in range(n_epochs): + for factors, beta in ((self.mhat_factors, self.beta1), + (self.vhat_factors, self.beta2)): + factors[i] = 1. / (1 - beta ** (i + 1)) + + def update(self, i_epoch): + for m, v, g, theta in zip(self.ms, self.vs, self.gs, self.thetas): + @for_range_opt(len(m)) + def _(k): + m[k] = self.beta1 * m[k] + (1 - self.beta1) * g[k] + v[k] = self.beta2 * v[k] + (1 - self.beta2) * g[k] ** 2 + mhat = m[k] * self.mhat_factors[i_epoch] + vhat = v[k] * self.vhat_factors[i_epoch] + theta[k] = theta[k] - self.alpha * mhat / \ + mpc_math.sqrt(vhat) + self.epsilon + +class SGD(Optimizer): + def __init__(self, layers, n_epochs, debug=False, report_loss=False): + self.momentum = 0.9 + self.layers = layers + self.n_epochs = n_epochs + self.thetas = [] + self.nablas = [] + self.delta_thetas = [] + for layer in layers: + self.nablas.extend(layer.nablas()) + self.thetas.extend(layer.thetas()) + for theta in layer.thetas(): + self.delta_thetas.append(theta.same_shape()) + self.gamma = MemValue(sfix(0.01)) + self.debug = debug + self.report_loss = report_loss + self.tol = 0.000 + self.X_by_label = None + + def reset(self, X_by_label=None): + self.X_by_label = X_by_label + for y in self.delta_thetas: + y.assign_all(0) + for layer in self.layers: + layer.reset() + + def update(self, i_epoch): + for nabla, theta, delta_theta in zip(self.nablas, self.thetas, + self.delta_thetas): + @for_range_opt_multithread(self.n_threads, len(nabla)) + def _(k): + old = delta_theta[k] + if isinstance(old, Array): + old = old.get_vector() + red_old = self.momentum * old + new = self.gamma * nabla[k] + diff = red_old - new + delta_theta[k] = diff + theta[k] = theta[k] + delta_theta[k] + if self.debug: + for x, name in (old, 'old'), (red_old, 'red_old'), \ + (new, 'new'), (diff, 'diff'): + x = x.reveal() + print_ln_if((x > 1000) + (x < -1000), + name + ': %s %s %s %s', + *[y.v.reveal() for y in old, red_old, \ + new, diff]) + if self.debug: + d = delta_theta.get_vector().reveal() + a = cfix.Array(len(d.v)) + a.assign(d) + @for_range(len(a)) + def _(i): + x = a[i] + print_ln_if((x > 1000) + (x < -1000), + 'update len=%d' % len(nabla)) + a.assign(nabla.get_vector().reveal()) + @for_range(len(a)) + def _(i): + x = a[i] + print_ln_if((x > 1000) + (x < -1000), + 'nabla len=%d' % len(nabla)) + self.gamma.imul(1 - 10 ** - 6) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py new file mode 100644 index 00000000..23926f24 --- /dev/null +++ b/Compiler/mpc_math.py @@ -0,0 +1,752 @@ +## +# @file +# Arithmetic Module for Complex Math Operations +# +# Implements trigonometric and logarithmic functions. + +import math +from Compiler import floatingpoint +from Compiler import types +from Compiler import comparison +from Compiler import program +# polynomials as enumerated on Hart's book +## +# @private +p_3307 = [1.57079632679489000000000, -0.64596409750624600000000, + 0.07969262624616700000000, -0.00468175413531868000000, + 0.00016044118478735800000, -0.00000359884323520707000, + 0.00000005692172920657320, -0.00000000066880348849204, + 0.00000000000606691056085, -0.00000000000004375295071, + 0.00000000000000025002854] +## +# @private +p_3508 = [1.00000000000000000000, -0.50000000000000000000, + 0.04166666666666667129, -0.00138888888888888873, + 0.00002480158730158702, -0.00000027557319223933, + 0.00000000208767569817, -0.00000000001147074513, + 0.00000000000004779454, -0.00000000000000015612, + 0.00000000000000000040] +## +# @private +p_1045 = [1.000000077443021686, 0.693147180426163827795756, + 0.224022651071017064605384, 0.055504068620466379157744, + 0.009618341225880462374977, 0.001332730359281437819329, + 0.000155107460590052573978, 0.000014197847399765606711, + 0.000001863347724137967076] +## +# @private +p_2524 = [-2.05466671951, -8.8626599391, + +6.10585199015, +4.81147460989] +## +# @private +q_2524 = [+0.353553425277, +4.54517087629, + +6.42784209029, +1] +## +# @private +p_5102 = [+21514.05962602441933193254468, +73597.43380288444240814980706, + +100272.5618306302784970511863, +69439.29750032252337059765503, + +25858.09739719099025716567793, +5038.63918550126655793779119, + +460.1588804635351471161727227, +15.08767735870030987717455528, + +0.07523052818757628444510729539] +## +# @private +q_5102 = [+21514.05962602441933193298234, +80768.78701155924885176713209, + +122892.6789092784776298743322, +97323.20349053555680260434387, + +42868.57652046408093184006664, +10401.13491566890057005103878, + +1289.75056911611097141145955, +68.51937831018968013114024294, + +1] +## +# @private +p_4737 = [-9338.550897341021522505385079, +43722.68009378241623148489754, + -86008.12066370804865047446067, +92190.57592175496843898184959, + -58360.27724533928122075635101, +22081.61324178027161353562222, + -4805.541226761699661564427739, +542.2148323255220943742314911, + -24.94928894422502466205102672, 0.2222361619461131578797029272] +## +# @private +q_4737 =[-9338.550897341021522505384935, +45279.10524333925315190231067, + -92854.24688696401422824346529, +104687.2504366298224257408682, + -70581.74909396877350961227976, +28972.22947326672977624954443, + -7044.002024719172700685571406, +935.7104153502806086331621628, + -56.83369358538071475796209327, 1] +## +# @private +p_4754 = [-6.90859801, +12.85564644, -5.94939208] + +## +# @private +q_4754 = [-6.92529156, +14.20305096, -8.27925501, 1] + +# all inputs are calcualted in radians hence we need some conversion. +pi = math.radians(180) +pi_over_2 = math.radians(90) + +## +# truncates values regardless of the input type. (It always rounds down) +# @param x: coefficient to be truncated. +# +# @return truncated sint value of x +def trunc(x): + if type(x) is types.sfix: + return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True) + elif type(x) is types.sfloat: + v, p, z, s = floatingpoint.FLRound(x, 0) + #return types.sfloat(v, p, z, s, x.err) + return types.sfloat(v, p, z, s) + return x + + +## +# loads integer to fractional type (sint) +# @param x: coefficient to be truncated. +# +# @return returns sfix, sfloat loaded value +def load_sint(x, l_type): + if l_type is types.sfix: + return types.sfix.from_sint(x) + elif l_type is types.sfloat: + return x + return x + + +## +# evaluates a Polynomial to a given x in a privacy preserving manner. +# Inputs can be of any kind of register, secret or otherwise. +# +# @param p_c: Polynomial coefficients. (Array) +# +# @param x: Value to which the polynomial p_c is evaluated to.(register) +# +# @return the evaluation of the polynomial. return type depends on inputs. +def p_eval(p_c, x): + degree = len(p_c) - 1 + if type(x) is types.sfix: + # ignore coefficients smaller than precision + for c in reversed(p_c): + if c < 2 ** -(x.f + 1): + degree -= 1 + else: + break + pre_mults = floatingpoint.PreOpL(lambda a,b,_: a * b, + [x] * degree) + local_aggregation = 0 + # Evaluation of the Polynomial + for i, pre_mult in zip(p_c[1:], pre_mults): + local_aggregation += pre_mult.mul_no_reduce(x.coerce(i)) + return local_aggregation.reduce_after_mul() + p_c[0] + + +## +# reduces the input to [0,90) and returns whether the reduced value is +# greater than \Pi and greater than Pi over 2 +# @param x: value of any type to be reduced to the [0,90) interval +# +# @return w: reduced angle in either fixed or floating point . +# +# @return b1: \{0,1\} value. Returns one when reduction to 2*\pi +# is greater than \pi +# +# @return b2: \{0,1\} value. Returns one when reduction to +# \pi is greater than \pi/2. +def sTrigSub(x): + # reduction to 2* \pi + f = x * (1.0 / (2 * pi)) + f = load_sint(trunc(f), type(x)) + y = x - (f) * (2 * pi) + # reduction to \pi + b1 = y > pi + w = b1 * ((2 * pi - y) - y) + y + # reduction to \pi/2 + b2 = w > pi_over_2 + w = b2 * ((pi - w) - w) + w + # returns scaled angle and boolean flags + return w, b1, b2 + +# kernel method calls -- they are built in a generic way + + +## +# Kernel sin. Returns the sin of a given angle on the [0, \pi/2) interval and +# adjust the sign in case the angle was reduced on the [0,360) interval +# +# @param w: fractional value for an angle on the [0, \pi) interval. +# +# @return returns the sin of w. +def ssin(w, s): + # calculates the v of w for polynomial evaluation + v = w * (1.0 / pi_over_2) + v_2 = v ** 2 + # adjust sign according to the movement in the reduction + b = s * (-2) + 1 + # calculate the sin using polynomial evaluation + local_sin = b * v * p_eval(p_3307, v_2) + return local_sin + + +## +# Kernel cos. Returns the cos of a given angle on the [0.pi/2) +# interval and adjust +# the sign in case the angle was reduced on the [0,360) interval. +# +# @param w: fractional value for an angle on the [0,\pi) interval. +# +# @param s: \{0,1\} value. Corresponding to b2. Returns 1 if the angle +# was reduced from an angle in the [\pi/2,\pi) interval. +# +# @return returns the cos of w (sfix). +def scos(w, s): + # calculates the v of the w. + v = w + v_2 = v ** 2 + # adjust sign according to the movement in the reduction + b = s * (-2) + 1 + # calculate the cos using polynomial evaluation + local_cos = b * p_eval(p_3508, v_2) + return local_cos + + +# facade method calls --it is built in a generic way + +## +# Returns the sin of any given fractional value. +# +# @param x: fractional input (sfix, sfloat). +# +# @return returns the sin of x (sfix, sfloat). +def sin(x): + # reduces the angle to the [0,\pi/2) interval. + w, b1, b2 = sTrigSub(x) + # returns the sin with sign correction + return ssin(w, b1) + + +## +# Returns the sin of any given fractional value. +# +# @param x: fractional input (sfix, sfloat). +# +# @return returns the sin of x (sfix, sfloat). +def cos(x): + # reduces the angle to the [0,\pi/2) interval. + w, b1, b2 = sTrigSub(x) + + # returns the sin with sign correction + return scos(w, b2) + + +## +# Returns the tan (sfix, sfloat) of any given fractional value. +# +# @param x: fractional input (sfix, sfloat). +# +# @return returns the tan of x (sifx, sfloat). +def tan(x): + # reduces the angle to the [0,\pi/2) interval. + w, b1, b2 = sTrigSub(x) + # calculates the sin and the cos. + local_sin = ssin(w, b1) + local_cos = scos(w, b2) + # obtains the local tan + local_tan = local_sin/local_cos + return local_tan + + +## +# Returns the result of 2^a for any unbounded number +# @param a: exponent for 2^a +# +# @return returns the value of 2^a if it is within the range +@types.vectorize +def exp2_fx(a): + if types.program.options.ring: + sint = types.sint + intbitint = types.intbitint + # how many bits to use from integer part + n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) + n_bits = a.f + n_int_bits + n_shift = int(types.program.options.ring) - a.k + r_bits = [sint.get_random_bit() for i in range(a.k)] + shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal() + masked_bits = (shifted >> n_shift).bit_decompose(a.k) + lower_overflow = sint() + comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1], + r_bits[a.f-1::-1]) + lower_r = sint.bit_compose(r_bits[:a.f]) + lower_masked = sint.bit_compose(masked_bits[:a.f]) + lower = lower_r + lower_masked - (lower_overflow << (a.f)) + c = types.sfix._new(lower, k=a.k, f=a.f) + higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits], + r_bits[a.f:n_bits], + carry_in=lower_overflow, + get_carry=True) + d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]), + k=a.k, f=a.f) + e = p_eval(p_1045, c) + g = d * e + small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False, + nearest=types.sfix.round_nearest), + k=a.k, f=a.f) + carry = comparison.CarryOutLE(masked_bits[n_bits:-1], + r_bits[n_bits:-1], + higher_bits[-1]) + # should be for free + highest_bits = intbitint.ripple_carry_adder( + masked_bits[n_bits:-1], [0] * (a.k - n_bits), + carry_in=higher_bits[-1]) + bits_to_check = [x.bit_xor(y) + for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])] + t = floatingpoint.KMul(bits_to_check) + # sign + s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry) + return s.if_else(t.if_else(small_result, 0), g) + else: + # obtain absolute value of a + s = a < 0 + a = (s * (-2) + 1) * a + # isolates fractional part of number + b = trunc(a) + c = a - load_sint(b, type(a)) + # squares integer part of a + d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a)) + # evaluates fractional part of a in p_1045 + e = p_eval(p_1045, c) + g = d * e + return (1 - s) * g + s * ((types.sfix(1)) / g) + + +## +# Returns the result of log_2(x) for any unbounded number. This is +# achieved by changing x into f*2^n where f is bounded by[0.5, 1]. +# Then the polynomials are used to calculate the log_2 of f, +# which is then just added to n. +# +# @param x: input for log_2 (sfix, sint). +# +# @return returns (sfix) the value of log2(X) +@types.vectorize +def log2_fx(x): + if type(x) is types.sfix: + # transforms sfix to f*2^n, where f is [o.5,1] bounded + # obtain number bounded by [0,5 and 1] by transforming input to sfloat + v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa) + p -= x.f + vlen = x.f + else: + d = types.sfloat(x) + v, p, vlen = d.v, d.p, d.vlen + # isolates mantisa of d, now the n can be also substituted by the + # secret shared p from d in the expresion above. + v = load_sint(v, type(x)) + w = (1.0 / (2 ** (vlen))) + v = v * w + # polynomials for the log_2 evaluation of f are calculated + P = p_eval(p_2524, v) + Q = p_eval(q_2524, v) + # the log is returned by adding the result of the division plus p. + a = P / Q + load_sint(vlen + p, type(x)) + return a # *(1-(f.z))*(1-f.s)*(1-f.error) + + +## +# Returns the value of the expression x^y where both inputs +# are secret shared. It uses log2_fx together with +# exp2_fx to calcualte the expresion 2^{y*log2(x)}. +# +# @param x: (sfix) secret shared base. +# +# @param y: (sfix, clear types) secret shared exponent. +# +# @return returns the value of x^y +def pow_fx(x, y): + log2_x =0 + # obtains log2(x) + if (type(x) == int or type(x) == float): + log2_x = math.log(x,2) + else: + log2_x = log2_fx(x) + # obtains y * log2(x) + exp = y * log2_x + # returns 2^(y*log2(x)) + return exp2_fx(exp) + + +## +# Returns the value of the expression log_b(x) where x is +# secret shared. It uses log2_fx to calculate the expression +# logb(2)*log2(x). +# +# @param x:(sfix, sint) secret shared coefficient for log. +# +# @param b:(int) base for log operation. +# +# @return returns (sfix) the value of logb(x). +def log_fx(x, b): + # calculates logb(2) + logb_2 = math.log(2, b) + # returns logb(2) * log2(x) + return logb_2 * log2_fx(x) + + +## +# Returns the absolute value of a fix point number. +# The method is also applicable to sfloat, +# however, more efficient mechanisms can be devised. +# +# @param x: (sfix) +# +# @return (sfix) unsigned +def abs_fx(x): + s = x < 0 + return (1 - 2 * s) * x + + +## +# Floors the input and stores the value into a sflix register +# @param x: coefficient to be floored. +# +# @return floored sint value of x +def floor_fx(x): + return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x)) + + +### sqrt methods + + +## +# obtains the most significative bit (MSB) +# of a given input. The size of the vector +# is tuned to the needs of sqrt. +# @param b: number from which you obtain the +# most significative bit. +# @param k: number of bits for which +# an output of size (k+1) if even +# is going to be produced. +# @return z: index array for MSB of size +# k or K+1 if even. +def MSB(b, k): + # calculation of z + # x in order 0 - k + if (k > types.program.bit_length): + raise OverflowError("The supported bit \ + lenght of the application is smaller than k") + + x_order = b.bit_decompose(k) + x = [0] * k + # x i now inverted + for i in range(k - 1, -1, -1): + x[k - 1 - i] = x_order[i] + # y is inverted for PReOR and then restored + y_order = floatingpoint.PreOR(x) + + # y in order (restored in orginal order + y = [0] * k + for i in range(k - 1, -1, -1): + y[k - 1 - i] = y_order[i] + + # obtain z + z = [0] * (k + 1 - k % 2) + for i in range(k - 1): + z[i] = y[i] - y[i + 1] + z[k - 1] = y[k - 1] + + return z + + +## +# Similar to norm_SQ, saves rounds by not +# calculating v and c. +# +# @param b: sint input to be normalized. +# @param k: bitsize of the input, by definition +# its value is either sfix.k or program.bit_lengthh +# @return m_odd: the parity of most signficative bit index m +# @return m: index of most significative bit +# @return w: 2^m/2 or 2^ (m-1) /2 +def norm_simplified_SQ(b, k): + z = MSB(b, k) + # construct m + #m = types.sint(0) + m_odd = 0 + for i in range(k): + #m = m + (i + 1) * z[i] + # determine the parity of the input + if (i % 2 == 0): + m_odd = m_odd + z[i] + + # construct w, + k_over_2 = k / 2 + 1 + w_array = [0] * (k_over_2) + w_array[0] = z[0] + for i in range(1, k_over_2): + w_array[i] = z[2 * i - 1] + z[2 * i] + + # w aggregation + w = types.sint(0) + for i in range(k_over_2): + w += (2 ** i) * w_array[i] + + # return computed values + #return m_odd, m, w + return m_odd, None, w + + +## +# Obtains the sqrt using our custom mechanism +# for any sfix input value. +# no restrictions on the size of f. +# +# @param x: secret shared input from which the sqrt +# is calucalted, +# +# @return g: approximated sqrt +def sqrt_simplified_fx(x): + # fix theta (number of iterations) + theta = max(int(math.ceil(math.log(types.sfix.k))), 6) + + # process to use 2^(m/2) approximation + m_odd, m, w = norm_simplified_SQ(x.v, x.k) + # process to set up the precision and allocate correct 2**f + if x.f % 2 == 1: + m_odd = (1 - 2 * m_odd) + m_odd + w = (w * 2 - w) * (1-m_odd) + w + # map number to use sfix format and instantiate the number + w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2)) + # obtains correct 2 ** (m/2) + w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w + # produce x/ 2^(m/2) + y_0 = types.cfix(1.0) / w + + # from this point on it sufices to work sfix-wise + g_0 = (y_0 * x) + h_0 = y_0 * types.cfix(0.5) + gh_0 = g_0 * h_0 + + ## initialization + g = g_0 + h = h_0 + gh = gh_0 + + for i in range(1, theta - 2): + r = (3 / 2.0) - gh + g = g * r + h = h * r + gh = g * h + + # newton + r = (3 / 2.0) - gh + h = h * r + H = 4 * (h * h) + H = H * x + H = (3) - H + H = h * H + g = H * x + g = g + + return g + + +## +# Calculates the normSQ of a number +# @param x: number from which the norm is going to be extracted +# @param k: bitsize of x +# +# @return c: where c = x*v where c is bounded by 2^{k-1} and 2^k +# @return v: where v = 2^k-m +# @return m: where m = MSB +# @return w: where w = 2^{m/2} if m is oeven and 2^{m-1 / 2} otherwise +def norm_SQ(b, k): + # calculation of z + # x in order 0 - k + z = MSB(b,k) + # now reverse bits of z[i] to generate v + v = types.sint(0) + for i in range(k): + v += (2**(k - i - 1)) * z[i] + c = b * v + + # construct m + m = types.sint(0) + for i in range(k): + m = m + (i+1) * z[i] + + # construct w, changes from what is on the paper + # and the documentation + k_over_2= k/2+1#int(math.ceil((k/2.0)))+1 + w_array = [0]*(k_over_2 ) + w_array[0] = z[0] + for i in range(1, k_over_2): + w_array[i] = z[2 * i - 1] + z[2 * i] + + w = types.sint(0) + for i in range(k_over_2): + w += (2 ** i) * w_array[i] + + # return computed values + return c, v, m, w + + +## +# Given f and k, returns a linear approximation of 1/x^{1/2} +# escalated by s^f. +# Method only works for sfix inputs. It uses the normSQ. +# the method is an implementation of [Liedel2012] +# @param x: number from which the approximation is caluclated +# @param k: bitsize of x +# @param f: precision of the input f +# +# @return c: Some approximation of (1/x^{1/2} * 2^f) *K +# where K is close to 1 +def lin_app_SQ(b, k, f): + + alpha = types.cfix((-0.8099868542) * 2 ** (k)) + beta = types.cfix(1.787727479 * 2 ** (2 * k)) + + # obtain normSQ parameters + c, v, m, W = norm_SQ(types.sint(b), k) + + # c is now escalated + w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k + + + # m even or odd determination + m_bit = types.sint() + comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False) + m = load_sint(m_bit, types.sfix) + + # w times v this way both terms have 2^3k and can be symplified + w = w * v + factor = 1.0 / (2 ** (3.0 * k - 2 * f)) + w = w * factor # w escalated to 3k -2 * f + # normalization factor W* 1/2 ^{f/2} + w = w * W * types.cfix(1.0 / (2 ** (f / 2.0))) + # now we need to elminate an additional root of 2 in case m was odd + sqr_2 = types.cfix((2 ** (1 / 2.0))) + w = (1 - m) * w + sqr_2 * w * m + + return w + + +## +# Given bitsize k and precision f, it calulates the square root of x. +# @param x: number from which the norm is going to be extracted +# @param k: bitsize of x. +# @param f: precision of x. +# +# @return g: square root of de-scaled input x +def sqrt_fx(x_l, k, f): + factor = 1.0 / (2.0 ** f) + + x = load_sint(x_l, types.sfix) * factor + + theta = int(math.ceil(math.log(k/5.4))) + + y_0 = lin_app_SQ(x_l,k,f) #cfix(1.0/ (cx ** (1/2.0))) # lin_app_SQ(x_l,5,2) + + y_0 = y_0 * factor #*((1.0/(2.0 ** f))) + g_0 = y_0 * x + + + #g = mpc_math.load_sint(mpc_math.trunc(g_0),types.sfix) + h_0 = y_0 *(0.5) + gh_0 = g_0 * h_0 + + ##initialization + g= g_0 + h= h_0 + gh =gh_0 + + for i in range(1,theta-2): #to implement \in [1,\theta-2] + r = (3/2.0) - gh + g = g * r + h = h * r + gh = g * h + + # newton + r = (3/2.0) - gh + h = h * r + H = 4 * (h * h) + H = H * x + H = (3) - H + H = h * H + g = H * x + g = g #* (0.5) + + return g + +## +# Returns the sqrt (sfix) of any given fractional +# value as long as it can be rounded to a integral value +# to 2^f precision. +# +# Note that sqrt only works as long as this inequality is respected: +# 3*k - 2 *f < x.f (x.f by default is 20) +# @param x: fractional input (sfix). +# +# @return returns the aTan of x (sifx). +@types.vectorize +def sqrt(x, k = types.sfix.k, f = types.sfix.f): + + if (3 *k -2 * f >= types.sfix.f): + return sqrt_simplified_fx(x) + # raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ") + else: + param = trunc(x *(2 ** (f))) + return sqrt_fx(param ,k ,f) + + +## +# Returns the aTan (sfix) of any given fractional value. +# +# @param x: fractional input (sfix). +# +# @return returns the aTan of x (sifx). +def atan(x): + # obtain absolute value of x + s = x < 0 + x_abs = (s * (-2) + 1) * x + # angle isolation + b = x_abs > 1 + v = (types.cfix(1.0) / x_abs) + v = (1 - b) * (x_abs - v) + v + v_2 =v*v + + # range of polynomial coefficients + assert x.k - x.f >= 18 + P = p_eval(p_5102, v_2) + Q = p_eval(q_5102, v_2) + + # padding + y = v * (P / Q) + y_pi_over_two = pi_over_2 - y + + # sign correction + y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two + y = (1 - s) * (y - (-y)) + (-y) + + return y + + +## +# Returns the aSin (sfix) of any given fractional value. +# +# @param x: fractional input (sfix). valid interval is -1.0 <= x <= 1 +# +# @return returns the aSin of x (sfix). +def asin(x): + # Square x + x_2 = x*x + # trignometric identities + sqrt_l = sqrt(1- (x_2)) + x_sqrt_l =x / sqrt_l + return atan(x_sqrt_l) + + +## +# Returns the aCos (sfix) of any given fractional value. +# +# @param x: fractional input (sfix). -1.0 < x < 1 +# +# @return returns the aCos of x (sifx). +def acos(x): + y = asin(x) + return pi_over_2 - y diff --git a/Compiler/program.py b/Compiler/program.py index 1b09986e..76c3b00b 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -44,8 +44,10 @@ class Program(object): if (param != -1) + sum(x != 0 for x in(options.ring, options.field, options.binary)) > 1: raise CompilerError('can only use one out of -p, -B, -R, -F') - self.bit_length = int(options.ring) or int(options.binary) \ - or int(options.field) + if options.ring: + self.bit_length = int(options.ring) - 1 + else: + self.bit_length = int(options.binary) or int(options.field) if not self.bit_length: self.bit_length = BIT_LENGTHS[param] print 'Default bit length:', self.bit_length @@ -71,6 +73,7 @@ class Program(object): self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w') self.types = {} self.budget = int(self.options.budget) + self.verbose = False self.to_merge = [Compiler.instructions.asm_open_class, \ Compiler.instructions.gasm_open_class, \ Compiler.instructions.muls_class, \ @@ -82,10 +85,13 @@ class Program(object): Compiler.instructions.asm_input_class, \ Compiler.instructions.gasm_input_class, Compiler.instructions.inputfix_class, - Compiler.instructions.inputfloat_class] + Compiler.instructions.inputfloat_class, + Compiler.instructions.inputmixed_class, + Compiler.instructions.trunc_pr_class] import Compiler.GC.instructions as gc self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] + self.use_trunc_pr = False Program.prog = self self.reset_values() @@ -452,6 +458,8 @@ class Tape: else: self.alloc_pool = defaultdict(set) self.purged = False + self.n_rounds = 0 + self.n_to_merge = 0 def __len__(self): return len(self.instructions) @@ -506,6 +514,8 @@ class Tape: instructions = self.instructions for inst in instructions: inst.add_usage(req_node) + req_node.num['all', 'round'] = self.n_rounds + req_node.num['all', 'inv'] = self.n_to_merge def __str__(self): return self.name @@ -530,7 +540,7 @@ class Tape: self.basicblocks.append(sub) self.active_basicblock = sub self.req_node.add_block(sub) - print 'Compiling basic block', sub.name + #print 'Compiling basic block', sub.name def init_registers(self): self.reset_registers() @@ -601,6 +611,8 @@ class Tape: if len(block.instructions) > 10000: print 'Merging instructions...' numrounds = merger.longest_paths_merge() + block.n_rounds = numrounds + block.n_to_merge = len(merger.open_nodes) if numrounds > 0: print 'Program requires %d rounds of communication' % numrounds if merger.counter: @@ -633,10 +645,11 @@ class Tape: # allocate registers reg_counts = self.count_regs() if not options.noreallocate: - print 'Tape register usage:', dict(reg_counts) - print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) - print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) - print 'Re-allocating...' + if self.program.verbose: + print 'Tape register usage:', dict(reg_counts) + print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + print 'Re-allocating...' allocator = al.StraightlineAllocator(REG_MAX) def alloc_loop(block): for reg in sorted(block.used_from_scope, @@ -661,7 +674,7 @@ class Tape: self.req_num = self.req_tree.aggregate() print 'Tape requires', self.req_num for req,num in sorted(self.req_num.items()): - if num == float('inf'): + if num == float('inf') or num >= 2 ** 32: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( @@ -692,8 +705,9 @@ class Tape: self.basicblocks[-1].instructions.append( Compiler.instructions.reqbl(bl, add_to_prog=False)) - print 'Tape requires prime bit length', self.req_bit_length['p'] - print 'Tape requires galois bit length', self.req_bit_length['2'] + if self.program.verbose: + print 'Tape requires prime bit length', self.req_bit_length['p'] + print 'Tape requires galois bit length', self.req_bit_length['2'] @unpurged def _get_instructions(self): @@ -783,6 +797,8 @@ class Tape: return res __rmul__ = __mul__ def set_all(self, value): + if value == float('inf') and self['all', 'inv'] > 0: + print 'Going to unknown from %s' % self res = Tape.ReqNum() for i in self: res[i] = value @@ -832,6 +848,15 @@ class Tape: self.parent = parent def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) + try: + n_reps = self.aggregator([1]) + n_rounds = res['all', 'round'] + n_invs = res['all', 'inv'] + if (n_invs / n_rounds) * 1000 < n_reps: + print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ + '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs) + except: + pass return res def add_node(self, tape, name): new_node = Tape.ReqNode(name) diff --git a/Compiler/types.py b/Compiler/types.py index 68e51e32..3fa5992a 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -167,6 +167,12 @@ class _number(object): def pow2(self, bit_length=None, security=None): return 2**self + def min(self, other): + return (self < other).if_else(self, other) + + def max(self, other): + return (self < other).if_else(other, self) + class _int(object): def if_else(self, a, b): if hasattr(a, 'for_mux'): @@ -705,6 +711,10 @@ class regint(_register, _int): popint(res) return res + @vectorized_classmethod + def push(cls, value): + pushint(cls.conv(value)) + @vectorized_classmethod def get_random(cls, bit_length): """ Public insecure randomness """ @@ -781,10 +791,10 @@ class regint(_register, _int): @vectorize @read_mem_value def int_op(self, other, inst, reverse=False): - if isinstance(other, _secret): + try: + other = self.conv(other) + except: return NotImplemented - elif not isinstance(other, type(self)): - other = type(self)(other) res = regint() if reverse: inst(res, other, self) @@ -898,6 +908,8 @@ class regint(_register, _int): def print_reg_plain(self): print_int(self) + def print_if(self, string): + cint(self).print_if(string) class _secret(_register): __slots__ = [] @@ -1121,6 +1133,13 @@ class sint(_secret, _int): comparison.PRandInt(res, bits) return res + @vectorized_classmethod + def get_input_from(cls, player): + """ Secret input """ + res = cls() + inputmixed('int', res, player) + return res + @classmethod def get_raw_input_from(cls, player): res = cls() @@ -1196,8 +1215,8 @@ class sint(_secret, _int): @vectorize def __lt__(self, other, bit_length=None, security=None): res = sint() - comparison.LTZ(res, self - other, bit_length or program.bit_length + - (not (int(program.options.ring) == program.bit_length)), + comparison.LTZ(res, self - other, + (bit_length or program.bit_length) + 1, security or program.security) return res @@ -1205,8 +1224,8 @@ class sint(_secret, _int): @vectorize def __gt__(self, other, bit_length=None, security=None): res = sint() - comparison.LTZ(res, other - self, bit_length or program.bit_length + - (not (int(program.options.ring) == program.bit_length)), + comparison.LTZ(res, other - self, + (bit_length or program.bit_length) + 1, security or program.security) return res @@ -1304,13 +1323,14 @@ class sint(_secret, _int): return floatingpoint.BitDec(self, bit_length, bit_length, security) def TruncMul(self, other, k, m, kappa=None, nearest=False): - return (self * other).round(k, m, kappa, nearest) + return (self * other).round(k, m, kappa, nearest, signed=True) - def TruncPr(self, k, m, kappa=None): - return floatingpoint.TruncPr(self, k, m, kappa) + def TruncPr(self, k, m, kappa=None, signed=True): + return floatingpoint.TruncPr(self, k, m, kappa, signed=signed) @vectorize def round(self, k, m, kappa=None, nearest=False, signed=False): + kappa = kappa or program.security secret = isinstance(m, sint) if nearest: if secret: @@ -1320,7 +1340,7 @@ class sint(_secret, _int): else: if secret: return floatingpoint.Trunc(self, k, m, kappa) - return self.TruncPr(k, m, kappa) + return self.TruncPr(k, m, kappa, signed=signed) def Norm(self, k, f, kappa=None, simplex_flag=False): return library.Norm(self, k, f, kappa, simplex_flag) @@ -1461,23 +1481,25 @@ class _bitint(object): linear_rounds = False @classmethod - def bit_adder(cls, a, b): + def bit_adder(cls, a, b, carry_in=0, get_carry=False): a, b = list(a), list(b) a += [0] * (len(b) - len(a)) b += [0] * (len(a) - len(b)) - return cls.bit_adder_selection(a, b) + return cls.bit_adder_selection(a, b, carry_in=carry_in, + get_carry=get_carry) @classmethod - def bit_adder_selection(cls, a, b): + def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False): if cls.log_rounds: - return cls.carry_lookahead_adder(a, b) + return cls.carry_lookahead_adder(a, b, carry_in=carry_in) elif cls.linear_rounds: - return cls.ripple_carry_adder(a, b) + return cls.ripple_carry_adder(a, b, carry_in=carry_in) else: - return cls.carry_select_adder(a, b) + return cls.carry_select_adder(a, b, carry_in=carry_in) @classmethod - def carry_lookahead_adder(cls, a, b, fewer_inv=False): + def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, + get_carry=False): lower = [] for (ai,bi) in zip(a,b): if ai is 0 or bi is 0: @@ -1493,10 +1515,13 @@ class _bitint(object): else: pre_op = floatingpoint.PreOpL if d: - carries = (0,) + zip(*pre_op(carry, d))[1] + carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1] else: carries = [] - return lower + cls.sum_from_carries(a, b, carries) + res = lower + cls.sum_from_carries(a, b, carries) + if get_carry: + res += [carries[-1]] + return res @staticmethod def sum_from_carries(a, b, carries): @@ -1504,7 +1529,7 @@ class _bitint(object): for (ai, bi, carry) in zip(a, b, carries)] @classmethod - def carry_select_adder(cls, a, b, get_carry=False): + def carry_select_adder(cls, a, b, get_carry=False, carry_in=0): a += [0] * (len(b) - len(a)) b += [0] * (len(a) - len(b)) n = len(a) @@ -1524,7 +1549,7 @@ class _bitint(object): raise Exception('blocks not summing up: %s != %s' % \ (sum(blocks), n)) res = [] - carry = 0 + carry = carry_in cin_one = util.long_one(a + b) for m in blocks: aa = a[:m] @@ -1540,7 +1565,8 @@ class _bitint(object): return res @classmethod - def ripple_carry_adder(cls, a, b, carry=0): + def ripple_carry_adder(cls, a, b, carry_in=0): + carry = carry_in res = [] for aa, bb in zip(a, b): cc, carry = cls.full_adder(aa, bb, carry) @@ -1760,14 +1786,15 @@ class intbitint(_bitint, sint): for i in range(len(a))] @classmethod - def bit_adder_selection(cls, a, b): + def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False): if cls.linear_rounds: - return cls.ripple_carry_adder(a, b) + return cls.ripple_carry_adder(a, b, carry_in=carry_in) # experimental cut-off with dead code elimination elif len(a) < 122 or cls.log_rounds: - return cls.carry_lookahead_adder(a, b) + return cls.carry_lookahead_adder(a, b, carry_in=carry_in, + get_carry=get_carry) else: - return cls.carry_select_adder(a, b) + return cls.carry_select_adder(a, b, carry_in=carry_in) class sgf2nint(_bitint, sgf2n): bin_type = sgf2n @@ -1904,10 +1931,10 @@ class sgf2nfloat(sgf2n): sgf2nfloat.set_precision(24, 8) -def parse_type(other): +def parse_type(other, k=None, f=None): # converts type to cfix/sfix depending on the case if isinstance(other, cfix.scalars): - return cfix(other) + return cfix(other, k=k, f=f) elif isinstance(other, cint): tmp = cfix() tmp.load_int(other) @@ -1975,9 +2002,11 @@ class cfix(_number, _structure): return 1 @vectorize_init - def __init__(self, v=None, size=None): - f = self.f - k = self.k + def __init__(self, v=None, k=None, f=None, size=None): + f = f or self.f + k = k or self.k + self.f = f + self.k = k self.size = get_global_vector_size() if isinstance(v, cint): self.v = cint(v,size=self.size) @@ -2025,22 +2054,20 @@ class cfix(_number, _structure): other = parse_type(other) if isinstance(other, cfix): return cfix(self.v + other.v) - elif isinstance(other, sfix): - return sfix(self.v + other.v) else: - raise CompilerError('Invalid type %s for cfix.__add__' % type(other)) + return NotImplemented @vectorize def mul(self, other): other = parse_type(other) if isinstance(other, cfix): + assert self.f == other.f sgn = cint(1 - 2 * (self.v * other.v < 0)) absolute = self.v * other.v * sgn val = sgn * (absolute >> self.f) return cfix(val) elif isinstance(other, sfix): - res = sfix((self.v * other.v) >> self.f) - return res + return NotImplemented else: raise CompilerError('Invalid type %s for cfix.__mul__' % type(other)) @@ -2130,7 +2157,8 @@ class cfix(_number, _structure): if isinstance(other, cfix): return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f)) elif isinstance(other, sfix): - return sfix(library.FPDiv(self.v, other.v, self.k, self.f, other.kappa)) + return sfix(library.FPDiv(self.v, other.v, self.k, self.f, + other.kappa, nearest=sfix.round_nearest)) else: raise TypeError('Incompatible fixed point types in division') @@ -2169,6 +2197,7 @@ class _single(_number, _structure): return cls._new(cls.int_type.load_mem(address)) @classmethod + @read_mem_value def conv(cls, other): if isinstance(other, cls): return other @@ -2193,9 +2222,13 @@ class _single(_number, _structure): @classmethod def dot_product(cls, x, y, res_params=None): + return cls.unreduced_dot_product(x, y, res_params).reduce_after_mul() + + @classmethod + def unreduced_dot_product(cls, x, y, res_params=None): dp = cls.int_type.dot_product([xx.pre_mul() for xx in x], [yy.pre_mul() for yy in y]) - return x[0].unreduced(dp, y[0], res_params, len(x)).reduce_after_mul() + return x[0].unreduced(dp, y[0], res_params, len(x)) @classmethod def row_matrix_mul(cls, row, matrix, res_params=None): @@ -2300,20 +2333,25 @@ class _fix(_single): @classmethod def coerce(cls, other): - if isinstance(other, _fix): + if isinstance(other, (_fix, cfix)): return other else: return cls.conv(other) @classmethod - def from_sint(cls, other): + def from_sint(cls, other, k=None, f=None): res = cls() + res.f = f or cls.f + res.k = k or cls.k res.load_int(cls.int_type.conv(other)) return res @classmethod - def _new(cls, other): - return cls(other) + def _new(cls, other, k=None, f=None): + res = cls(other) + res.k = k or cls.k + res.f = f or cls.f + return res @vectorize_init def __init__(self, _v=None, size=None): @@ -2331,7 +2369,7 @@ class _fix(_single): self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size) elif isinstance(_v, self.float_type): p = (f + _v.p) - b = (p >= 0) + b = (p.greater_equal(0, _v.vlen)) a = b*(_v.v << (p)) + (1-b)*(_v.v >> (-p)) self.v = (1-2*_v.s)*a elif isinstance(_v, type(self)): @@ -2355,31 +2393,45 @@ class _fix(_single): def add(self, other): other = self.coerce(other) if isinstance(other, (_fix, cfix)): - return type(self)(self.v + other.v) + return self._new(self.v + other.v, k=self.k, f=self.f) elif isinstance(other, cfix.scalars): - tmp = cfix(other) + tmp = cfix(other, k=self.k, f=self.f) return self + tmp else: - raise CompilerError('Invalid type %s for _fix.__add__' % type(other)) + return NotImplemented @vectorize def mul(self, other): + if isinstance(other, (sint, cint, regint, int, long)): + return self._new(self.v * other, k=self.k, f=self.f) + elif isinstance(other, float): + if int(other) == other: + return self.mul(int(other)) + v = int(round(other * 2 ** self.f)) + if v == 0: + return 0 + f = self.f + while v % 2 == 0: + f -= 1 + v /= 2 + k = len(bin(abs(v))) - 1 + other = cfix(cint(v)) + other.f = f + other.k = k other = self.coerce(other) - if isinstance(other, _fix): - val = self.v.TruncMul(other.v, self.k * 2, self.f, self.kappa, + if isinstance(other, (_fix, cfix)): + val = self.v.TruncMul(other.v, self.k + other.k, other.f, + self.kappa, self.round_nearest) if self.size >= other.size: - return self._new(val) + return self._new(val, k=self.k, f=self.f) else: - return self.vec._new(val) - elif isinstance(other, cfix): - res = type(self)((self.v * other.v) >> self.f) - return res + return self.vec._new(val, k=self.k, f=self.f) elif isinstance(other, cfix.scalars): scalar_fix = cfix(other) return self * scalar_fix else: - raise CompilerError('Invalid type %s for _fix.__mul__' % type(other)) + return NotImplemented @vectorize def __neg__(self): @@ -2397,6 +2449,7 @@ class _fix(_single): else: raise TypeError('Incompatible fixed point types in division') + @vectorize def __rdiv__(self, other): return self.coerce(other) / self @@ -2418,14 +2471,24 @@ class sfix(_fix): @vectorized_classmethod def get_input_from(cls, player): v = cls.int_type() - inputfix(v, cls.f, player) + inputmixed('fix', v, cls.f, player) return cls._new(v) - @classmethod - def coerce(cls, other): - return parse_type(other) + @vectorized_classmethod + def get_random(cls, lower, upper): + """ Uniform random number around centre of bounds """ + """ Range can be smaller """ + log_range = int(math.log(upper - lower, 2)) + n_bits = log_range + cls.f + average = lower + 0.5 * (upper - lower) + lower = average - 0.5 * 2 ** log_range + return cls._new(cls.int_type.get_random_int(n_bits)) + lower + + def coerce(self, other): + return parse_type(other, k=self.k, f=self.f) def mul_no_reduce(self, other, res_params=None): + assert self.f == other.f return self.unreduced(self.v * other.v) def pre_mul(self): @@ -2434,16 +2497,21 @@ class sfix(_fix): def unreduced(self, v, other=None, res_params=None, n_summands=1): return unreduced_sfix(v, self.k * 2, self.f, self.kappa) -class unreduced_sfix(object): +class unreduced_sfix(_single): + int_type = sint + + @classmethod + def _new(cls, v): + return cls(v, 2 * sfix.k, sfix.f, sfix.kappa) + def __init__(self, v, k, m, kappa): self.v = v self.k = k self.m = m self.kappa = kappa - self.size = self.v.size def __add__(self, other): - if other in (0, 0L): + if other is 0 or other is 0L: return self assert self.k == other.k assert self.m == other.m @@ -2455,7 +2523,10 @@ class unreduced_sfix(object): @vectorize def reduce_after_mul(self): return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa, - nearest=sfix.round_nearest)) + nearest=sfix.round_nearest, + signed=True)) + +sfix.unreduced_type = unreduced_sfix # this is for 20 bit decimal precision # with 40 bitlength of entire number @@ -2503,7 +2574,7 @@ class squant(_single): raise CompilerError('%f not quantizable' % value) self.v = self.int_type(q) reset_global_vector_size() - elif isinstance(value, type(self)): + elif isinstance(value, squant) and value.params == self.params: self.v = value.v else: raise CompilerError('cannot convert %s to squant' % value) @@ -2538,7 +2609,7 @@ class squant(_single): return self.mul_no_reduce(other, res_params).reduce_after_mul() def mul_no_reduce(self, other, res_params=None): - if isinstance(other, sint): + if isinstance(other, (sint, cint, regint)): return self._new(other * (self.v - self.Z) + self.Z, params=self.get_params()) other = self.coerce(other) @@ -2572,7 +2643,7 @@ class _unreduced_squant(object): self.res_params = res_params or params[0] def __add__(self, other): - if other in (0, 0L): + if other is 0 or other is 0L: return self assert self.params == other.params assert self.res_params == other.res_params @@ -2650,7 +2721,8 @@ class squant_params(object): int_mult = util.expand(int_mult, size) tmp = unreduced.v * int_mult + shifted_Z shifted = tmp.round(self.max_length, n_shift, - squant.kappa, squant.round_nearest) + kappa=squant.kappa, nearest=squant.round_nearest, + signed=True) if squant.clamp: length = max(self.k, self.max_length - n_shift) + 1 top = (1 << self.k) - 1 @@ -2715,6 +2787,10 @@ class sfloat(_number, _structure): else: return cls(other) + @classmethod + def coerce(cls, other): + return cls.conv(other) + @staticmethod def convert_float(v, vlen, plen): if v < 0: @@ -2747,7 +2823,7 @@ class sfloat(_number, _structure): p = sint() z = sint() s = sint() - inputfloat(v, p, z, s, cls.vlen, player) + inputmixed('float', v, p, z, s, cls.vlen, player) return cls(v, p, z, s) @vectorize_init @@ -2935,7 +3011,7 @@ class sfloat(_number, _structure): return self + -other def __rsub__(self, other): - raise NotImplementedError() + return -self + other def __div__(self, other): other = self.conv(other) @@ -3144,17 +3220,18 @@ class Array(object): for i in range(self.length): yield self[i] - def assign(self, other): - if isinstance(other, Array): - def loop(i): - self[i] = other[i] - library.range_loop(loop, len(self)) - elif isinstance(other, Tape.Register): - if len(other) == self.length: - self[0] = other - else: - raise CompilerError('Length mismatch between array and vector') - else: + def same_shape(self): + return Array(self.length, self.value_type) + + def assign(self, other, base=0): + try: + other = other.get_vector() + except: + pass + try: + other.store_in_mem(self.get_address(base)) + assert len(self) >= other.size + base + except AttributeError: for i,j in enumerate(other): self[i] = j return self @@ -3169,22 +3246,42 @@ class Array(object): self[i] = mem_value return self - def get_vector(self): - return self.value_type.load_mem(self.address, size=self.length) + def get_vector(self, base=0, size=None): + size = size or self.length + return self.value_type.load_mem(self.get_address(base), size=size) def get_mem_value(self, index): return MemValue(self[index], self.get_address(index)) + def input_from(self, player, budget=None): + self.assign(self.value_type.get_input_from(player, size=len(self))) + def __add__(self, other): + if other is 0: + return self assert len(self) == len(other) - return Array.create_from(x + y for x, y in zip(self, other)) + return self.get_vector() + other def __sub__(self, other): assert len(self) == len(other) - return Array.create_from(x - y for x, y in zip(self, other)) + return self.get_vector() - other def __mul__(self, value): - return Array.create_from(x * value for x in self) + return self.get_vector() * value + + def __pow__(self, value): + return self.get_vector() ** value + + __radd__ = __add__ + __rmul__ = __mul__ + + def shuffle(self): + @library.for_range(len(self)) + def _(i): + j = regint.get_random(64) % (len(self) - i) + tmp = self[i] + self[i] = self[i + j] + self[i + j] = tmp def reveal(self): return Array.create_from(x.reveal() for x in self) @@ -3223,6 +3320,9 @@ class SubMultiArray(object): self.address, index, debug=self.debug) return self.sub_cache[key] + def __setitem__(self, index, other): + self[index].assign(other) + def __len__(self): return self.sizes[0] @@ -3235,35 +3335,60 @@ class SubMultiArray(object): def total_size(self): return reduce(operator.mul, self.sizes) * self.value_type.n_elements() - def get_vector(self): - return self.value_type.load_mem(self.address, size=self.total_size()) + def get_vector(self, base=0, size=None): + assert self.value_type.n_elements() == 1 + size = size or self.total_size() + return self.value_type.load_mem(self.address + base, size=size) - def assign_vector(self, vector): - assert vector.size == self.total_size() - vector.store_in_mem(self.address) + def assign_vector(self, vector, base=0): + assert self.value_type.n_elements() == 1 + assert vector.size <= self.total_size() + vector.store_in_mem(self.address + base) -class MultiArray(SubMultiArray): - def __init__(self, sizes, value_type, debug=None, address=None): - self.array = Array(reduce(operator.mul, sizes), \ - value_type, address=address) - SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \ - debug=debug) - if len(sizes) < 2: - raise CompilerError('Use Array') + def assign(self, other): + if self.value_type.n_elements() > 1: + assert self.sizes == other.sizes + self.assign_vector(other.get_vector()) -class Matrix(MultiArray): - def __init__(self, rows, columns, value_type, debug=None, address=None): - MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ - address=address) + def same_shape(self): + return MultiArray(self.sizes, self.value_type) - def __setitem__(self, index, other): - assert other.size == self.sizes[1] - other.store_in_mem(self[index].address) + def input_from(self, player, budget=None): + @library.for_range_opt(self.sizes[0], budget=budget) + def _(i): + self[i].input_from(player, budget=budget) + + def schur(self, other): + assert self.sizes == other.sizes + if len(self.sizes) == 2: + res = Matrix(self.sizes[0], self.sizes[1], self.value_type) + else: + res = MultiArray(self.sizes, self.value_type) + res.assign_vector(self.get_vector() * other.get_vector()) + return res + + def __add__(self, other): + if other is 0: + return self + assert self.sizes == other.sizes + if len(self.sizes) == 2: + res = Matrix(self.sizes[0], self.sizes[1], self.value_type) + else: + res = MultiArray(self.sizes, self.value_type) + res.assign_vector(self.get_vector() + other.get_vector()) + return res + + __radd__ = __add__ + + def iadd(self, other): + assert self.sizes == other.sizes + self.assign_vector(self.get_vector() + other.get_vector()) def __mul__(self, other): return self.mul(other) def mul(self, other, res_params=None): + assert len(self.sizes) == 2 if isinstance(other, Array): assert len(other) == self.sizes[1] if self.value_type.n_elements() == 1: @@ -3277,7 +3402,8 @@ class Matrix(MultiArray): matrix[i][0] = x res = self * matrix return Array.create_from(x[0] for x in res) - elif isinstance(other, Matrix): + elif isinstance(other, SubMultiArray): + assert len(other.sizes) == 2 assert other.sizes[0] == self.sizes[1] if res_params is not None: class t(self.value_type): @@ -3287,14 +3413,16 @@ class Matrix(MultiArray): t = self.value_type res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: + if max(res_matrix.sizes) > 1000: + raise AttributeError() A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( self.value_type.matrix_mul(A, B, self.sizes[1], res_params)) - except AttributeError: + except (AttributeError, AssertionError): # fallback for sfloat etc. - @library.for_range(self.sizes[0]) + @library.for_range_opt(self.sizes[0]) def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( @@ -3311,6 +3439,78 @@ class Matrix(MultiArray): else: raise NotImplementedError + def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True, + res=None): + assert len(self.sizes) == 2 + assert len(other.sizes) == 2 + if res is None: + if reduce: + res_matrix = Matrix(n_rows, n_columns, self.value_type) + else: + res_matrix = Matrix(n_rows, n_columns, \ + self.value_type.unreduced_type) + else: + res_matrix = res + @library.for_range_opt(n_rows) + def _(i): + @library.for_range_opt(n_columns) + def _(j): + col = column(other, j) + r = row(self, i) + if reduce: + res_matrix[i][j] = self.value_type.dot_product(r, col) + else: + entry = self.value_type.unreduced_dot_product(r, col) + res_matrix[i][j] = entry + return res_matrix + + def plain_mul(self, other, res=None): + assert other.sizes[0] == self.sizes[1] + return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \ + other.sizes[1], \ + lambda x, j: [x[k][j] for k in range(len(x))], + res=res) + + def mul_trans(self, other): + assert other.sizes[1] == self.sizes[1] + return self.budget_mul(other, self.sizes[0], lambda x, i: x[i], \ + other.sizes[0], lambda x, j: x[j]) + + def trans_mul(self, other, reduce=True, res=None): + assert other.sizes[0] == self.sizes[0] + return self.budget_mul(other, self.sizes[1], \ + lambda x, j: [x[k][j] for k in range(len(x))], \ + other.sizes[1], \ + lambda x, j: [x[k][j] for k in range(len(x))], + reduce=reduce, res=res) + + def transpose(self): + assert len(self.sizes) == 2 + res = Matrix(self.sizes[1], self.sizes[0], self.value_type) + @library.for_range_opt(self.sizes[1]) + def _(i): + @library.for_range_opt(self.sizes[0]) + def _(j): + res[i][j] = self[j][i] + return res + +class MultiArray(SubMultiArray): + def __init__(self, sizes, value_type, debug=None, address=None): + if isinstance(address, Array): + self.array = address + else: + self.array = Array(reduce(operator.mul, sizes), \ + value_type, address=address) + SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \ + debug=debug) + if len(sizes) < 2: + raise CompilerError('Use Array') + +class Matrix(MultiArray): + def __init__(self, rows, columns, value_type, debug=None, address=None): + MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ + address=address) + class VectorArray(object): def __init__(self, length, value_type, vector_size, address=None): self.array = Array(length * vector_size, value_type, address) diff --git a/ECDSA/README.md b/ECDSA/README.md index d81699af..d7db6191 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -1,7 +1,7 @@ This directory contains the code used for the benchmarks by [Dalskov et al.](https://eprint.iacr.org/2019/889) `*-ecdsa-party.cpp` contains the high-level programs while the two phases are implemented -`preprocessing.hpp` and `sign.hpp`, respectively. +in `preprocessing.hpp` and `sign.hpp`, respectively. #### Compilation diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 58e22c4c..7863edf0 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -104,7 +104,7 @@ class invalid_program: public exception class file_error: public exception { string filename, ans; public: - file_error(string m="") : filename(m) + file_error(string m) : filename(m) { ans="File Error : "; ans+=filename; diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index b8449ed8..c8c90270 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -36,7 +36,8 @@ bigint SemiHomomorphicNoiseBounds::min_p0() double SemiHomomorphicNoiseBounds::min_phi_m(int log_q) { - return 33.1 * (log_q - log2(3.2)); + // the constant was updated using Martin Albrecht's LWE estimator in Sep 2019 + return 37.8 * (log_q - log2(3.2)); } diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp index 94f2eb26..d3d43a0e 100644 --- a/FHEOffline/Sacrificing.cpp +++ b/FHEOffline/Sacrificing.cpp @@ -26,7 +26,7 @@ inline FileSacriFactory::FileSacriFactory(const char* type, const Player& P, if (output_thread) file1 << "-" << output_thread; this->inpf.open(file1.str().c_str(),ios::in | ios::binary); - if (this->inpf.fail()) { throw file_error(); } + if (this->inpf.fail()) { throw file_error(file1.str()); } } template @@ -221,12 +221,12 @@ void Triple_Checking(const Player& P,MAC_Check& MC,int nm) /* Open file for reading in the initial triples */ stringstream file1; file1 << PREP_DIR "Initial-Triples-" << file_completion(dummy) << "-P" << P.my_num(); ifstream inpf(file1.str().c_str(),ios::in | ios::binary); - if (inpf.fail()) { throw file_error(); } + if (inpf.fail()) { throw file_error(file1.str()); } /* Open file for writing out the final triples */ stringstream file3; file3 << PREP_DIR "Triples-" << file_completion(dummy) << "-P" << P.my_num(); ofstream outf(file3.str().c_str(),ios::out | ios::binary); - if (outf.fail()) { throw file_error(); } + if (outf.fail()) { throw file_error(file3.str()); } gf2n_short te,t; Create_Random(t,P); @@ -444,12 +444,12 @@ void Square_Checking(const Player& P,MAC_Check& MC,int ns) /* Open files for reading in the initial data */ stringstream file1; file1 << PREP_DIR "Initial-Squares-" << file_completion(dummy) << "-P" << P.my_num(); ifstream inpf_s(file1.str().c_str(),ios::in | ios::binary); - if (inpf_s.fail()) { throw file_error(); } + if (inpf_s.fail()) { throw file_error(file1.str()); } /* Open files for writing out the final data */ stringstream file3; file3 << PREP_DIR "Squares-" << file_completion(dummy) << "-P" << P.my_num(); ofstream outf_s(file3.str().c_str(),ios::out | ios::binary); - if (outf_s.fail()) { throw file_error(); } + if (outf_s.fail()) { throw file_error(file3.str()); } gf2n_short te,t,t2; Create_Random(t,P); diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index 64cc04e5..bd25a437 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -153,7 +153,7 @@ size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& cipherte others_ciphertexts.resize(this->sec, pk.get_params()); for (int i = 1; i < P.num_players(); i++) { -#ifdef VERBOSE +#ifdef VERBOSE_HE cerr << "Sending proof with " << 1e-9 * ciphertexts.get_length() << "+" << 1e-9 * cleartexts.get_length() << " GB" << endl; #endif @@ -164,7 +164,7 @@ size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& cipherte #ifndef LESS_ALLOC_MORE_MEM Verifier verifier(proof); #endif -#ifdef VERBOSE +#ifdef VERBOSE_HE cerr << "Checking proof of player " << i << endl; #endif timers["Verifying"].start(); diff --git a/GC/FakeSecret.cpp b/GC/FakeSecret.cpp index 0cf91e98..981ef22a 100644 --- a/GC/FakeSecret.cpp +++ b/GC/FakeSecret.cpp @@ -7,6 +7,8 @@ #include "GC/Processor.h" #include "GC/square64.h" +#include "GC/Processor.hpp" + namespace GC { @@ -14,7 +16,7 @@ int FakeSecret::default_length = 128; ostream& FakeSecret::out = cout; -void FakeSecret::load(int n, const Integer& x) +void FakeSecret::load_clear(int n, const Integer& x) { if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n)) throw out_of_range("public value too long"); diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index a44cf62a..f6a6d7b9 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -79,7 +79,7 @@ public: __uint128_t operator^=(const FakeSecret& other) { return a ^= other.a; } - void load(int n, const Integer& x); + void load_clear(int n, const Integer& x); template void load(int n, const Memory& mem, size_t address) { load(n, mem[address]); } template diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index 9e8f297d..add5c7b7 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -93,8 +93,8 @@ unsigned GC::Instruction::get_max_reg(int reg_type) const offset = 1; break; case INPUTB: - skip = 3; - offset = 2; + skip = 4; + offset = 3; break; case CONVCBIT: return BaseInstruction::get_max_reg(INT); diff --git a/GC/Machine.cpp b/GC/Machine.cpp index eed1c986..55d011bd 100644 --- a/GC/Machine.cpp +++ b/GC/Machine.cpp @@ -23,9 +23,6 @@ namespace GC { -extern template class ReplicatedSecret; -extern template class ReplicatedSecret; - #define GC_MACHINE(T) \ template class Instruction; \ template class Machine; \ @@ -34,8 +31,4 @@ extern template class ReplicatedSecret; template class Thread; \ template class ThreadMaster; \ -GC_MACHINE(FakeSecret); -GC_MACHINE(SemiHonestRepSecret); -GC_MACHINE(MaliciousRepSecret) - } /* namespace GC */ diff --git a/GC/MaliciousRepPrep.h b/GC/MaliciousRepPrep.h deleted file mode 100644 index 28e801a5..00000000 --- a/GC/MaliciousRepPrep.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * MaliciousRepPrep.h - * - */ - -#ifndef GC_MALICIOUSREPPREP_H_ -#define GC_MALICIOUSREPPREP_H_ - -#include "MaliciousRepSecret.h" -#include "Protocols/ReplicatedPrep.h" - -namespace GC -{ - -class MaliciousRepPrep : public BufferPrep -{ - ReplicatedBase* protocol; - -public: - MaliciousRepPrep(DataPositions& usage); - ~MaliciousRepPrep(); - - void set_protocol(MaliciousRepSecret::Protocol& protocol); - - void buffer_triples(); - void buffer_bits(); - - void buffer_squares() { throw not_implemented(); } - void buffer_inverses() { throw not_implemented(); } -}; - -} /* namespace GC */ - -#endif /* GC_MALICIOUSREPPREP_H_ */ diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index f56de415..5b4dd0b3 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -6,7 +6,7 @@ #ifndef GC_MALICIOUSREPSECRET_H_ #define GC_MALICIOUSREPSECRET_H_ -#include "ReplicatedSecret.h" +#include "ShareSecret.h" #include "Machine.h" #include "Protocols/Beaver.h" #include "Protocols/MaliciousRepMC.h" @@ -17,7 +17,8 @@ template class MaliciousRepMC; namespace GC { -class MaliciousRepThread; +template class ShareThread; +template class RepPrep; class MaliciousRepSecret : public ReplicatedSecret { @@ -30,7 +31,8 @@ public: typedef MC MAC_Check; typedef Beaver Protocol; - typedef NotImplementedInput Input; + typedef ReplicatedInput Input; + typedef RepPrep LivePrep; static MC* new_mc(Machine& machine) { diff --git a/GC/MaliciousRepThread.cpp b/GC/MaliciousRepThread.cpp deleted file mode 100644 index 7951b739..00000000 --- a/GC/MaliciousRepThread.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/* - * MalicousRepParty.cpp - * - */ - -#include "Protocols/MaliciousRepMC.h" -#include "MaliciousRepThread.h" -#include "Math/Setup.h" - -#include "Protocols/MaliciousRepMC.hpp" -#include "Protocols/MAC_Check_Base.hpp" -#include "Protocols/Beaver.hpp" -#include "Processor/Data_Files.hpp" - -namespace GC -{ - -thread_local MaliciousRepThread* MaliciousRepThread::singleton = 0; - -MaliciousRepThread::MaliciousRepThread(int i, - ThreadMaster& master) : - Thread(i, master), DataF(usage) -{ -} - -void MaliciousRepThread::pre_run() -{ - if (singleton) - throw runtime_error("there can only be one"); - singleton = this; - DataF.set_protocol(*protocol); -} - -void MaliciousRepThread::post_run() -{ -#ifndef INSECURE - cerr << "Removing used pre-processed data" << endl; - DataF.prune(); -#endif -} - -void MaliciousRepThread::and_(Processor& processor, - const vector& args, bool repeat) -{ - assert(P->num_players() == 3); - processor.check_args(args, 4); - protocol->init_mul(DataF, *MC); - for (size_t i = 0; i < args.size(); i += 4) - { - int n_bits = args[i]; - int left = args[i + 2]; - int right = args[i + 3]; - MaliciousRepSecret y_ext; - if (repeat) - y_ext = processor.S[right].extend_bit(); - else - y_ext = processor.S[right]; - protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits)); - } - - protocol->exchange(); - - for (size_t i = 0; i < args.size(); i += 4) - { - int n_bits = args[i]; - int out = args[i + 1]; - processor.S[out] = protocol->finalize_mul().mask(n_bits); - } -} - -} /* namespace GC */ diff --git a/GC/MaliciousRepThread.h b/GC/MaliciousRepThread.h deleted file mode 100644 index 69774b7a..00000000 --- a/GC/MaliciousRepThread.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * MalicousRepParty.h - * - */ - -#ifndef GC_MALICIOUSREPTHREAD_H_ -#define GC_MALICIOUSREPTHREAD_H_ - -#include "Thread.h" -#include "MaliciousRepSecret.h" -#include "MaliciousRepPrep.h" -#include "Processor/Data_Files.h" - -#include - -namespace GC -{ - -class MaliciousRepThread : public Thread -{ - static thread_local MaliciousRepThread* singleton; - -public: - static MaliciousRepThread& s(); - - DataPositions usage; - MaliciousRepPrep DataF; - - MaliciousRepThread(int i, ThreadMaster& master); - virtual ~MaliciousRepThread() {} - - void pre_run(); - void post_run(); - - void and_(Processor& processor, const vector& args, bool repeat); -}; - -inline MaliciousRepThread& MaliciousRepThread::s() -{ - if (singleton) - return *singleton; - else - throw runtime_error("no singleton"); -} - -} /* namespace GC */ - -#endif /* GC_MALICIOUSREPTHREAD_H_ */ diff --git a/GC/RepPrep.h b/GC/RepPrep.h new file mode 100644 index 00000000..5f1e0382 --- /dev/null +++ b/GC/RepPrep.h @@ -0,0 +1,46 @@ +/* + * MaliciousRepPrep.h + * + */ + +#ifndef GC_REPPREP_H_ +#define GC_REPPREP_H_ + +#include "MaliciousRepSecret.h" +#include "ShiftableTripleBuffer.h" +#include "Protocols/ReplicatedPrep.h" + +namespace GC +{ + +template +class RepPrep : public BufferPrep, ShiftableTripleBuffer +{ + ReplicatedBase* protocol; + +public: + RepPrep(DataPositions& usage, Thread& thread); + ~RepPrep(); + + void set_protocol(typename T::Protocol& protocol); + + void buffer_triples(); + void buffer_bits(); + + void buffer_squares() { throw not_implemented(); } + void buffer_inverses() { throw not_implemented(); } + + void get(Dtype type, T* data) + { + BufferPrep::get(type, data); + } + + array get_triple(int n_bits) + { + return ShiftableTripleBuffer::get_triple(n_bits); + } +}; + +} /* namespace GC */ + +#endif /* GC_REPPREP_H_ */ diff --git a/GC/MaliciousRepPrep.cpp b/GC/RepPrep.hpp similarity index 54% rename from GC/MaliciousRepPrep.cpp rename to GC/RepPrep.hpp index 26d1ee46..2c35b8e9 100644 --- a/GC/MaliciousRepPrep.cpp +++ b/GC/RepPrep.hpp @@ -3,8 +3,8 @@ * */ -#include "MaliciousRepPrep.h" -#include "MaliciousRepThread.h" +#include "RepPrep.h" +#include "ShareThread.h" #include "Processor/OnlineOptions.h" #include "Protocols/MalRepRingPrep.hpp" @@ -15,33 +15,40 @@ namespace GC { -MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage) : - BufferPrep(usage), protocol(0) +template +RepPrep::RepPrep(DataPositions& usage, Thread& thread) : + BufferPrep(usage), protocol(0) { + (void) thread; } -MaliciousRepPrep::~MaliciousRepPrep() +template +RepPrep::~RepPrep() { if (protocol) delete protocol; } -void MaliciousRepPrep::set_protocol(MaliciousRepSecret::Protocol& protocol) +template +void RepPrep::set_protocol(typename T::Protocol& protocol) { this->protocol = new ReplicatedBase(protocol.P); } -void MaliciousRepPrep::buffer_triples() +template +void RepPrep::buffer_triples() { assert(protocol != 0); - auto MC = MaliciousRepThread::s().new_mc(); - shuffle_triple_generation(triples, protocol->P, *MC, 64); + auto MC = ShareThread::s().new_mc(); + shuffle_triple_generation(this->triples, protocol->P, *MC, 64); delete MC; } -void MaliciousRepPrep::buffer_bits() +template +void RepPrep::buffer_bits() { assert(this->protocol != 0); + assert(this->protocol->P.num_players() == 3); for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) { this->bits.push_back({}); diff --git a/GC/ReplicatedParty.cpp b/GC/ReplicatedParty.cpp deleted file mode 100644 index 72b209c0..00000000 --- a/GC/ReplicatedParty.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/* - * ReplicatedParty.cpp - * - */ - -#include "ReplicatedParty.h" -#include "Thread.h" -#include "MaliciousRepThread.h" -#include "Networking/Server.h" -#include "Tools/ezOptionParser.h" -#include "Tools/benchmarking.h" - -namespace GC -{ - -template -ReplicatedParty::ReplicatedParty(int argc, const char** argv) : - ThreadMaster(online_opts), online_opts(opt, argc, argv) -{ - opt.add( - "localhost", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Host where party 0 is running (default: localhost)", // Help description. - "-h", // Flag token. - "--hostname" // Flag token. - ); - opt.add( - "5000", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Base port number (default: 5000).", // Help description. - "-pn", // Flag token. - "--portnum" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Unencrypted communication.", // Help description. - "-u", // Flag token. - "--unencrypted" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Check opening by communication instead of hashing.", // Help description. - "-c", // Flag token. - "--communication" // Flag token. - ); - online_opts.finalize(opt, argc, argv); - this->progname = online_opts.progname; - int my_num = online_opts.playerno; - int pnb; - string hostname; - opt.get("-pn")->getInt(pnb); - opt.get("-h")->getString(hostname); - this->machine.use_encryption = not opt.get("-u")->isSet; - this->machine.more_comm_less_comp = opt.get("-c")->isSet; - - T::out.activate(my_num == 0 or online_opts.interactive); - - if (not this->machine.use_encryption) - insecure("unencrypted communication"); - - Server* server = Server::start_networking(this->N, my_num, 3, hostname, pnb); - - this->run(); - - this->machine.write_memory(this->N.my_num()); - - if (server) - delete server; -} - -template<> -Thread* ReplicatedParty::new_thread(int i) -{ - return ThreadMaster::new_thread(i); -} - -template<> -Thread* ReplicatedParty::new_thread(int i) -{ - return new MaliciousRepThread(i, *this); -} - -template<> -void ReplicatedParty::post_run() -{ -} - -template<> -void ReplicatedParty::post_run() -{ - DataPositions usage; - for (auto thread : threads) - usage.increase(((MaliciousRepThread*)thread)->usage); - usage.print_cost(); -} - -extern template class ReplicatedSecret; -extern template class ReplicatedSecret; - -template class ReplicatedParty; -template class ReplicatedParty; - -} diff --git a/GC/ReplicatedParty.h b/GC/ReplicatedParty.h deleted file mode 100644 index a59793f9..00000000 --- a/GC/ReplicatedParty.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * ReplicatedParty.h - * - */ - -#ifndef GC_REPLICATEDPARTY_H_ -#define GC_REPLICATEDPARTY_H_ - -#include "Protocols/ReplicatedMC.h" -#include "Protocols/MaliciousRepMC.h" -#include "ReplicatedSecret.h" -#include "Processor.h" -#include "Program.h" -#include "Memory.h" -#include "ThreadMaster.h" - -namespace GC -{ - -template -class ReplicatedParty : public ThreadMaster -{ - ez::ezOptionParser opt; - OnlineOptions online_opts; - -public: - static Thread& s(); - - ReplicatedParty(int argc, const char** argv); - - Thread* new_thread(int i); - - void post_run(); -}; - -template -inline Thread& ReplicatedParty::s() -{ - return Thread::s(); -} - -} - -#endif /* GC_REPLICATEDPARTY_H_ */ diff --git a/GC/ReplicatedSecret.cpp b/GC/ReplicatedSecret.cpp deleted file mode 100644 index bae11162..00000000 --- a/GC/ReplicatedSecret.cpp +++ /dev/null @@ -1,254 +0,0 @@ -/* - * ReplicatedSecret.cpp - * - */ - -#include "ReplicatedSecret.h" -#include "ReplicatedParty.h" -#include "MaliciousRepSecret.h" -#include "Protocols/MaliciousRepMC.h" -#include "MaliciousRepThread.h" -#include "Thread.h" -#include "square64.h" - -#include "Protocols/Share.h" - -#include "Protocols/ReplicatedMC.hpp" -#include "Protocols/Replicated.hpp" - -namespace GC -{ - -template -int ReplicatedSecret::default_length = 8 * sizeof(ReplicatedSecret::value_type); - -template -SwitchableOutput ReplicatedSecret::out; - -template -void ReplicatedSecret::load(int n, const Integer& x) -{ - if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n)) - throw out_of_range("public value too long"); - *this = x; -} - -template -void ReplicatedSecret::bitcom(Memory& S, const vector& regs) -{ - *this = 0; - for (unsigned int i = 0; i < regs.size(); i++) - *this ^= (S[regs[i]] << i); -} - -template -void ReplicatedSecret::bitdec(Memory& S, const vector& regs) const -{ - for (unsigned int i = 0; i < regs.size(); i++) - S[regs[i]] = (*this >> i) & 1; -} - -template -void ReplicatedSecret::load(vector >& accesses, - const Memory& mem) -{ - for (auto access : accesses) - access.dest = mem[access.address]; -} - -template -void ReplicatedSecret::store(Memory& mem, - vector >& accesses) -{ - for (auto access : accesses) - mem[access.address] = access.source; -} - -template -void ReplicatedSecret::store_clear_in_dynamic(Memory& mem, - const vector& accesses) -{ - for (auto access : accesses) - mem[access.address] = access.value; -} - -template -void ReplicatedSecret::inputb(Processor& processor, - const vector& args) -{ - auto& party = ReplicatedParty::s(); - party.os.resize(2); - for (auto& o : party.os) - o.reset_write_head(); - - InputArgList a(args); - bool interactive = party.n_interactive_inputs_from_me(a) > 0; - - for (auto x : a) - { - if (x.from == party.P->my_num()) - { - auto& res = processor.S[x.dest]; - res.prepare_input(party.os, processor.get_input(x.params, interactive), x.n_bits, party.secure_prng); - } - } - - if (interactive) - cout << "Thank you" << endl; - - for (int i = 0; i < 2; i++) - party.P->pass_around(party.os[i], i + 1); - - for (auto x : a) - { - int from = x.from; - int n_bits = x.n_bits; - if (from != party.P->my_num()) - { - auto& res = processor.S[x.dest]; - res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits); - } - } -} - -template -U ReplicatedSecret::input(Processor& processor, const InputArgs& args) -{ - int from = args.from; - int n_bits = args.n_bits; - auto& party = ReplicatedParty::s(); - U res; - party.os.resize(2); - for (auto& o : party.os) - o.reset_write_head(); - if (from == party.P->my_num()) - { - res.prepare_input(party.os, processor.get_input(args.params), n_bits, party.secure_prng); - party.P->send_relative(party.os); - } - else - { - party.P->receive_player(from, party.os[0], true); - res.finalize_input(party, party.os[0], from, n_bits); - } - return res; -} - -template -void ReplicatedSecret::prepare_input(vector& os, long input, int n_bits, PRNG& secure_prng) -{ - randomize_to_sum(input, secure_prng); - *this &= get_mask(n_bits); - for (int i = 0; i < 2; i++) - BitVec(get_mask(n_bits) & (*this)[i]).pack(os[i], n_bits); -} - -template -void ReplicatedSecret::finalize_input(Thread& party, octetStream& o, int from, int n_bits) -{ - int j = party.P->get_offset(from) == 2; - (*this)[j] = BitVec::unpack_new(o, n_bits); - (*this)[1 - j] = 0; -} - -template -BitVec ReplicatedSecret::local_mul(const ReplicatedSecret& other) const -{ - return (*this)[0] * other.sum() + (*this)[1] * other[0]; -} - -template -void ReplicatedSecret::and_(int n, - const ReplicatedSecret& x, - const ReplicatedSecret& y, bool repeat) -{ - (void)n, (void)x, (void)y, (void)repeat; - throw runtime_error("use static method"); -} - -template<> -void ReplicatedSecret::and_(Processor& processor, - const vector& args, bool repeat) -{ - auto& party = Thread::s(); - assert(party.P->num_players() == 3); - processor.check_args(args, 4); - assert(party.protocol != 0); - auto& protocol = *party.protocol; - protocol.init_mul(); - for (size_t i = 0; i < args.size(); i += 4) - { - int n_bits = args[i]; - int left = args[i + 2]; - int right = args[i + 3]; - MaliciousRepSecret y_ext; - if (repeat) - y_ext = processor.S[right].extend_bit(); - else - y_ext = processor.S[right]; - protocol.prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits); - } - protocol.exchange(); - for (size_t i = 0; i < args.size(); i += 4) - processor.S[args[i + 1]] = protocol.finalize_mul(args[i]); -} - -template<> -void ReplicatedSecret::and_( - Processor& processor, const vector& args, - bool repeat) -{ - MaliciousRepThread::s().and_(processor, args, repeat); -} - -template -void ReplicatedSecret::trans(Processor& processor, - int n_outputs, const vector& args) -{ - assert(length == 2); - for (int k = 0; k < 2; k++) - { - square64 square; - for (size_t i = n_outputs; i < args.size(); i++) - square.rows[i - n_outputs] = processor.S[args[i]][k].get(); - square.transpose(args.size() - n_outputs, n_outputs); - for (int i = 0; i < n_outputs; i++) - processor.S[args[i]][k] = square.rows[i]; - } -} - -template -void ReplicatedSecret::reveal(size_t n_bits, Clear& x) -{ - (void) n_bits; - ReplicatedSecret share = *this; - vector opened; - auto& party = ReplicatedParty::s(); - party.MC->POpen_Begin(opened, {share}, *party.P); - party.MC->POpen_End(opened, {share}, *party.P); - x = IntBase(opened[0]); -} - -template<> -void ReplicatedSecret::random_bit() -{ - auto& party = ReplicatedParty::s(); - *this = party.secure_prng.get_bit(); - octetStream o; - (*this)[0].pack(o, 1); - party.P->pass_around(o, 1); - (*this)[1].unpack(o, 1); -} - -template<> -void ReplicatedSecret::random_bit() -{ - MaliciousRepSecret res; - MaliciousRepThread::s().DataF.get_one(DATA_BIT, res); - *this = res; -} - -template class ReplicatedSecret; -template class ReplicatedSecret; - -} diff --git a/GC/Secret.h b/GC/Secret.h index 024ecbd9..02def12d 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -83,6 +83,8 @@ public: static typename T::out_type out; + static const bool needs_ot = false; + static T& cast(T& reg) { return *reinterpret_cast(®); } static const T& cast(const T& reg) { return *reinterpret_cast(®); } @@ -98,34 +100,40 @@ public: static Secret carryless_mult(const Secret& x, const Secret& y); static void output(T& reg); - template - static void load(vector< ReadAccess< Secret > >& accesses, const U& mem); - template - static void store(U& mem, vector< WriteAccess< Secret > >& accesses); + template + static void load(vector< ReadAccess >& accesses, const U& mem); + template + static void store(U& mem, vector< WriteAccess >& accesses); - static void andrs(Processor< Secret >& processor, const vector& args) + template + static void andrs(Processor& processor, const vector& args) { T::andrs(processor, args); } - static void ands(Processor< Secret >& processor, const vector& args) + template + static void ands(Processor& processor, const vector& args) { T::ands(processor, args); } - static void inputb(Processor< Secret >& processor, const vector& args) + template + static void inputb(Processor& processor, const vector& args) { T::inputb(processor, args); } - static void trans(Processor >& processor, int n_inputs, const vector& args); + template + static void trans(Processor& processor, int n_inputs, const vector& args); static void convcbit(Integer& dest, const Clear& source) { T::convcbit(dest, source); } Secret(); Secret(const Integer& x) { *this = x; } - void load(int n, const Integer& x); - void operator=(const Integer& x) { load(default_length, x); } + void load_clear(int n, const Integer& x); + void operator=(const Integer& x) { load_clear(default_length, x); } void load(int n, const Memory& mem, size_t address); Secret operator<<(int i); Secret operator>>(int i); - void bitcom(Memory< Secret >& S, const vector& regs); - void bitdec(Memory< Secret >& S, const vector& regs) const; + template + void bitcom(Memory& S, const vector& regs); + template + void bitdec(Memory& S, const vector& regs) const; Secret operator+(const Secret x) const; Secret& operator+=(const Secret x) { *this = *this + x; return *this; } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 6782438f..90ba1ddb 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -102,9 +102,9 @@ void Secret::random_bit() } template -template +template void Secret::store(U& mem, - vector > >& accesses) + vector >& accesses) { T::store(mem, accesses); } @@ -194,7 +194,7 @@ T& GC::Secret::get_new_reg() } template -void Secret::load(int n, const Integer& x) +void Secret::load_clear(int n, const Integer& x) { if ((unsigned)n < 8 * sizeof(x) and abs(x.get()) > (1LL << n)) throw out_of_range("public value too long"); @@ -219,8 +219,8 @@ void Secret::load(int n, const Integer& x) } template -template -void Secret::load(vector > >& accesses, const U& mem) +template +void Secret::load(vector >& accesses, const U& mem) { for (auto&& access : accesses) { @@ -252,7 +252,8 @@ Secret Secret::operator>>(int i) } template -void Secret::bitcom(Memory& S, const vector& regs) +template +void Secret::bitcom(Memory& S, const vector& regs) { registers.clear(); for (unsigned int i = 0; i < regs.size(); i++) @@ -264,7 +265,8 @@ void Secret::bitcom(Memory& S, const vector& regs) } template -void Secret::bitdec(Memory& S, const vector& regs) const +template +void Secret::bitdec(Memory& S, const vector& regs) const { if (regs.size() > registers.size()) throw out_of_range( @@ -280,7 +282,8 @@ void Secret::bitdec(Memory& S, const vector& regs) const } template -void Secret::trans(Processor >& processor, int n_outputs, +template +void Secret::trans(Processor& processor, int n_outputs, const vector& args) { int n_inputs = args.size() - n_outputs; diff --git a/GC/SemiHonestRepPrep.cpp b/GC/SemiHonestRepPrep.cpp new file mode 100644 index 00000000..efb21d55 --- /dev/null +++ b/GC/SemiHonestRepPrep.cpp @@ -0,0 +1,11 @@ +/* + * ReplicatedPrep.cpp + * + */ + +#include + +namespace GC +{ + +} /* namespace GC */ diff --git a/GC/SemiHonestRepPrep.h b/GC/SemiHonestRepPrep.h new file mode 100644 index 00000000..678a436e --- /dev/null +++ b/GC/SemiHonestRepPrep.h @@ -0,0 +1,28 @@ +/* + * ReplicatedPrep.h + * + */ + +#ifndef GC_SEMIHONESTREPPREP_H_ +#define GC_SEMIHONESTREPPREP_H_ + +#include "RepPrep.h" +#include "ShareSecret.h" + +namespace GC +{ + +class SemiHonestRepPrep : public RepPrep +{ +public: + SemiHonestRepPrep(DataPositions& usage, Thread& thread) : + RepPrep(usage, thread) + { + } + + void buffer_triples() { throw not_implemented(); } +}; + +} /* namespace GC */ + +#endif /* GC_SEMIHONESTREPPREP_H_ */ diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp new file mode 100644 index 00000000..8b4de98b --- /dev/null +++ b/GC/SemiPrep.cpp @@ -0,0 +1,58 @@ +/* + * SemiPrep.cpp + * + */ + +#include "SemiPrep.h" +#include "ThreadMaster.h" +#include "OT/NPartyTripleGenerator.h" +#include "OT/BitDiagonal.h" + +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "OT/NPartyTripleGenerator.hpp" + +namespace GC +{ + +SemiPrep::SemiPrep(DataPositions& usage, Thread& thread) : + BufferPrep(usage), thread(thread), triple_generator(0) +{ +} + +void SemiPrep::set_protocol(Beaver& protocol) +{ + (void) protocol; + params.set_passive(); + triple_generator = new SemiSecret::TripleGenerator( + thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.master.N, thread.thread_num, thread.master.opts.batch_size, + 1, params, thread.P); + triple_generator->multi_threaded = false; +} + +void SemiPrep::buffer_triples() +{ + assert(this->triple_generator); + this->triple_generator->generatePlainTriples(); + for (auto& x : this->triple_generator->plainTriples) + { + this->triples.push_back({{x[0], x[1], x[2]}}); + } + this->triple_generator->unlock(); +} + +SemiPrep::~SemiPrep() +{ + if (triple_generator) + delete triple_generator; +} + +void SemiPrep::buffer_bits() +{ + word r = thread.secure_prng.get_word(); + for (size_t i = 0; i < sizeof(word) * 8; i++) + this->bits.push_back((r >> i) & 1); +} + +} /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h new file mode 100644 index 00000000..166e97e9 --- /dev/null +++ b/GC/SemiPrep.h @@ -0,0 +1,51 @@ +/* + * SemiPrep.h + * + */ + +#ifndef GC_SEMIPREP_H_ +#define GC_SEMIPREP_H_ + +#include "Protocols/ReplicatedPrep.h" +#include "OT/TripleMachine.h" +#include "SemiSecret.h" +#include "ShiftableTripleBuffer.h" + +template class Beaver; + +namespace GC +{ + +class SemiPrep : public BufferPrep, ShiftableTripleBuffer +{ + Thread& thread; + + SemiSecret::TripleGenerator* triple_generator; + MascotParams params; + +public: + SemiPrep(DataPositions& usage, Thread& thread); + ~SemiPrep(); + + void set_protocol(Beaver& protocol); + + void buffer_triples(); + void buffer_bits(); + + void buffer_squares() { throw not_implemented(); } + void buffer_inverses() { throw not_implemented(); } + + void get(Dtype type, SemiSecret* data) + { + BufferPrep::get(type, data); + } + + array get_triple(int n_bits) + { + return ShiftableTripleBuffer::get_triple(n_bits); + } +}; + +} /* namespace GC */ + +#endif /* GC_SEMIPREP_H_ */ diff --git a/GC/SemiSecret.cpp b/GC/SemiSecret.cpp new file mode 100644 index 00000000..56dc82cd --- /dev/null +++ b/GC/SemiSecret.cpp @@ -0,0 +1,52 @@ +/* + * SemiSecret.cpp + * + */ + +#include "GC/ShareParty.h" +#include "SemiSecret.h" + +#include "GC/ShareSecret.hpp" +#include "Protocols/MAC_Check_Base.hpp" + +namespace GC +{ + +void SemiSecret::trans(Processor& processor, int n_outputs, + const vector& args) +{ + square64 square; + for (size_t i = n_outputs; i < args.size(); i++) + square.rows[i - n_outputs] = processor.S[args[i]].get(); + square.transpose(args.size() - n_outputs, n_outputs); + for (int i = 0; i < n_outputs; i++) + processor.S[args[i]] = square.rows[i]; +} + +void SemiSecret::load_clear(int n, const Integer& x) +{ + check_length(n, x); + *this = constant(x, Thread::s().P->my_num()); +} + +void SemiSecret::bitcom(Memory& S, const vector& regs) +{ + *this = 0; + for (unsigned int i = 0; i < regs.size(); i++) + *this ^= (S[regs[i]] << i); +} + +void SemiSecret::bitdec(Memory& S, + const vector& regs) const +{ + for (unsigned int i = 0; i < regs.size(); i++) + S[regs[i]] = (*this >> i) & 1; +} + +void SemiSecret::reveal(size_t n_bits, Clear& x) +{ + auto& thread = Thread::s(); + x = thread.MC->POpen(*this, *thread.P).mask(n_bits); +} + +} /* namespace GC */ diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h new file mode 100644 index 00000000..a0f9e080 --- /dev/null +++ b/GC/SemiSecret.h @@ -0,0 +1,67 @@ +/* + * SemiSecret.h + * + */ + +#ifndef GC_SEMISECRET_H_ +#define GC_SEMISECRET_H_ + +#include "Protocols/SemiMC.h" +#include "Protocols/SemiShare.h" +#include "Processor/DummyProtocol.h" +#include "ShareSecret.h" + +template class Beaver; + +namespace GC +{ + +class SemiPrep; + +class SemiSecret : public SemiShare, public ShareSecret +{ +public: + typedef Memory DynamicMemory; + + typedef SemiMC MC; + typedef Beaver Protocol; + typedef MC MAC_Check; + typedef SemiPrep LivePrep; + typedef SemiInput Input; + + static const int default_length = sizeof(BitVec) * 8; + + static string type_string() { return "binary secret"; } + static string phase_name() { return "Binary computation"; } + + static MC* new_mc(Machine& _) { (void) _; return new MC; } + + static void trans(Processor& processor, int n_outputs, + const vector& args); + + SemiSecret() + { + } + SemiSecret(long other) : + SemiShare(other) + { + } + SemiSecret(const IntBase& other) : + SemiShare(other) + { + } + + void load_clear(int n, const Integer& x); + + void bitcom(Memory& S, const vector& regs); + void bitdec(Memory& S, const vector& regs) const; + + void xor_(int n, const SemiSecret& x, const SemiSecret& y) + { *this = BitVec(x ^ y).mask(n); } + + void reveal(size_t n_bits, Clear& x); +}; + +} /* namespace GC */ + +#endif /* GC_SEMISECRET_H_ */ diff --git a/GC/ShareParty.h b/GC/ShareParty.h new file mode 100644 index 00000000..9e7e3a56 --- /dev/null +++ b/GC/ShareParty.h @@ -0,0 +1,51 @@ +/* + * ReplicatedParty.h + * + */ + +#ifndef GC_SHAREPARTY_H_ +#define GC_SHAREPARTY_H_ + +#include "Protocols/ReplicatedMC.h" +#include "Protocols/MaliciousRepMC.h" +#include "ShareSecret.h" +#include "Processor.h" +#include "Program.h" +#include "Memory.h" +#include "ThreadMaster.h" + +namespace GC +{ + +template +class ShareParty : public ThreadMaster +{ + static ShareParty* singleton; + + ez::ezOptionParser opt; + OnlineOptions online_opts; + +public: + static ShareParty& s(); + + typename T::mac_key_type mac_key; + + ShareParty(int argc, const char** argv, int default_batch_size = 0); + + Thread* new_thread(int i); + + void post_run(); +}; + +template +inline ShareParty& ShareParty::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("no singleton"); +} + +} + +#endif /* GC_SHAREPARTY_H_ */ diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp new file mode 100644 index 00000000..2f96a2e7 --- /dev/null +++ b/GC/ShareParty.hpp @@ -0,0 +1,137 @@ +/* + * ReplicatedParty.cpp + * + */ + +#include "ShareParty.h" + +#include "Thread.h" +#include "ShareThread.h" +#include "SemiPrep.h" +#include "Networking/Server.h" +#include "Networking/CryptoPlayer.h" +#include "Tools/ezOptionParser.h" +#include "Tools/benchmarking.h" +#include "Tools/NetworkOptions.h" +#include "Protocols/fake-stuff.h" + +#include "ShareThread.hpp" +#include "RepPrep.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/fake-stuff.hpp" + +namespace GC +{ + +template +ShareParty* ShareParty::singleton = 0; + +template +ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : + ThreadMaster(online_opts), online_opts(opt, argc, argv, + default_batch_size) +{ + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; + + NetworkOptionsWithNumber network_opts(opt, argc, argv, + T::dishonest_majority ? 2 : 3, T::dishonest_majority); + if (T::dishonest_majority) + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use encrypted channels.", // Help description. + "-e", // Flag token. + "--encrypted" // Flag token. + ); + else + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Unencrypted communication.", // Help description. + "-u", // Flag token. + "--unencrypted" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Check opening by communication instead of hashing.", // Help description. + "-c", // Flag token. + "--communication" // Flag token. + ); + online_opts.finalize(opt, argc, argv); + this->progname = online_opts.progname; + int my_num = online_opts.playerno; + + if (T::dishonest_majority) + this->machine.use_encryption = opt.get("-e")->isSet; + else + this->machine.use_encryption = not opt.get("-u")->isSet; + + this->machine.more_comm_less_comp = opt.get("-c")->isSet; + + T::out.activate(my_num == 0 or online_opts.interactive); + + if (not this->machine.use_encryption and not T::dishonest_majority) + insecure("unencrypted communication"); + + Server* server = network_opts.start_networking(this->N, my_num); + + if (online_opts.live_prep) + if (T::needs_ot) + { + Player* P; + if (this->machine.use_encryption) + P = new CryptoPlayer(this->N, 0xFFFF); + else + P = new PlainPlayer(this->N, 0xFFFF); + for (int i = 0; i < this->machine.nthreads; i++) + this->machine.ot_setups.push_back({{{*P, true}}}); + delete P; + } + + try + { + gf2n _; + read_mac_keys(get_prep_dir(network_opts.nplayers, 128, 128), this->N, + this->mac_key, _); + } + catch (exception& e) + { + SeededPRNG G; + this->mac_key.randomize(G); + } + + this->run(); + + this->machine.write_memory(this->N.my_num()); + + if (server) + delete server; +} + +template +Thread* ShareParty::new_thread(int i) +{ + return new ShareThread(i, *this); +} + +template +void ShareParty::post_run() +{ + DataPositions usage; + for (auto thread : this->threads) + usage.increase(dynamic_cast*>(thread)->usage); + usage.print_cost(); +} + +} diff --git a/GC/ReplicatedSecret.h b/GC/ShareSecret.h similarity index 82% rename from GC/ReplicatedSecret.h rename to GC/ShareSecret.h index 7af08e77..dfa9ef7b 100644 --- a/GC/ReplicatedSecret.h +++ b/GC/ShareSecret.h @@ -3,8 +3,8 @@ * */ -#ifndef GC_REPLICATEDSECRET_H_ -#define GC_REPLICATEDSECRET_H_ +#ifndef GC_SHARESECRET_H_ +#define GC_SHARESECRET_H_ #include using namespace std; @@ -18,6 +18,7 @@ using namespace std; #include "Tools/SwitchableOutput.h" #include "Protocols/Replicated.h" #include "Protocols/ReplicatedMC.h" +#include "Processor/DummyProtocol.h" namespace GC { @@ -32,22 +33,9 @@ template class Machine; template -class ReplicatedSecret : public FixedVec +class ShareSecret { - typedef FixedVec super; - public: - typedef BitVec clear; - typedef BitVec open_type; - typedef BitVec mac_type; - typedef BitVec mac_key_type; - - typedef ReplicatedBase Protocol; - - static string type_string() { return "replicated secret"; } - static string phase_name() { return "Replicated computation"; } - - static int default_length; static SwitchableOutput out; static void store_clear_in_dynamic(Memory& mem, @@ -63,38 +51,59 @@ public: static void and_(Processor& processor, const vector& args, bool repeat); static void inputb(Processor& processor, const vector& args); - static void trans(Processor& processor, int n_outputs, - const vector& args); - static void convcbit(Integer& dest, const Clear& source) { dest = source; } static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } - static U input(Processor& processor, const InputArgs& args); - void prepare_input(vector& os, long input, int n_bits, PRNG& secure_prng); - void finalize_input(Thread& party, octetStream& o, int from, int n_bits); + void check_length(int n, const Integer& x); + + void random_bit(); +}; + +template +class ReplicatedSecret : public FixedVec, public ShareSecret +{ + typedef FixedVec super; + +public: + typedef BitVec clear; + typedef BitVec open_type; + typedef BitVec mac_type; + typedef BitVec mac_key_type; + + typedef ReplicatedBase Protocol; + + static const int N_BITS = clear::N_BITS; + + static const bool dishonest_majority = false; + static const bool needs_ot = false; + + static string type_string() { return "replicated secret"; } + static string phase_name() { return "Replicated computation"; } + + static int default_length; + + static void trans(Processor& processor, int n_outputs, + const vector& args); ReplicatedSecret() {} template ReplicatedSecret(const T& other) : super(other) {} - void load(int n, const Integer& x); + void load_clear(int n, const Integer& x); void bitcom(Memory& S, const vector& regs); void bitdec(Memory& S, const vector& regs) const; - void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y) - { *this = x ^ y; (void)n; } - void and_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y, bool repeat); - void andrs(int n, const ReplicatedSecret& x, const ReplicatedSecret& y); - BitVec local_mul(const ReplicatedSecret& other) const; - void reveal(size_t n_bits, Clear& x); + void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y) + { *this = x ^ y; (void)n; } - void random_bit(); + void reveal(size_t n_bits, Clear& x); }; +class SemiHonestRepPrep; class SemiHonestRepSecret : public ReplicatedSecret { @@ -106,6 +115,8 @@ public: typedef ReplicatedMC MC; typedef Replicated Protocol; typedef MC MAC_Check; + typedef SemiHonestRepPrep LivePrep; + typedef ReplicatedInput Input; static MC* new_mc(Machine& _) { (void) _; return new MC; } @@ -116,4 +127,4 @@ public: } -#endif /* GC_REPLICATEDSECRET_H_ */ +#endif /* GC_SHARESECRET_H_ */ diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp new file mode 100644 index 00000000..7a9c0b38 --- /dev/null +++ b/GC/ShareSecret.hpp @@ -0,0 +1,168 @@ +/* + * ReplicatedSecret.cpp + * + */ + +#include "ShareSecret.h" + +#include "MaliciousRepSecret.h" +#include "Protocols/MaliciousRepMC.h" +#include "ShareThread.h" +#include "Thread.h" +#include "square64.h" + +#include "Protocols/Share.h" + +#include "Protocols/ReplicatedMC.hpp" +#include "Protocols/Beaver.hpp" +#include "ShareParty.h" +#include "ShareThread.hpp" + +namespace GC +{ + +template +int ReplicatedSecret::default_length = 8 * sizeof(typename ReplicatedSecret::value_type); + +template +SwitchableOutput ShareSecret::out; + +template +void ShareSecret::check_length(int n, const Integer& x) +{ + if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n)) + throw out_of_range("public value too long"); +} + +template +void ReplicatedSecret::load_clear(int n, const Integer& x) +{ + this->check_length(n, x); + *this = x; +} + +template +void ReplicatedSecret::bitcom(Memory& S, const vector& regs) +{ + *this = 0; + for (unsigned int i = 0; i < regs.size(); i++) + *this ^= (S[regs[i]] << i); +} + +template +void ReplicatedSecret::bitdec(Memory& S, const vector& regs) const +{ + for (unsigned int i = 0; i < regs.size(); i++) + S[regs[i]] = (*this >> i) & 1; +} + +template +void ShareSecret::load(vector >& accesses, + const Memory& mem) +{ + for (auto access : accesses) + access.dest = mem[access.address]; +} + +template +void ShareSecret::store(Memory& mem, + vector >& accesses) +{ + for (auto access : accesses) + mem[access.address] = access.source; +} + +template +void ShareSecret::store_clear_in_dynamic(Memory& mem, + const vector& accesses) +{ + for (auto access : accesses) + mem[access.address] = access.value; +} + +template +void ShareSecret::inputb(Processor& processor, + const vector& args) +{ + auto& party = ShareThread::s(); + typename U::Input input(*party.MC, party.DataF, *party.P); + input.reset_all(*party.P); + + InputArgList a(args); + bool interactive = party.n_interactive_inputs_from_me(a) > 0; + + for (auto x : a) + { + if (x.from == party.P->my_num()) + { + input.add_mine(processor.get_input(x.params, interactive), x.n_bits); + } + else + input.add_other(x.from); + } + + if (interactive) + cout << "Thank you" << endl; + + input.exchange(); + + for (auto x : a) + { + int from = x.from; + int n_bits = x.n_bits; + auto& res = processor.S[x.dest]; + res = input.finalize(from, n_bits).mask(n_bits); + } +} + +template +BitVec ReplicatedSecret::local_mul(const ReplicatedSecret& other) const +{ + return (*this)[0] * other.sum() + (*this)[1] * other[0]; +} + +template +void ShareSecret::and_( + Processor& processor, const vector& args, + bool repeat) +{ + ShareThread::s().and_(processor, args, repeat); +} + +template +void ReplicatedSecret::trans(Processor& processor, + int n_outputs, const vector& args) +{ + assert(length == 2); + for (int k = 0; k < 2; k++) + { + square64 square; + for (size_t i = n_outputs; i < args.size(); i++) + square.rows[i - n_outputs] = processor.S[args[i]][k].get(); + square.transpose(args.size() - n_outputs, n_outputs); + for (int i = 0; i < n_outputs; i++) + processor.S[args[i]][k] = square.rows[i]; + } +} + +template +void ReplicatedSecret::reveal(size_t n_bits, Clear& x) +{ + (void) n_bits; + auto& share = *this; + vector opened; + auto& party = ShareThread::s(); + party.MC->POpen_Begin(opened, {share}, *party.P); + party.MC->POpen_End(opened, {share}, *party.P); + x = IntBase(opened[0]); +} + +template +void ShareSecret::random_bit() +{ + U res; + ShareThread::s().DataF.get_one(DATA_BIT, res); + *this = res; +} + +} diff --git a/GC/ShareThread.h b/GC/ShareThread.h new file mode 100644 index 00000000..b0eb12a2 --- /dev/null +++ b/GC/ShareThread.h @@ -0,0 +1,55 @@ +/* + * MalicousRepParty.h + * + */ + +#ifndef GC_SHARETHREAD_H_ +#define GC_SHARETHREAD_H_ + +#include "Thread.h" +#include "MaliciousRepSecret.h" +#include "RepPrep.h" +#include "SemiHonestRepPrep.h" +#include "Processor/Data_Files.h" +#include "Protocols/ReplicatedInput.h" + +#include + +namespace GC +{ + +template +class ShareThread : public Thread +{ + static thread_local ShareThread* singleton; + +public: + static ShareThread& s(); + + DataPositions usage; + Preprocessing& DataF; + + ShareThread(int i, ThreadMaster& master); + virtual ~ShareThread(); + + void pre_run(); + void post_run(); + + void and_(Processor& processor, const vector& args, bool repeat); +}; + +template +thread_local ShareThread* ShareThread::singleton = 0; + +template +inline ShareThread& ShareThread::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("no singleton"); +} + +} /* namespace GC */ + +#endif /* GC_SHARETHREAD_H_ */ diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp new file mode 100644 index 00000000..fa5c3915 --- /dev/null +++ b/GC/ShareThread.hpp @@ -0,0 +1,88 @@ +/* + * MalicousRepParty.cpp + * + */ + +#ifndef GC_SHARETHREAD_HPP_ +#define GC_SHARETHREAD_HPP_ + +#include +#include "Protocols/MaliciousRepMC.h" +#include "Math/Setup.h" + +#include "Processor/Data_Files.hpp" + +namespace GC +{ + +template +ShareThread::ShareThread(int i, + ThreadMaster& master) : + Thread(i, master), usage(master.N.num_players()), DataF( + master.opts.live_prep ? + *(Preprocessing*) new typename T::LivePrep(usage, + *this) : + *(Preprocessing*) new Sub_Data_Files(master.N, + get_prep_dir(master.N.num_players(), 128, 128), + usage)) +{ +} + +template +ShareThread::~ShareThread() +{ + delete &DataF; +} + +template +void ShareThread::pre_run() +{ + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; + assert(this->protocol != 0); + DataF.set_protocol(*this->protocol); +} + +template +void ShareThread::post_run() +{ +#ifndef INSECURE + cerr << "Removing used pre-processed data" << endl; + DataF.prune(); +#endif +} + +template +void ShareThread::and_(Processor& processor, + const vector& args, bool repeat) +{ + auto& protocol = this->protocol; + processor.check_args(args, 4); + protocol->init_mul(DataF, *this->MC); + for (size_t i = 0; i < args.size(); i += 4) + { + int n_bits = args[i]; + int left = args[i + 2]; + int right = args[i + 3]; + T y_ext; + if (repeat) + y_ext = processor.S[right].extend_bit(); + else + y_ext = processor.S[right]; + protocol->prepare_mul(processor.S[left].mask(n_bits), y_ext.mask(n_bits), n_bits); + } + + protocol->exchange(); + + for (size_t i = 0; i < args.size(); i += 4) + { + int n_bits = args[i]; + int out = args[i + 1]; + processor.S[out] = protocol->finalize_mul(n_bits); + } +} + +} /* namespace GC */ + +#endif diff --git a/GC/ShiftableTripleBuffer.h b/GC/ShiftableTripleBuffer.h new file mode 100644 index 00000000..ed56dc3e --- /dev/null +++ b/GC/ShiftableTripleBuffer.h @@ -0,0 +1,60 @@ +/* + * ShiftableTripleBuffer.h + * + */ + +#ifndef GC_SHIFTABLETRIPLEBUFFER_H_ +#define GC_SHIFTABLETRIPLEBUFFER_H_ + +#include "Math/FixedVec.h" + +#include + +namespace GC +{ + +template +class ShiftableTripleBuffer +{ + FixedVec triple_buffer; + int buffer_left; + + virtual void get(Dtype type, T* data) = 0; + +public: + ShiftableTripleBuffer() : + buffer_left(0) + { + } + + virtual ~ShiftableTripleBuffer() {} + + array get_triple(int n_bits) + { + int max_n_bits = T::N_BITS; + assert(n_bits <= max_n_bits); + assert(n_bits > 0); + array res; + + if (n_bits <= buffer_left) + { + res = triple_buffer.mask(n_bits).get(); + triple_buffer >>= n_bits; + buffer_left -= n_bits; + } + else + { + get(DATA_TRIPLE, res.data()); + FixedVec tmp = res; + res = tmp.mask(n_bits).get(); + triple_buffer += (tmp >> n_bits) << buffer_left; + buffer_left += max_n_bits - n_bits; + } + + return res; + } +}; + +} /* namespace GC */ + +#endif /* GC_SHIFTABLETRIPLEBUFFER_H_ */ diff --git a/GC/Thread.h b/GC/Thread.h index d0f35cd7..5546907f 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -39,7 +39,6 @@ public: Names& N; Player* P; PRNG secure_prng; - vector os; int thread_num; WaitQueue tape_schedule; diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 8ce35fc9..043d9355 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -56,6 +56,11 @@ void ThreadMaster::run() P = new PlainPlayer(N, 0xff << 24); machine.load_schedule(progname); + + if (T::needs_ot) + for (int i = 0; i < machine.nthreads; i++) + machine.ot_setups.push_back({{*P, true}, {*P, true}}); + for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); for (auto thread : threads) diff --git a/GC/TinyMC.cpp b/GC/TinyMC.cpp new file mode 100644 index 00000000..ff432007 --- /dev/null +++ b/GC/TinyMC.cpp @@ -0,0 +1,11 @@ +/* + * TinyMC.cpp + * + */ + +#include "TinyMC.h" + +namespace GC +{ + +} /* namespace GC */ diff --git a/GC/TinyMC.h b/GC/TinyMC.h new file mode 100644 index 00000000..d3d45d30 --- /dev/null +++ b/GC/TinyMC.h @@ -0,0 +1,67 @@ +/* + * TinyMC.h + * + */ + +#ifndef GC_TINYMC_H_ +#define GC_TINYMC_H_ + +#include "Protocols/MAC_Check_Base.h" + +namespace GC +{ + +template +class TinyMC : public MAC_Check_Base +{ + typename T::part_type::MAC_Check part_MC; + vector part_values; + vector part_shares; + +public: + TinyMC(typename T::mac_key_type mac_key) : + part_MC(mac_key) + { + this->alphai = mac_key; + } + + typename T::part_type::MAC_Check& get_part_MC() + { + return part_MC; + } + + void POpen_Begin(vector& values, const vector& S, + const Player& P) + { + values.clear(); + part_shares.clear(); + for (auto& share : S) + for (auto& part : share.get_regs()) + part_shares.push_back(part); + part_MC.POpen_Begin(part_values, part_shares, P); + } + + void POpen_End(vector& values, const vector& S, + const Player& P) + { + values.clear(); + part_MC.POpen_End(part_values, part_shares, P); + int i = 0; + for (auto& share : S) + { + typename T::open_type opened = 0; + for (size_t j = 0; j < share.get_regs().size(); j++) + opened += typename T::open_type(part_values[i++].get_bit(0)) << j; + values.push_back(opened); + } + } + + void Check(const Player& P) + { + part_MC.Check(P); + } +}; + +} /* namespace GC */ + +#endif /* GC_TINYMC_H_ */ diff --git a/GC/TinyPrep.h b/GC/TinyPrep.h new file mode 100644 index 00000000..31e3ecca --- /dev/null +++ b/GC/TinyPrep.h @@ -0,0 +1,52 @@ +/* + * TinyPrep.h + * + */ + +#ifndef GC_TINYPREP_H_ +#define GC_TINYPREP_H_ + +#include "Thread.h" +#include "OT/TripleMachine.h" +#include "Protocols/Beaver.h" +#include "Protocols/ReplicatedPrep.h" +#include "Protocols/RandomPrep.h" + +namespace GC +{ + +template +class TinyPrep : public BufferPrep, public RandomPrep +{ + typedef Share> res_type; + + Thread& thread; + + typename T::TripleGenerator* triple_generator; + typename T::part_type::TripleGenerator* input_generator; + MascotParams params; + + vector> triple_buffer; + +public: + TinyPrep(DataPositions& usage, Thread& thread); + ~TinyPrep(); + + void set_protocol(Beaver& protocol); + + void buffer_triples(); + void buffer_bits(); + + void buffer_inputs(int player); + + void buffer_squares() { throw not_implemented(); } + void buffer_inverses() { throw not_implemented(); } + + typename T::part_type::super get_random(); + + array get_triple(int n_bits); +}; + +} /* namespace GC */ + +#endif /* GC_TINYPREP_H_ */ diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp new file mode 100644 index 00000000..0ff4c2a0 --- /dev/null +++ b/GC/TinyPrep.hpp @@ -0,0 +1,174 @@ +/* + * TinyPrep.cpp + * + */ + +#include "TinyPrep.h" + +namespace GC +{ + +template +TinyPrep::TinyPrep(DataPositions& usage, Thread& thread) : + BufferPrep(usage), thread(thread), triple_generator(0), + input_generator(0) +{ +} + +template +TinyPrep::~TinyPrep() +{ + if (triple_generator) + delete triple_generator; + if (input_generator) + delete input_generator; +} + +template +void TinyPrep::set_protocol(Beaver& protocol) +{ + (void) protocol; + params.generateMACs = true; + params.amplify = false; + params.check = false; + params.set_mac_key(thread.MC->get_alphai()); + triple_generator = new typename T::TripleGenerator( + thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.master.N, thread.thread_num, + thread.master.opts.batch_size, + 1, params, thread.P); + triple_generator->multi_threaded = false; + input_generator = new typename T::part_type::TripleGenerator( + thread.processor.machine.ot_setups.at(thread.thread_num).at(1), + thread.master.N, thread.thread_num, + thread.master.opts.batch_size, + 1, params, thread.P); + input_generator->multi_threaded = false; + thread.MC->get_part_MC().set_prep(*this); +} + +template +void TinyPrep::buffer_triples() +{ + auto& triple_generator = this->triple_generator; + params.generateBits = false; + vector> triples; + ShuffleSacrifice sacrifice; + while (int(triples.size()) < sacrifice.minimum_n_inputs()) + { + triple_generator->generatePlainTriples(); + triple_generator->unlock(); + assert(triple_generator->plainTriples.size() != 0); + for (size_t i = 0; i < triple_generator->plainTriples.size(); i++) + triple_generator->valueBits[2].set_portion(i, + triple_generator->plainTriples[i][2]); + triple_generator->run_multipliers({}); + for (size_t i = 0; i < triple_generator->plainTriples.size(); i++) + { + for (int j = 0; j < T::default_length; j++) + { + triples.push_back({}); + for (int k = 0; k < 3; k++) + { + auto& share = triples.back()[k]; + share.set_share( + triple_generator->plainTriples.at(i).at(k).get_bit( + j)); + typename T::part_type::mac_type mac; + mac = thread.MC->get_alphai() * share.get_share(); + for (auto& multiplier : triple_generator->ot_multipliers) + mac += multiplier->macs.at(k).at(i * T::default_length + j); + share.set_mac(mac); + } + } + } + } + sacrifice.triple_sacrifice(triples, triples, + *thread.P, thread.MC->get_part_MC()); + for (size_t i = 0; i < triples.size() / T::default_length; i++) + { + this->triples.push_back({}); + auto& triple = this->triples.back(); + for (auto& x : triple) + x.resize_regs(T::default_length); + for (int j = 0; j < T::default_length; j++) + { + auto& source_triple = triples[j + i * T::default_length]; + for (int k = 0; k < 3; k++) + triple[k].get_reg(j) = source_triple[k]; + } + } +} + +template +void TinyPrep::buffer_bits() +{ + auto tmp = BufferPrep::get_random_from_inputs(thread.P->num_players()); + for (auto& bit : tmp.get_regs()) + { + this->bits.push_back({}); + this->bits.back().resize_regs(1); + this->bits.back().get_reg(0) = bit; + } +} + +template +void TinyPrep::buffer_inputs(int player) +{ + auto& inputs = this->inputs; + inputs.resize(thread.P->num_players()); + assert(this->input_generator); + this->input_generator->generateInputs(player); + for (size_t i = 0; i < this->input_generator->inputs.size() / T::default_length; i++) + { + inputs[player].push_back({}); + inputs[player].back().share.resize_regs(T::default_length); + for (int j = 0; j < T::default_length; j++) + { + auto& source_input = this->input_generator->inputs[j + + i * T::default_length]; + inputs[player].back().share.get_reg(j) = res_type(source_input.share); + inputs[player].back().value ^= typename T::open_type( + source_input.value.get_bit(0)) << j; + } + } +} + +template +typename T::part_type::super GC::TinyPrep::get_random() +{ + T tmp; + this->get_one(DATA_BIT, tmp); + return tmp.get_reg(0); +} + +template +array TinyPrep::get_triple(int n_bits) +{ + assert(n_bits > 0); + while ((unsigned)n_bits > triple_buffer.size()) + { + array tmp; + this->get(DATA_TRIPLE, tmp.data()); + for (size_t i = 0; i < tmp[0].get_regs().size(); i++) + { + triple_buffer.push_back( + { {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} }); + } + } + + array res; + for (int j = 0; j < 3; j++) + res[j].resize_regs(n_bits); + + for (int i = 0; i < n_bits; i++) + { + for (int j = 0; j < 3; j++) + res[j].get_reg(i) = triple_buffer.back()[j]; + triple_buffer.pop_back(); + } + + return res; +} + +} /* namespace GC */ diff --git a/GC/TinySecret.cpp b/GC/TinySecret.cpp new file mode 100644 index 00000000..a8f78241 --- /dev/null +++ b/GC/TinySecret.cpp @@ -0,0 +1,11 @@ +/* + * TinySecret.cpp + * + */ + +#include "TinySecret.h" + +namespace GC +{ + +} /* namespace GC */ diff --git a/GC/TinySecret.h b/GC/TinySecret.h new file mode 100644 index 00000000..54fa4ca2 --- /dev/null +++ b/GC/TinySecret.h @@ -0,0 +1,163 @@ +/* + * TinySecret.h + * + */ + +#ifndef GC_TINYSECRET_H_ +#define GC_TINYSECRET_H_ + +#include "Secret.h" +#include "TinyShare.h" +#include "ShareParty.h" +#include "OT/Rectangle.h" +#include "OT/BitDiagonal.h" + +template class NPartyTripleGenerator; +template class OTTripleGenerator; +template class TinyMultiplier; + +namespace GC +{ + +template class TinyPrep; +template class TinyMC; + +template +class TinySecret : public Secret> +{ + typedef TinySecret This; + +public: + typedef TinyShare part_type; + typedef Secret super; + + typedef typename part_type::mac_key_type mac_key_type; + + typedef BitVec open_type; + typedef BitVec clear; + + typedef TinyMC MC; + typedef MC MAC_Check; + typedef Beaver Protocol; + typedef ::Input Input; + typedef TinyPrep LivePrep; + typedef Memory DynamicMemory; + + typedef OTTripleGenerator TripleGenerator; + typedef TinyMultiplier Multiplier; + typedef typename part_type::sacri_type sacri_type; + typedef typename part_type::mac_type mac_type; + typedef BitDiagonal Rectangle; + + static const bool dishonest_majority = true; + static const bool needs_ot = true; + + static const int default_length = 64; + + static string type_short() + { + return "T"; + } + + static DataFieldType field_type() + { + return BitVec::field_type(); + } + + static int size() + { + return part_type::size() * default_length; + } + + static MC* new_mc(Machine& machine) + { + (void) machine; + return new MC(ShareParty::s().mac_key); + } + + static void store_clear_in_dynamic(Memory& mem, + const vector& accesses) + { + auto& party = ShareThread::s(); + for (auto access : accesses) + mem[access.address] = constant(access.value, party.P->my_num(), + {}); + } + + static This constant(BitVec other, int my_num, mac_key_type alphai) + { + This res; + res.resize_regs(other.length()); + for (int i = 0; i < other.length(); i++) + res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai); + return res; + } + + TinySecret() + { + } + TinySecret(const super& other) : + super(other) + { + } + + void assign(const char* buffer) + { + this->resize_regs(default_length); + for (int i = 0; i < default_length; i++) + this->get_reg(i).assign(buffer + i * part_type::size()); + } + + This operator-(const This& other) const + { + return *this + other; + } + + This operator*(const BitVec& other) const + { + This res = *this; + for (int i = 0; i < super::size(); i++) + if (not other.get_bit(i)) + res.get_reg(i) = {}; + return res; + } + + This extend_bit() const + { + This res; + res.get_regs().resize(BitVec::N_BITS, this->get_reg(0)); + return res; + } + + This mask(int n_bits) const + { + This res = *this; + res.get_regs().resize(n_bits); + return res; + } + + void reveal(size_t n_bits, Clear& x) + { + auto& to_open = *this; + to_open.resize_regs(n_bits); + auto& party = ShareThread::s(); + x = party.MC->POpen(to_open, *party.P); + } + + void output(ostream& s, bool human = true) const + { + assert(this->get_regs().size() == default_length); + for (auto& reg : this->get_regs()) + reg.output(s, human); + } +}; + +template +inline TinySecret operator*(const BitVec& clear, const TinySecret& share) +{ + return share * clear; +} + +} /* namespace GC */ + +#endif /* GC_TINYSECRET_H_ */ diff --git a/GC/TinyShare.cpp b/GC/TinyShare.cpp new file mode 100644 index 00000000..cdbd03b6 --- /dev/null +++ b/GC/TinyShare.cpp @@ -0,0 +1,11 @@ +/* + * TinyShare.cpp + * + */ + +#include "TinyShare.h" + +namespace GC +{ + +} /* namespace GC */ diff --git a/GC/TinyShare.h b/GC/TinyShare.h new file mode 100644 index 00000000..51d724e6 --- /dev/null +++ b/GC/TinyShare.h @@ -0,0 +1,80 @@ +/* + * TinyShare.h + * + */ + +#ifndef GC_TINYSHARE_H_ +#define GC_TINYSHARE_H_ + +#include "ShareSecret.h" +#include "ShareParty.h" +#include "Secret.h" +#include "Protocols/Spdz2kShare.h" +#include "Processor/NoLivePrep.h" + +namespace GC +{ + +template class TinySecret; +template class ShareThread; + +template +class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> +{ + typedef TinyShare This; + +public: + typedef Spdz2kShare<1, S> super; + + typedef void DynamicMemory; + + typedef NoLivePrep LivePrep; + + typedef SwitchableOutput out_type; + + static string name() + { + return "tiny share"; + } + + static ShareThread>& get_party() + { + return ShareThread>::s(); + } + + static This new_reg() + { + return {}; + } + + TinyShare() + { + } + TinyShare(const typename super::super& other) : + super(other) + { + } + + void XOR(const This& a, const This& b) + { + *this = a + b; + } + + void public_input(bool input) + { + auto& party = get_party(); + *this = super::constant(input, party.P->my_num(), + ShareParty < TinySecret < S >> ::s().mac_key); + } + + void random() + { + TinySecret tmp; + get_party().DataF.get_one(DATA_BIT, tmp); + *this = tmp.get_reg(0); + } +}; + +} /* namespace GC */ + +#endif /* GC_TINYSHARE_H_ */ diff --git a/GC/instructions.h b/GC/instructions.h index 28c1b1f6..18ee9edd 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -55,7 +55,7 @@ X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \ X(SHRCI, C0 = C1 >> IMM) \ X(SHLCI, C0 = C1 << IMM) \ - X(LDBITS, S0.load(R1, IMM)) \ + X(LDBITS, S0.load_clear(R1, IMM)) \ X(LDMS, S0 = MSD) \ X(STMS, MSD = S0) \ X(LDMSI, S0 = MSI) \ @@ -67,7 +67,7 @@ X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \ X(STMSDI, PROC.store_dynamic_indirect(EXTRA, MD)) \ X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA, MD)) \ - X(CONVSINT, S0.load(IMM, I1)) \ + X(CONVSINT, S0.load_clear(IMM, I1)) \ X(CONVCINT, C0 = I1) \ X(CONVCBIT, T::convcbit(I0, C1)) \ X(MOVS, S0 = PS1) \ diff --git a/License.txt b/License.txt index 5c7c4b13..5c7fce3d 100644 --- a/License.txt +++ b/License.txt @@ -25,6 +25,11 @@ Copyright (c) 2018, The University of Bristol, Bar-Ilan University Please contact mks.keller@gmail.com The same license as for SPDZ-2 applies. ___________________________________________________________________ +SCALE-MAMBA [https://github.com/KULeuven-COSIC/SCALE-MAMBA] +Copyright (c) 2019, The University of Bristol, COSIC-KU Leuven +Please contact nigel.smart@kuleuven.be +See below for the full license. +___________________________________________________________________ University of Bristol : Open Access Software Licence @@ -46,3 +51,27 @@ Any use of the software for scientific publications or commercial purposes shoul Enquiries about further applications and development opportunities are welcome. Please contact nigel@cs.bris.ac.uk +___________________________________________________________________ + + +This software incorporates components from the original SPDZ system, as well as various +extensions. It's copyright therefore rests with the following two institutions: + +Copyright (c) 2017, The University of Bristol, Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. +Copyright (c) 2018, COSIC-KU Leuven, Kasteelpark Arenberg 10, bus 2452, B-3001 Leuven-Heverlee, Belgium. + +All rights reserved + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +Any use of the software for commercial purposes should be reported to the nigel.smart@kuleuven.be +This is for impact and usage monitoring purposes only. + +Enquiries about further applications and development opportunities are welcome. Please contact nigel.smart@kuleuven.be diff --git a/OT/OTMachine.cpp b/Machines/OTMachine.cpp similarity index 100% rename from OT/OTMachine.cpp rename to Machines/OTMachine.cpp diff --git a/OT/OTMachine.h b/Machines/OTMachine.h similarity index 100% rename from OT/OTMachine.h rename to Machines/OTMachine.h diff --git a/OT/OText_main.cpp b/Machines/OText_main.cpp similarity index 100% rename from OT/OText_main.cpp rename to Machines/OText_main.cpp diff --git a/OT/OutputCheck.h b/Machines/OutputCheck.h similarity index 100% rename from OT/OutputCheck.h rename to Machines/OutputCheck.h diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 27b533a0..69a6d0b6 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -65,18 +65,6 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr "-ip", // Flag token. "--ip-file-name" // Flag token. ); - opt.add( - "empty", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Where to obtain memory, new|old|empty (default: empty)\n\t" - "new: copy from Player-Memory-P file\n\t" - "old: reuse previous memory in Memory-P\n\t" - "empty: create new empty memory", // Help description. - "-m", // Flag token. - "--memory" // Flag token. - ); opt.add( "", // Default. 0, // Required? @@ -143,14 +131,13 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr "--encrypted" // Flag token. ); - string memtype, hostname, ipFileName; + string hostname, ipFileName; int lg2, pnbase, opening_sum, max_broadcast; int my_port; online_opts.finalize(opt, argc, argv); opt.get("--portnumbase")->getInt(pnbase); opt.get("--lg2")->getInt(lg2); - opt.get("--memory")->getString(memtype); opt.get("--hostname")->getString(hostname); opt.get("--ip-file-name")->getString(ipFileName); opt.get("--opening-sum")->getInt(opening_sum); @@ -192,7 +179,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr try #endif { - Machine(playerno, playerNames, online_opts.progname, memtype, lg2, + Machine(playerno, playerNames, online_opts.progname, online_opts.memtype, lg2, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, opt.get("--threads")->isSet, max_broadcast, opt.get("--encrypted")->isSet, online_opts.live_prep, diff --git a/OT/TripleMachine.cpp b/Machines/TripleMachine.cpp similarity index 83% rename from OT/TripleMachine.cpp rename to Machines/TripleMachine.cpp index 97454fdb..79a23c14 100644 --- a/OT/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -5,7 +5,6 @@ #include #include "OT/NPartyTripleGenerator.h" -#include "OT/OTMachine.h" #include "OT/OTTripleSetup.h" #include "Math/gf2n.h" #include "Math/Setup.h" @@ -13,9 +12,11 @@ #include "Tools/ezOptionParser.h" #include "Math/Setup.h" #include "Protocols/fake-stuff.h" +#include "Math/BitVec.h" #include "Protocols/fake-stuff.hpp" #include "Math/Z2k.hpp" +#include "OT/NPartyTripleGenerator.hpp" #include #include @@ -23,24 +24,10 @@ using namespace std; void* run_ngenerator_thread(void* ptr) { - ((MascotGenerator*)ptr)->generate(); + ((GeneratorThread*)ptr)->generate(); return 0; } -MascotParams::MascotParams() -{ - generateMACs = true; - amplify = true; - check = true; - generateBits = false; - timerclear(&start); -} - -void MascotParams::set_passive() -{ - generateMACs = amplify = check = false; -} - TripleMachine::TripleMachine(int argc, const char** argv) : nConnections(1), bonding(0) { @@ -167,9 +154,10 @@ TripleMachine::TripleMachine(int argc, const char** argv) : } template -NPartyTripleGenerator* TripleMachine::new_generator(OTTripleSetup& setup, int i) +GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i) { - return new NPartyTripleGenerator(setup, N[i%nConnections], i, nTriplesPerThread, nloops, *this); + return new typename T::TripleGenerator(setup, N[i % nConnections], i, + nTriplesPerThread, nloops, *this); } void TripleMachine::run() @@ -186,13 +174,13 @@ void TripleMachine::run() PlainPlayer P(N[0], 0xF000); OTTripleSetup setup(P, true); - vector generators(nthreads); + vector generators(nthreads); vector threads(nthreads); for (int i = 0; i < nthreads; i++) { if (primeField) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i); else if (z2k) { if (z2k == 32 and z2s == 32) @@ -270,58 +258,3 @@ void TripleMachine::output_mac_keys() else write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s); } - -template<> gf2n_long MascotParams::get_mac_key() -{ - return mac_key2l; -} - -template<> gf2n_short MascotParams::get_mac_key() -{ - return mac_key2s; -} - -template<> gfp1 MascotParams::get_mac_key() -{ - return mac_keyp; -} - -template<> Z2<48> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<64> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<32> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> void MascotParams::set_mac_key(gf2n_long key) -{ - mac_key2l = key; -} - -template<> void MascotParams::set_mac_key(gf2n_short key) -{ - mac_key2s = key; -} - -template<> void MascotParams::set_mac_key(gfp1 key) -{ - mac_keyp = key; -} - -template<> void MascotParams::set_mac_key(Z2<64> key) -{ - mac_keyz = key; -} - -template<> void MascotParams::set_mac_key(Z2<48> key) -{ - mac_keyz = key; -} diff --git a/Machines/malicious-rep-bin-party.cpp b/Machines/malicious-rep-bin-party.cpp index 502112e0..a90f4b8e 100644 --- a/Machines/malicious-rep-bin-party.cpp +++ b/Machines/malicious-rep-bin-party.cpp @@ -3,10 +3,25 @@ * */ -#include "GC/ReplicatedParty.h" +#include "GC/ShareParty.h" +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" #include "GC/MaliciousRepSecret.h" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" + +#include "Processor/Machine.hpp" +#include "Processor/Instruction.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Beaver.hpp" + int main(int argc, const char** argv) { - GC::ReplicatedParty(argc, argv); + GC::ShareParty(argc, argv); } diff --git a/Machines/replicated-bin-party.cpp b/Machines/replicated-bin-party.cpp index d528e4db..39e41567 100644 --- a/Machines/replicated-bin-party.cpp +++ b/Machines/replicated-bin-party.cpp @@ -3,9 +3,24 @@ * */ -#include "GC/ReplicatedParty.h" +#include "GC/ShareParty.h" + +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" + +#include "Processor/Machine.hpp" +#include "Processor/Instruction.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Beaver.hpp" int main(int argc, const char** argv) { - GC::ReplicatedParty(argc, argv); + GC::ShareParty(argc, argv); } diff --git a/Machines/semi-bin-party.cpp b/Machines/semi-bin-party.cpp new file mode 100644 index 00000000..ba2c1f9a --- /dev/null +++ b/Machines/semi-bin-party.cpp @@ -0,0 +1,28 @@ +/* + * semi-bin-party.cpp + * + */ + +#include "GC/ShareParty.h" +#include "GC/SemiSecret.h" + +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" + +#include "GC/Machine.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "GC/Processor.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/ReplicatedInput.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Input.hpp" + +int main(int argc, const char** argv) +{ + GC::ShareParty(argc, argv); +} diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index e5b6f674..ad15c085 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -4,6 +4,7 @@ */ #include "Processor/Machine.h" +#include "Processor/RingOptions.h" #include "Protocols/Spdz2kShare.h" #include "Math/gf2n.h" #include "Networking/Server.h" @@ -27,13 +28,24 @@ int main(int argc, const char** argv) int s; opt.get("-S")->getInt(s); opt.resetArgs(); + RingOptions ring_options(opt, argc, argv); + int k = ring_options.R; #ifdef VERBOSE - cerr << "Using SPDZ2k with security parameter " << s << endl; + cerr << "Using SPDZ2k with ring length " << k << " and security parameter " + << s << endl; #endif - if (s == 64) - return spdz_main, Share>(argc, argv, opt); - else if (s == 48) - return spdz_main, Share>(argc, argv, opt); + +#undef Z +#define Z(K, S) \ + if (s == S and k == K) \ + return spdz_main, Share>(argc, argv, opt); + + Z(64, 64) + Z(64, 48) + Z(72, 64) + Z(72, 48) + else - throw runtime_error("not compiled for s=" + to_string(s)); + throw runtime_error( + "not compiled for k=" + to_string(k) + " and s=" + to_string(s)); } diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp new file mode 100644 index 00000000..aa7e277c --- /dev/null +++ b/Machines/tiny-party.cpp @@ -0,0 +1,31 @@ +/* + * tiny-party.cpp + * + */ + +#include "GC/TinySecret.h" +#include "GC/ShareParty.h" +#include "GC/TinyMC.h" + +#include "GC/ShareParty.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" + +#include "Processor/Machine.hpp" +#include "Processor/Instruction.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/MascotPrep.hpp" + +int main(int argc, const char** argv) +{ + GC::ShareParty>(argc, argv, 1000); +} diff --git a/Makefile b/Makefile index b5aedf7f..ed1bfe48 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ OT_EXE = ot.x ot-offline.x COMMON = $(MATH) $(TOOLS) $(NETWORK) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) -YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) BMR/Key.o +YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(OT) VM = $(PROCESSOR) $(COMMON) @@ -35,7 +35,7 @@ DEPS := $(wildcard */*.d) .SECONDARY: $(OBJS) -all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x mascot-party.x +all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x ifeq ($(USE_NTL),1) all: overdrive she-offline cowgear-party.x @@ -77,7 +77,7 @@ spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Off tldr: -echo ARCH = -march=native >> CONFIG.mine - $(MAKE) Player-Online.x + $(MAKE) mascot-party.x ifeq ($(OS), Darwin) tldr: mac-setup @@ -90,7 +90,7 @@ shamir: shamir-party.x malicious-shamir-party.x galois-degree.x ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) +$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Machines/S*.cpp)) $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC) $(AR) -csr $@ $^ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) @@ -104,47 +104,17 @@ static-dir: static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) -Fake-Offline.x: Fake-Offline.cpp $(COMMON) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) $(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) $(ECLIB) -Check-Offline.x: Check-Offline.o $(COMMON) $(PROCESSOR) - $(CXX) $(CFLAGS) -o Check-Offline.x $^ $(LDLIBS) +Check-Offline.x: $(PROCESSOR) -Check-Offline-Z2k.x: Check-Offline-Z2k.cpp $(COMMON) - $(CXX) $(CFLAGS) -o Check-Offline-Z2k.x $^ $(LDLIBS) - -Server.x: Server.cpp $(COMMON) - $(CXX) $(CFLAGS) Server.cpp -o Server.x $(COMMON) $(LDLIBS) - -Setup.x: Setup.cpp $(COMMON) - $(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS) - -ot.x: $(OT) $(COMMON) OT/OText_main.cpp $(LIBSIMPLEOT) +ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -ot-check.x: $(OT) $(COMMON) - $(CXX) $(CFLAGS) -o ot-check.x OT/OutputCheck.cpp $(COMMON) $(LDLIBS) +ot-offline.x: $(OT) $(LIBSIMPLEOT) Machines/TripleMachine.o -ot-bitmatrix.x: $(OT) $(COMMON) OT/BitMatrixTest.cpp - $(CXX) $(CFLAGS) -o ot-bitmatrix.x OT/BitMatrixTest.cpp OT/BitMatrix.o $(COMMON) $(LDLIBS) - -ot-offline.x: $(OT) $(COMMON) ot-offline.cpp $(LIBSIMPLEOT) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - -check-passive.x: $(COMMON) check-passive.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - -gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON) - $(CXX) $(CFLAGS) Scripts/gen_input_f2n.cpp -o gen_input_f2n.x $(COMMON) $(LDLIBS) - -gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON) - $(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS) - -gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(GC) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +gc-emulate.x: $(PROCESSOR) GC/FakeSecret.o GC/square64.o bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS) @@ -155,48 +125,41 @@ bmr-%.x: $(BMR) Machines/bmr-%.cpp $(LIBSIMPLEOT) bmr-clean: -rm BMR/*.o BMR/*/*.o GC/*.o -client-setup.x: client-setup.cpp $(COMMON) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -ifeq ($(USE_NTL),1) -simple-offline.x: $(COMMON) $(FHEOFFLINE) simple-offline.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - -pairwise-offline.x: $(COMMON) $(FHEOFFLINE) pairwise-offline.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - -cnc-offline.x: $(COMMON) $(FHEOFFLINE) cnc-offline.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) - -spdz2-offline.x: $(COMMON) $(FHEOFFLINE) spdz2-offline.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -endif +simple-offline.x: $(FHEOFFLINE) +pairwise-offline.x: $(FHEOFFLINE) +cnc-offline.x: $(FHEOFFLINE) +spdz2-offline.x: $(FHEOFFLINE) yao-party.x: $(YAO) yao-clean: -rm Yao/*.o -galois-degree.x: galois-degree.cpp +galois-degree.x: Utils/galois-degree.cpp $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) -default-prime-length.x: default-prime-length.cpp +default-prime-length.x: Utils/default-prime-length.cpp $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +%.x: Utils/%.o $(COMMON) + $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) + %.x: Machines/%.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) %-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM) $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) $(ECLIB) -replicated-bin-party.x: $(GC) -malicious-rep-bin-party.x: $(GC) +replicated-bin-party.x: GC/square64.o +malicious-rep-bin-party.x: GC/square64.o +semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o +tiny-party.x: $(OT) shamir-party.x: Machines/ShamirMachine.o malicious-shamir-party.x: Machines/ShamirMachine.o spdz2k-party.x: $(OT) diff --git a/Math/BitVec.h b/Math/BitVec.h index 1bfa8303..a0f5d3e9 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -8,12 +8,18 @@ #include "Integer.h" #include "field_types.h" +#include "Square.h" + +class BitDiagonal; class BitVec : public IntBase { public: typedef BitVec Scalar; + typedef BitVec next; + typedef BitDiagonal Square; + static const int n_bits = sizeof(a) * 8; static char type_char() { return 'B'; } @@ -32,10 +38,14 @@ public: BitVec operator/(const BitVec& other) const { (void) other; throw not_implemented(); } BitVec& operator+=(const BitVec& other) { *this ^= other; return *this; } + BitVec& operator-=(const BitVec& other) { *this ^= other; return *this; } BitVec extend_bit() const { return -(a & 1); } BitVec mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; } + template + void add(octetStream& os) { *this += os.get(); } + void mul(const BitVec& a, const BitVec& b) { *this = a * b; } void randomize(PRNG& G, int n = n_bits) { IntBase::randomize(G); *this = mask(n); } diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 28d8eb90..9a114d92 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -7,6 +7,7 @@ #define MATH_FIXEDVEC_H_ #include +#include using namespace std; #include "Tools/octetStream.h" @@ -21,7 +22,7 @@ template class Replicated; template class FixedVec { - T v[L]; + array v; public: typedef T value_type; @@ -71,6 +72,16 @@ public: v[i] = other[i]; } + FixedVec(const array& other) + { + v = other; + } + + const array& get() const + { + return v; + } + T& operator[](int i) { return v[i]; diff --git a/Math/Integer.h b/Math/Integer.h index ed97889f..8f8e175f 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -24,9 +24,12 @@ protected: long a; public: + static const int N_BYTES = sizeof(a); static const int N_BITS = 8 * sizeof(a); + static const int MAX_N_BITS = N_BITS; static int size() { return sizeof(a); } + static int length() { return N_BITS; } static string type_string() { return "integer"; } static void init_default(int lgp) { (void)lgp; } @@ -39,10 +42,12 @@ public: long get() const { return a; } bool get_bit(int i) const { return (a >> i) & 1; } + char* get_ptr() const { return (char*)&a; } + unsigned long debug() const { return a; } void assign(long x) { *this = x; } - void assign(const char* buffer) { avx_memcpy(&a, buffer, sizeof(a)); } + void assign(const void* buffer) { avx_memcpy(&a, buffer, sizeof(a)); } void assign_zero() { a = 0; } void assign_one() { a = 1; } @@ -50,8 +55,20 @@ public: bool is_one() const { return a == 1; } bool is_bit() const { return is_zero() or is_one(); } - long operator>>(const IntBase& other) const { return a >> other.a; } - long operator<<(const IntBase& other) const { return a << other.a; } + long operator>>(const IntBase& other) const + { + if (other.a < N_BITS) + return (unsigned long) a >> other.a; + else + return 0; + } + long operator<<(const IntBase& other) const + { + if (other.a < N_BITS) + return a << other.a; + else + return 0; + } long operator^(const IntBase& other) const { return a ^ other.a; } long operator&(const IntBase& other) const { return a & other.a; } diff --git a/Math/Setup.cpp b/Math/Setup.cpp index ea9d1523..e103353d 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -121,7 +121,7 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2 if (mkdir_p(ss.str().c_str()) == -1) { cerr << "mkdir_p(" << ss.str() << ") failed\n"; - throw file_error(); + throw file_error(ss.str()); } // Output the data diff --git a/Math/Square.cpp b/Math/Square.cpp index 31c760df..b7f3ece8 100644 --- a/Math/Square.cpp +++ b/Math/Square.cpp @@ -4,6 +4,7 @@ */ #include "Square.h" +#include "BitVec.h" template<> void Square::to(gf2n_short& result) @@ -34,3 +35,11 @@ void Square::to(gfp1& result) mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L); result.assign((void*) ans); } + +template<> +void Square::to(BitVec& result) +{ + result = 0; + for (int i = 0; i < N_ROWS; i++) + result ^= ((rows[i] >> i) & 1) << i; +} diff --git a/Math/Square.h b/Math/Square.h index 2fd144ce..b33d8134 100644 --- a/Math/Square.h +++ b/Math/Square.h @@ -12,6 +12,8 @@ template class Square { public: + typedef U RowType; + static const int N_ROWS = U::MAX_N_BITS; static const int N_ROWS_ALLOCATED = N_ROWS; static const int N_COLUMNS = N_ROWS; @@ -21,16 +23,11 @@ public: U rows[N_ROWS]; - template Square& sub(const Square& other); - template Square& rsub(const Square& other); - template Square& sub(const void* other); - template void randomize(int row, PRNG& G) { rows[row].randomize(G); } - template void conditional_add(BitVector& conditions, Square& other, int offset); void to(U& result); diff --git a/Math/Square.hpp b/Math/Square.hpp index 8e4fdc13..e02ee376 100644 --- a/Math/Square.hpp +++ b/Math/Square.hpp @@ -6,7 +6,6 @@ #include "Math/Square.h" template -template Square& Square::sub(const Square& other) { for (int i = 0; i < U::length(); i++) @@ -15,7 +14,6 @@ Square& Square::sub(const Square& other) } template -template Square& Square::rsub(const Square& other) { for (int i = 0; i < U::length(); i++) @@ -24,7 +22,6 @@ Square& Square::rsub(const Square& other) } template -template Square& Square::sub(const void* other) { U value; @@ -35,7 +32,6 @@ Square& Square::sub(const void* other) } template -template void Square::conditional_add(BitVector& conditions, Square& other, int offset) { diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index 2ab069ef..e931b209 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -16,6 +16,8 @@ public: static void init_default(int l) { (void) l; } static void read_setup(int nparties, int lg2p, int gf2ndegree); + + void normalize() {} }; #endif /* MATH_VALUEINTERFACE_H_ */ diff --git a/Math/Z2k.h b/Math/Z2k.h index 61d4229b..05530d8d 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -109,6 +109,7 @@ public: Z2 operator*(const Z2& other) const; Z2 operator*(bool other) const { return other ? *this : Z2(); } + Z2 operator*(int other) const { return *this * Z2(other); } Z2 operator/(const Z2& other) const { (void) other; throw not_implemented(); } diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index ed98b9e0..3610ada9 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -352,7 +352,7 @@ void gf2n_short::input(istream& s,bool human) if (s.peek() == EOF) { if (s.tellg() == 0) { cout << "IO problem. Empty file?" << endl; - throw file_error(); + throw file_error("gf2n_short input"); } throw end_of_file("gf2n_short"); } diff --git a/Math/gf2n.h b/Math/gf2n.h index cd157d82..05e414dd 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -64,6 +64,7 @@ class gf2n_short typedef gf2n_short Scalar; static const int MAX_N_BITS = 64; + static const int N_BYTES = sizeof(a); static void init_field(int nn); static int degree() { return n; } diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp index f16d2576..7d401b5a 100644 --- a/Math/gf2nlong.cpp +++ b/Math/gf2nlong.cpp @@ -257,7 +257,7 @@ void gf2n_long::input(istream& s,bool human) if (s.peek() == EOF) { if (s.tellg() == 0) { cout << "IO problem. Empty file?" << endl; - throw file_error(); + throw file_error("gf2n_long input"); } throw end_of_file("gf2n_long"); } diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 84131feb..22bb2f24 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -100,6 +100,7 @@ class gf2n_long typedef ::Square Square; const static int MAX_N_BITS = 128; + const static int N_BYTES = sizeof(a); typedef gf2n_long Scalar; diff --git a/Math/gfp.h b/Math/gfp.h index ca6f2645..de79d04f 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -56,6 +56,7 @@ class gfp_ static const int N_LIMBS = L; static const int MAX_N_BITS = 64 * L; + static const int N_BYTES = sizeof(a); template static void init(bool mont = true) diff --git a/Math/modp.hpp b/Math/modp.hpp index 4da61678..4f69eeda 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -256,7 +256,7 @@ void modp_::input(istream& s,const Zp_Data& ZpD,bool human) if (s.peek() == EOF) { if (s.tellg() == 0) { cout << "IO problem. Empty file?" << endl; - throw file_error(); + throw file_error("modp input"); } throw end_of_file("modp"); } diff --git a/Math/operators.h b/Math/operators.h index 3af84e76..b3714cd3 100644 --- a/Math/operators.h +++ b/Math/operators.h @@ -6,8 +6,8 @@ #ifndef MATH_OPERATORS_H_ #define MATH_OPERATORS_H_ -template -T operator*(const bool& x, const T& y) { return x ? y : T(); } +//template +//T operator*(const bool& x, const T& y) { return x ? y : T(); } //template //T operator*(const T& y, const bool& x) { return x ? y : T(); } template diff --git a/Networking/Player.cpp b/Networking/Player.cpp index d73bd823..2f15d75d 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -56,10 +56,23 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante nplayers = 0; portnum_base = pnb; string line; + ports.clear(); while (getline(hostsfile, line)) { if (line.length() > 0 && line.at(0) != '#') { - names.push_back(line); + auto pos = line.find(':'); + if (pos == string::npos) + { + names.push_back(line); + ports.push_back(default_port(nplayers)); + } + else + { + names.push_back(line.substr(0, pos)); + int port; + stringstream(line.substr(pos + 1)) >> port; + ports.push_back(port); + } nplayers++; if (nplayers_wanted > 0 and nplayers_wanted == nplayers) break; @@ -67,29 +80,18 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante } if (nplayers_wanted > 0 and nplayers_wanted != nplayers) throw runtime_error("not enought hosts in HOSTS"); - setup_ports(); #ifdef DEBUG_NETWORKING cerr << "Got list of " << nplayers << " players from file: " << endl; for (unsigned int i = 0; i < names.size(); i++) - cerr << " " << names[i] << endl; + cerr << " " << names[i] << ":" << ports[i] << endl; #endif setup_server(); } Names::Names(ez::ezOptionParser& opt, int argc, const char** argv, - int default_nplayers) : - Names() + int default_nplayers) : Names() { - NetworkOptions network_opts(opt, argc, argv); - opt.add( - to_string(default_nplayers).c_str(), // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Number of players", // Help description. - "-N", // Flag token. - "--nparties" // Flag token. - ); + NetworkOptionsWithNumber network_opts(opt, argc, argv, default_nplayers, true); opt.add( "", // Default. 1, // Required? @@ -101,9 +103,7 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv, ); opt.parse(argc, argv); opt.get("-p")->getInt(player_no); - opt.get("-N")->getInt(nplayers); - global_server = Server::start_networking(*this, player_no, nplayers, - network_opts.hostname, network_opts.portnum_base); + global_server = network_opts.start_networking(*this, player_no); } void Names::setup_ports() @@ -396,6 +396,9 @@ void MultiPlayer::exchange_no_stats(int other, const octetStream& o, octetStr void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const { +#ifdef VERBOSE_COMM + cerr << "Exchanging with " << other << endl; +#endif TimeScope ts(comm_stats["Exchanging"].add(o)); exchange_no_stats(other, o, to_receive); sent += o.get_length(); @@ -605,34 +608,34 @@ int RealTwoPartyPlayer::other_player_num() const return other_player; } -void RealTwoPartyPlayer::send(octetStream& o) +void RealTwoPartyPlayer::send(octetStream& o) const { TimeScope ts(comm_stats["Sending one-to-one"].add(o)); o.Send(socket); sent += o.get_length(); } -void VirtualTwoPartyPlayer::send(octetStream& o) +void VirtualTwoPartyPlayer::send(octetStream& o) const { TimeScope ts(comm_stats["Sending one-to-one"].add(o)); P.send_to_no_stats(other_player, o); sent += o.get_length(); } -void RealTwoPartyPlayer::receive(octetStream& o) +void RealTwoPartyPlayer::receive(octetStream& o) const { TimeScope ts(timer); o.reset_write_head(); o.Receive(socket); } -void VirtualTwoPartyPlayer::receive(octetStream& o) +void VirtualTwoPartyPlayer::receive(octetStream& o) const { TimeScope ts(timer); P.receive_player_no_stats(other_player, o); } -void RealTwoPartyPlayer::send_receive_player(vector& o) +void RealTwoPartyPlayer::send_receive_player(vector& o) const { { if (is_server) @@ -655,7 +658,7 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const o.exchange(socket, socket); } -void VirtualTwoPartyPlayer::send_receive_player(vector& o) +void VirtualTwoPartyPlayer::send_receive_player(vector& o) const { TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0])); sent += o[0].get_length(); @@ -667,11 +670,21 @@ VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) : { } -void OffsetPlayer::send_receive_player(vector& o) +void OffsetPlayer::send_receive_player(vector& o) const { P.exchange(P.get_player(offset), o[0], o[1]); } +void TwoPartyPlayer::Broadcast_Receive(vector& o, + bool donthash) const +{ + (void) donthash; + vector os(2); + os[0] = o[my_num()]; + send_receive_player(os); + o[1 - my_num()] = os[1]; +} + CommStats& CommStats::operator +=(const CommStats& other) { data += other.data; diff --git a/Networking/Player.h b/Networking/Player.h index cfadfef3..d07fd92c 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -126,9 +126,11 @@ public: virtual ~PlayerBase(); int my_real_num() const { return player_no; } + virtual int my_num() const = 0; virtual int num_players() const = 0; virtual void pass_around(octetStream& o, int offset = 1) const = 0; + virtual void Broadcast_Receive(vector& o,bool donthash=false) const = 0; }; class Player : public PlayerBase @@ -276,9 +278,10 @@ public: virtual int my_num() const = 0; virtual int other_player_num() const = 0; - virtual void send(octetStream& o) = 0; - virtual void receive(octetStream& o) = 0; - virtual void send_receive_player(vector& o) = 0; + virtual void send(octetStream& o) const = 0; + virtual void receive(octetStream& o) const = 0; + virtual void send_receive_player(vector& o) const = 0; + void Broadcast_Receive(vector& o, bool donthash=false) const; }; class RealTwoPartyPlayer : public TwoPartyPlayer @@ -295,8 +298,8 @@ public: RealTwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0); ~RealTwoPartyPlayer(); - void send(octetStream& o); - void receive(octetStream& o); + void send(octetStream& o) const; + void receive(octetStream& o) const; int other_player_num() const; int my_num() const { return is_server; } @@ -305,7 +308,7 @@ public: /* Send and receive to/from the other player * - o[0] contains my data, received data put in o[1] */ - void send_receive_player(vector& o); + void send_receive_player(vector& o) const; void exchange(octetStream& o) const; void exchange(int other, octetStream& o) const { (void)other; exchange(o); } @@ -326,9 +329,9 @@ public: int other_player_num() const { return other_player; } int num_players() const { return 2; } - void send(octetStream& o); - void receive(octetStream& o); - void send_receive_player(vector& o); + void send(octetStream& o) const; + void receive(octetStream& o) const; + void send_receive_player(vector& o) const; void pass_around(octetStream& o, int _ = 1) const { (void)_, (void) o; throw not_implemented(); } }; @@ -349,16 +352,15 @@ public: int num_players() const { return 2; } int get_offset() const { return offset; } - void send(octetStream& o) { P.send_to(P.get_player(offset), o, true); } - void reverse_send(octetStream& o) { P.send_to(P.get_player(-offset), o, true); } - void receive(octetStream& o) { P.receive_player(P.get_player(offset), o, true); } + void send(octetStream& o) const { P.send_to(P.get_player(offset), o, true); } + void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o, true); } + void receive(octetStream& o) const { P.receive_player(P.get_player(offset), o, true); } void reverse_receive(octetStream& o) { P.receive_player(P.get_player(-offset), o, true); } - void send_receive_player(vector& o); + void send_receive_player(vector& o) const; void reverse_exchange(octetStream& o) const { P.pass_around(o, P.num_players() - offset); } - void exchange(int other, octetStream& o) const { (void)other; P.pass_around(o, offset); } + void exchange(octetStream& o) const { P.exchange(P.get_player(offset), o); } void pass_around(octetStream& o, int _ = 1) const { (void)_; P.pass_around(o, offset); } - void Broadcast_Receive(vector& o,bool donthash=false) const; }; #endif diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 49b42366..e93f35cb 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -106,6 +106,8 @@ void BaseOT::exec_base(bool new_receiver_inputs) receiver_maketable(&receiver); } + os[0].reset_write_head(); + for (i = 0; i < nOT; i += 4) { if (ot_role & RECEIVER) @@ -117,12 +119,24 @@ void BaseOT::exec_base(bool new_receiver_inputs) cs[j] = receiver_inputs[i + j]; } receiver_rsgen(&receiver, Rs_pack[0], cs); - os[0].reset_write_head(); os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0])); receiver_keygen(&receiver, receiver_keys); + + // Copy keys to receiver_outputs + for (j = 0; j < 4; j++) + { + for (k = 0; k < AES_BLK_SIZE; k++) + { + receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); + } + } } - send_if_ot_receiver(P, os, ot_role); + } + + send_if_ot_receiver(P, os, ot_role); + for (i = 0; i < nOT; i += 4) + { if (ot_role & SENDER) { os[1].get_bytes((octet*) Rs_pack[1], len); @@ -143,18 +157,6 @@ void BaseOT::exec_base(bool new_receiver_inputs) } } } - - if (ot_role & RECEIVER) - { - // Copy keys to receiver_outputs - for (j = 0; j < 4; j++) - { - for (k = 0; k < AES_BLK_SIZE; k++) - { - receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); - } - } - } #ifdef BASE_OT_DEBUG for (j = 0; j < 4; j++) { diff --git a/OT/BitDiagonal.cpp b/OT/BitDiagonal.cpp new file mode 100644 index 00000000..e6c4293a --- /dev/null +++ b/OT/BitDiagonal.cpp @@ -0,0 +1,19 @@ +/* + * Diagonal.cpp + * + */ + +#include + +void BitDiagonal::pack(octetStream& os) const +{ + for (int i = 0; i < N_ROWS; i++) + os.store_int(rows[i].get_bit(i), 1); +} + +void BitDiagonal::unpack(octetStream& os) +{ + *this = {}; + for (int i = 0; i < N_ROWS; i++) + rows[i] = os.get_int(1) << i; +} diff --git a/OT/BitDiagonal.h b/OT/BitDiagonal.h new file mode 100644 index 00000000..92c0889d --- /dev/null +++ b/OT/BitDiagonal.h @@ -0,0 +1,24 @@ +/* + * Diagonal.h + * + */ + +#ifndef OT_BITDIAGONAL_H_ +#define OT_BITDIAGONAL_H_ + +#include "Math/Square.h" +#include "Math/BitVec.h" + +class BitDiagonal : public Square +{ +public: + static int size() + { + return 8 * BitVec::size(); + } + + void pack(octetStream& os) const; + void unpack(octetStream& os); +}; + +#endif /* OT_BITDIAGONAL_H_ */ diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 4ba949c1..7d78105c 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -9,9 +9,12 @@ #include "BitMatrix.h" #include "Rectangle.h" +#include "BitDiagonal.h" #include "Math/gf2n.h" #include "Math/gfp.h" #include "Math/Z2k.h" +#include "Math/BitVec.h" +#include "GC/TinySecret.h" #include "OT/Rectangle.hpp" #include "Math/Z2k.hpp" @@ -268,25 +271,22 @@ void square128::randomize(PRNG& G) G.get_octets((octet*)&rows, sizeof(rows)); } -template <> -void square128::randomize(int row, PRNG& G) +void square128::randomize(int row, PRNG& G) { rows[row] = G.get_doubleword(); } -template<> -void square128::conditional_add(BitVector& conditions, square128& other, int offset) +void square128::conditional_add(BitVector& conditions, square128& other, int offset) { for (int i = 0; i < 128; i++) if (conditions.get_bit(128 * offset + i)) rows[i] ^= other.rows[i]; } -template void square128::hash_row_wise(MMO& mmo, square128& input) { - mmo.hashBlockWise((octet*)rows, (octet*)input.rows); + mmo.hashBlockWise((octet*)rows, (octet*)input.rows); } template <> @@ -395,20 +395,17 @@ square128& square128::operator^=(square128& other) return *this; } -template<> -square128& square128::add(square128& other) +square128& square128::add(square128& other) { return *this ^= other; } -template<> -square128& square128::sub(square128& other) +square128& square128::sub(square128& other) { return *this ^= other; } -template<> -square128& square128::rsub(square128& other) +square128& square128::rsub(square128& other) { return *this ^= other; } @@ -421,8 +418,7 @@ square128& square128::operator^=(const __m128i* other) return *this; } -template <> -square128& square128::sub(const __m128i* other) +square128& square128::sub(const __m128i* other) { return *this ^= other; } @@ -500,7 +496,7 @@ template void Matrix::randomize(int row, PRNG& G) { for (size_t i = 0; i < squares.size(); i++) - squares[i].template randomize(row, G); + squares[i].randomize(row, G); } void BitMatrix::transpose() @@ -597,44 +593,40 @@ Slice::Slice(U& bm, size_t start, size_t size) : } template -template Slice& Slice::rsub(Slice& other) { if (bm.squares.size() < other.end) throw invalid_length(); for (size_t i = other.start; i < other.end; i++) - bm.squares[i].template rsub(other.bm.squares[i]); + bm.squares[i].rsub(other.bm.squares[i]); return *this; } template -template Slice& Slice::sub(BitVector& other, int repeat) { if (end * U::PartType::N_COLUMNS > other.size() * repeat) throw invalid_length(to_string(U::PartType::N_COLUMNS)); for (size_t i = start; i < end; i++) { - bm.squares[i].template sub(other.get_ptr_to_byte(i / repeat, + bm.squares[i].sub(other.get_ptr_to_byte(i / repeat, U::PartType::N_ROW_BYTES)); } return *this; } template -template void Slice::randomize(int row, PRNG& G) { for (size_t i = start; i < end; i++) - bm.squares[i].template randomize(row, G); + bm.squares[i].randomize(row, G); } template -template void Slice::conditional_add(BitVector& conditions, U& other, bool useOffset) { for (size_t i = start; i < end; i++) - bm.squares[i].template conditional_add(conditions, other.squares[i], useOffset * i); + bm.squares[i].conditional_add(conditions, other.squares[i], useOffset * i); } template <> @@ -651,7 +643,7 @@ void Slice::print() cout << "hex / value" << endl; for (int i = 0; i < 16; i++) { - cout << int128(bm.squares[0].rows[i]) << " " << T(bm.squares[0].rows[i]) << endl; + cout << T(bm.squares[0].rows[i]) << endl; } cout << endl; } @@ -671,24 +663,13 @@ void Slice::unpack(octetStream& os) bm.squares[i].unpack(os); } -#define M(N,L) Matrix, Z2 > > - #undef XXX #define XXX(T,N,L) \ template class Matrix, Z2 > >; \ template class Slice, Z2 > > >; \ -template Slice, Z2 > > >& Slice< \ - Matrix, Z2 > > >::rsub( \ - Slice, Z2 > > >& other); \ -template Slice, Z2 > > >& Slice< \ - Matrix, Z2 > > >::sub(BitVector& other, int repeat); \ -template void Slice, Z2 > > >::conditional_add< \ - T>(BitVector& conditions, \ - Matrix, Z2 > >& other, bool useOffset); \ #undef X #define X(N,L) \ -template void Slice, Z2 > > >::randomize >(int row, PRNG& G); \ XXX(Z2, N, L) //X(96, 160) @@ -700,6 +681,11 @@ Y(64, 48) Y(66, 64) Y(66, 48) Y(32, 32) +Y(1, 40) +Y(72, 48) +Y(74, 48) +Y(72, 64) +Y(74, 64) template class Matrix; @@ -710,19 +696,15 @@ template class Slice; \ XX(BM, gf2n_long) #define XX(BM, GF) \ -template void Slice::conditional_add(BitVector& conditions, BM& other, bool useOffset); \ -template Slice& Slice::rsub(Slice& other); \ -template Slice& Slice::sub(BitVector& other, int repeat); \ -template void Slice::randomize(int row, PRNG& G); \ //template void Slice::print(); BMS -template class Slice>; -XX(Matrix, gf2n_short) +#define XXXX(BM, GF) \ + template class Slice; \ + XX(BM, GF) -template class Slice>>; -XX(Matrix>, gf2n_long) - -template class Slice>>; -XX(Matrix>, gfp1) +XXXX(Matrix, gf2n_short) +XXXX(Matrix>, gf2n_long) +XXXX(Matrix>, gfp1) +XXXX(Matrix, BitVec) diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 6f4c7eda..7561b32d 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -19,7 +19,7 @@ using namespace std; union square128 { - typedef int128 RowType; + typedef gf2n_long RowType; const static int N_ROWS = 128; const static int N_ROWS_ALLOCATED = 128; @@ -46,24 +46,16 @@ union square128 { square128& operator^=(BitVector& other); bool operator==(square128& other); - template square128& add(square128& other); - template square128& sub(square128& other); - template square128& rsub(square128& other); - template square128& sub(const __m128i* other); - template - square128& sub(const void* other) { return sub((__m128i*)other); } + square128& sub(const void* other) { return sub((__m128i*)other); } void randomize(PRNG& G); - template void randomize(int row, PRNG& G); - template void conditional_add(BitVector& conditions, square128& other, int offset); void transpose(); - template void hash_row_wise(MMO& mmo, square128& input); template void to(T& result); @@ -173,14 +165,10 @@ class Slice public: Slice(U& bm, size_t start, size_t size); - template Slice& rsub(Slice& other); - template Slice& sub(BitVector& other, int repeat = 1); - template void randomize(int row, PRNG& G); - template void conditional_add(BitVector& conditions, U& other, bool useOffset = false); void transpose(); diff --git a/OT/MascotParams.cpp b/OT/MascotParams.cpp new file mode 100644 index 00000000..c6375fbe --- /dev/null +++ b/OT/MascotParams.cpp @@ -0,0 +1,106 @@ +/* + * TripleMachine.cpp + * + */ + +#include +#include "OT/NPartyTripleGenerator.h" +#include "OT/OTTripleSetup.h" +#include "Math/gf2n.h" +#include "Math/Setup.h" +#include "Protocols/Spdz2kShare.h" +#include "Tools/ezOptionParser.h" +#include "Math/Setup.h" +#include "Protocols/fake-stuff.h" +#include "Math/BitVec.h" + +#include "Protocols/fake-stuff.hpp" +#include "Math/Z2k.hpp" + +#include +#include +using namespace std; + +MascotParams::MascotParams() +{ + generateMACs = true; + amplify = true; + check = true; + generateBits = false; + timerclear(&start); +} + +void MascotParams::set_passive() +{ + generateMACs = amplify = check = false; +} + +template<> gf2n_long MascotParams::get_mac_key() +{ + return mac_key2l; +} + +template<> gf2n_short MascotParams::get_mac_key() +{ + return mac_key2s; +} + +template<> gfp1 MascotParams::get_mac_key() +{ + return mac_keyp; +} + +template<> Z2<48> MascotParams::get_mac_key() +{ + return mac_keyz; +} + +template<> Z2<64> MascotParams::get_mac_key() +{ + return mac_keyz; +} + +template<> Z2<40> MascotParams::get_mac_key() +{ + return mac_keyz; +} + +template<> Z2<32> MascotParams::get_mac_key() +{ + return mac_keyz; +} + +template<> BitVec MascotParams::get_mac_key() +{ + return 0; +} + +template<> void MascotParams::set_mac_key(gf2n_long key) +{ + mac_key2l = key; +} + +template<> void MascotParams::set_mac_key(gf2n_short key) +{ + mac_key2s = key; +} + +template<> void MascotParams::set_mac_key(gfp1 key) +{ + mac_keyp = key; +} + +template<> void MascotParams::set_mac_key(Z2<64> key) +{ + mac_keyz = key; +} + +template<> void MascotParams::set_mac_key(Z2<48> key) +{ + mac_keyz = key; +} + +template<> void MascotParams::set_mac_key(Z2<40> key) +{ + mac_keyz = key; +} diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 304cfcfb..7a89aa66 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -24,7 +24,7 @@ class PlainTriple; template using ShareTriple = ShareTriple_; -class MascotGenerator +class GeneratorThread { protected: pthread_mutex_t mutex; @@ -37,8 +37,8 @@ public: bool multi_threaded; - MascotGenerator() : nTriples(0), multi_threaded(true) {} - virtual ~MascotGenerator() {}; + GeneratorThread() : nTriples(0), multi_threaded(true) {} + virtual ~GeneratorThread() {}; virtual void generate() = 0; void lock(); @@ -48,7 +48,7 @@ public: }; template -class OTTripleGenerator : public MascotGenerator +class OTTripleGenerator : public GeneratorThread { typedef typename T::open_type open_type; typedef typename T::mac_key_type mac_key_type; @@ -79,7 +79,7 @@ protected: public: // TwoPartyPlayer's for OTs, n-party Player for sacrificing vector players; - vector*> ot_multipliers; + vector ot_multipliers; //vector machines; BitVector baseReceiverInput; // same for every set of OTs vector< vector< vector > > baseSenderInputs; @@ -111,6 +111,8 @@ public: void generatePlainTriples(); void plainTripleRound(int k = 0); + void run_multipliers(MultJob job); + size_t data_sent(); }; @@ -121,8 +123,28 @@ class NPartyTripleGenerator : public OTTripleGenerator typedef typename T::mac_key_type mac_key_type; typedef typename T::sacri_type sacri_type; - template - void generateTriplesZ2k(); + virtual void generateTriples() = 0; + virtual void generateBits() = 0; + +public: + vector< ShareTriple_ > uncheckedTriples; + vector>> inputs; + + NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, + int thread_num, int nTriples, int nloops, MascotParams& machine, + Player* parentPlayer = 0); + virtual ~NPartyTripleGenerator() {} + + void generate(); + void generateInputs(int player); +}; + +template +class MascotTripleGenerator : public NPartyTripleGenerator +{ + typedef typename T::open_type open_type; + typedef typename T::mac_key_type mac_key_type; + typedef typename T::sacri_type sacri_type; void generateTriples(); void generateBits(); @@ -132,21 +154,37 @@ class NPartyTripleGenerator : public OTTripleGenerator void sacrifice(vector >& uncheckedTriples, typename T::MAC_Check& MC, PRNG& G); + +public: + vector bits; + + MascotTripleGenerator(OTTripleSetup& setup, const Names& names, + int thread_num, int nTriples, int nloops, MascotParams& machine, + Player* parentPlayer = 0); +}; + +template +class Spdz2kTripleGenerator : public NPartyTripleGenerator +{ + typedef typename T::open_type open_type; + typedef typename T::mac_key_type mac_key_type; + typedef typename T::sacri_type sacri_type; + + void generateBits() { throw not_implemented(); } + template - void sacrificeZ2k(vector >& uncheckedTriples, + void sacrificeZ2k( + vector< + ShareTriple_ >& uncheckedTriples, U& MC, PRNG& G); public: - vector< ShareTriple_ > uncheckedTriples; - vector bits; - vector>> inputs; - - NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, + Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, Player* parentPlayer = 0); - void generate(); - void generateInputs(int player); + void generateTriples(); }; template diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.hpp similarity index 90% rename from OT/NPartyTripleGenerator.cpp rename to OT/NPartyTripleGenerator.hpp index 5f0cf87e..7b0fafe9 100644 --- a/OT/NPartyTripleGenerator.cpp +++ b/OT/NPartyTripleGenerator.hpp @@ -1,3 +1,6 @@ +#ifndef OT_NPARTYTRIPLGENERATOR_HPP_ +#define OT_NPARTYTRIPLGENERATOR_HPP_ + #include "NPartyTripleGenerator.h" #include "OT/OTExtensionWithMatrix.h" @@ -11,9 +14,11 @@ #include "Tools/Subroutines.h" #include "Protocols/MAC_Check.h" #include "Protocols/Spdz2kPrep.h" +#include "GC/SemiSecret.h" #include "OT/Triple.hpp" #include "OT/Rectangle.hpp" +#include "OT/OTMultiplier.hpp" #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.h" #include "Protocols/MascotPrep.hpp" @@ -46,6 +51,24 @@ NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, { } +template +MascotTripleGenerator::MascotTripleGenerator(OTTripleSetup& setup, + const Names& names, int thread_num, int _nTriples, int nloops, + MascotParams& machine, Player* parentPlayer) : + NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, + machine, parentPlayer) +{ +} + +template +Spdz2kTripleGenerator::Spdz2kTripleGenerator(OTTripleSetup& setup, + const Names& names, int thread_num, int _nTriples, int nloops, + MascotParams& machine, Player* parentPlayer) : + NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, + machine, parentPlayer) +{ +} + template OTTripleGenerator::OTTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, @@ -174,9 +197,9 @@ void NPartyTripleGenerator::generate() timers["Generator thread"].stop(); if (machine.output) cout << "Written " << nTriples << " " << T::type_string() << " outputs to " << ss.str() << endl; -#ifdef VERBOSE +#ifdef VERBOSE_OT else - cout << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl; + cerr << "Generated " << nTriples << " " << T::type_string() << " outputs" << endl; #endif } @@ -251,7 +274,7 @@ void NPartyTripleGenerator::generateInputs(int player) } template<> -void NPartyTripleGenerator>::generateBits() +void MascotTripleGenerator>::generateBits() { for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push(DATA_BIT); @@ -288,9 +311,9 @@ void NPartyTripleGenerator>::generateBits() gf2n r; for (int j = 0; j < nBitsToCheck; j++) { - gf2n mac_sum = bool(valueBits[0].get_bit(j)) * machine.get_mac_key(); + gf2n mac_sum = valueBits[0].get_bit(j) ? machine.get_mac_key() : 0; for (int i = 0; i < nparties-1; i++) - mac_sum += ((MascotMultiplier*)ot_multipliers[i])->macs[0][j]; + mac_sum += ot_multipliers[i]->macs[0][j]; bits[j].set_share(valueBits[0].get_bit(j)); bits[j].set_mac(mac_sum); r.randomize(G); @@ -310,7 +333,7 @@ void NPartyTripleGenerator>::generateBits() } template<> -void NPartyTripleGenerator>::generateBits() +void MascotTripleGenerator>::generateBits() { generateTriples(); } @@ -322,9 +345,12 @@ void NPartyTripleGenerator::generateBits() } template -template -void NPartyTripleGenerator::generateTriplesZ2k() +void Spdz2kTripleGenerator::generateTriples() { + const int K = T::k; + const int S = T::s; + auto& uncheckedTriples = this->uncheckedTriples; + auto& timers = this->timers; auto& machine = this->machine; auto& nTriplesPerLoop = this->nTriplesPerLoop; @@ -386,7 +412,7 @@ void NPartyTripleGenerator::generateTriplesZ2k() timers["Triple computation"].start(); for (int i = 0; i < nparties-1; i++) { - c += ((Spdz2kMultiplier*)ot_multipliers[i])->c_output[j]; + c += ot_multipliers[i]->c_output[j]; } #ifdef DEBUG_SPDZ2K @@ -433,36 +459,6 @@ void NPartyTripleGenerator::generateTriplesZ2k() } } -template<> -void NPartyTripleGenerator>::generateTriples() -{ - this->generateTriplesZ2k<32, 32>(); -} - -template<> -void NPartyTripleGenerator>::generateTriples() -{ - this->generateTriplesZ2k<64, 64>(); -} - -template<> -void NPartyTripleGenerator>::generateTriples() -{ - this->generateTriplesZ2k<64, 48>(); -} - -template<> -void NPartyTripleGenerator>::generateTriples() -{ - this->generateTriplesZ2k<66, 64>(); -} - -template<> -void NPartyTripleGenerator>::generateTriples() -{ - this->generateTriplesZ2k<66, 48>(); -} - template void OTTripleGenerator::generatePlainTriples() { @@ -500,13 +496,15 @@ void OTTripleGenerator::plainTripleRound(int k) for (int j = 0; j < nPreampTriplesPerLoop; j++) { - T a((char*)valueBits[0].get_ptr() + j * T::size()); - T b((char*)valueBits[1].get_ptr() + j / nAmplify * T::size()); + T a; + a.assign((char*)valueBits[0].get_ptr() + j * T::size()); + T b; + b.assign((char*)valueBits[1].get_ptr() + j / nAmplify * T::size()); T c = a * b; timers["Triple computation"].start(); for (int i = 0; i < nparties-1; i++) { - c += dynamic_cast(ot_multipliers[i])->c_output[j]; + c += ot_multipliers[i]->c_output[j]; } timers["Triple computation"].stop(); if (machine.amplify) @@ -531,7 +529,7 @@ void OTTripleGenerator::plainTripleRound(int k) } template -void NPartyTripleGenerator::generateTriples() +void MascotTripleGenerator::generateTriples() { typedef typename U::open_type T; @@ -547,6 +545,7 @@ void NPartyTripleGenerator::generateTriples() auto& outputFile = this->outputFile; auto& field_size = this->field_size; auto& nPreampTriplesPerLoop = this->nPreampTriplesPerLoop; + auto& uncheckedTriples = this->uncheckedTriples; for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push(DATA_TRIPLE); @@ -626,7 +625,7 @@ void NPartyTripleGenerator::generateTriples() } template -void NPartyTripleGenerator::sacrifice( +void MascotTripleGenerator::sacrifice( vector >& uncheckedTriples, typename T::MAC_Check& MC, PRNG& G) { auto& machine = this->machine; @@ -663,7 +662,7 @@ void NPartyTripleGenerator::sacrifice( template template -void NPartyTripleGenerator::sacrificeZ2k( +void Spdz2kTripleGenerator::sacrificeZ2k( vector >& uncheckedTriples, U& MC, PRNG& G) { typedef sacri_type T; @@ -707,7 +706,7 @@ void NPartyTripleGenerator::sacrificeZ2k( } if (machine.generateBits) - generateBitsFromTriples(uncheckedTriples, MC, outputFile); + throw not_implemented(); else if (machine.output) for (int j = 0; j < nTriplesPerLoop; j++) @@ -716,7 +715,7 @@ void NPartyTripleGenerator::sacrificeZ2k( template<> template -void NPartyTripleGenerator>::generateBitsFromTriples( +void MascotTripleGenerator>::generateBitsFromTriples( vector< ShareTriple_ >& triples, W& MC, ofstream& outputFile) { vector< Share > a_plus_b(nTriplesPerLoop), a_squared(nTriplesPerLoop); @@ -746,7 +745,7 @@ void NPartyTripleGenerator>::generateBitsFromTriples( template template -void NPartyTripleGenerator::generateBitsFromTriples( +void MascotTripleGenerator::generateBitsFromTriples( vector< ShareTriple_ >& triples, W& MC, ofstream& outputFile) { throw how_would_that_work(); @@ -786,22 +785,22 @@ void OTTripleGenerator::print_progress(int k) } } -void MascotGenerator::lock() +void GeneratorThread::lock() { pthread_mutex_lock(&mutex); } -void MascotGenerator::unlock() +void GeneratorThread::unlock() { pthread_mutex_unlock(&mutex); } -void MascotGenerator::signal() +void GeneratorThread::signal() { pthread_cond_signal(&ready); } -void MascotGenerator::wait() +void GeneratorThread::wait() { if (multi_threaded) pthread_cond_wait(&ready, &mutex); @@ -821,18 +820,11 @@ void OTTripleGenerator::wait_for_multipliers() ot_multipliers[i]->outbox.pop(); } +template +void OTTripleGenerator::run_multipliers(MultJob job) +{ + signal_multipliers(job); + wait_for_multipliers(); +} -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; - -template class OTTripleGenerator>; -template class OTTripleGenerator>; -template class OTTripleGenerator>; -template class OTTripleGenerator>; - -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; -template class NPartyTripleGenerator>; +#endif diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index b4fd9113..1d6b5bcd 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -8,6 +8,8 @@ #include "Math/gfp.h" #include "Math/Z2k.h" #include "Math/gf2nlong.h" +#include "Math/BitVec.h" +#include "GC/TinySecret.h" #include "OT/Rectangle.hpp" #include "Math/Z2k.hpp" @@ -71,7 +73,7 @@ void OTExtensionWithMatrix::transfer(int nOTs, for (int loop = 0; loop < nloops; loop++) { - extend(nOTs, newReceiverInput); + extend(nOTs, newReceiverInput); #ifdef OTEXT_TIMER gettimeofday(&totalendv, NULL); double elapsed = timeval_diff(&totalstartv, &totalendv); @@ -97,24 +99,25 @@ void OTCorrelator::resize(int nOTs) } // the template is used to denote the field of the hash output -template void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput) { extend_correlated(nOTs_requested, newReceiverInput); - hash_outputs(nOTs_requested); + hash_outputs(nOTs_requested); } -void OTExtensionWithMatrix::extend_correlated(BitVector& newReceiverInput) +void OTExtensionWithMatrix::extend_correlated(const BitVector& newReceiverInput) { extend_correlated(newReceiverInput.size(), newReceiverInput); } -void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& newReceiverInput) +void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, const BitVector& newReceiverBits) { // if (nOTs % nbaseOTs != 0) // throw invalid_length(); //"nOTs must be a multiple of nbaseOTs\n"); if (nOTs_requested == 0) return; + // local copy + auto newReceiverInput = newReceiverBits; if ((ot_role & RECEIVER) and (size_t)nOTs_requested != newReceiverInput.size()) throw runtime_error("wrong number of choice bits"); int nOTs_requested_rounded = (nOTs_requested + 127) / 128 * 128; @@ -133,8 +136,8 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new // subloop for first part to interleave communication with computation for (int start = 0; start < nOTs / 128; start += slice) { - expand(start, slice); - this->correlate(start, slice, newReceiverInput, true); + expand(start, slice); + this->correlate(start, slice, newReceiverInput, true); transpose(start, slice); } @@ -164,7 +167,6 @@ void OTExtensionWithMatrix::extend_correlated(int nOTs_requested, BitVector& new } template -template void OTCorrelator::expand(int start, int slice) { (void)start, (void)slice; @@ -180,8 +182,8 @@ void OTCorrelator::expand(int start, int slice) { for (int i = 0; i < nbaseOTs; i++) { - receiverOutputSlice.template randomize(i, G_sender[i][0]); - t1Slice.template randomize(i, G_sender[i][1]); + receiverOutputSlice.randomize(i, G_sender[i][0]); + t1Slice.randomize(i, G_sender[i][1]); } } @@ -189,23 +191,22 @@ void OTCorrelator::expand(int start, int slice) { for (int i = 0; i < nbaseOTs; i++) // randomize base receiver output - senderOutputSlices[0].template randomize(i, G_receiver[i]); + senderOutputSlices[0].randomize(i, G_receiver[i]); } } -template void OTExtensionWithMatrix::expand_transposed() { for (int i = 0; i < nbaseOTs; i++) { if (ot_role & RECEIVER) { - receiverOutputMatrix.squares[i/128].randomize(i % 128, G_sender[i][0]); - t1.squares[i/128].randomize(i % 128, G_sender[i][1]); + receiverOutputMatrix.squares[i/128].randomize(i % 128, G_sender[i][0]); + t1.squares[i/128].randomize(i % 128, G_sender[i][1]); } if (ot_role & SENDER) { - senderOutputMatrices[0].squares[i/128].randomize(i % 128, G_receiver[i]); + senderOutputMatrices[0].squares[i/128].randomize(i % 128, G_receiver[i]); } } } @@ -224,7 +225,6 @@ void OTCorrelator::setup_for_correlation(BitVector& baseReceiverInput, } template -template void OTCorrelator::correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat) { @@ -240,8 +240,8 @@ void OTCorrelator::correlate(int start, int slice, // create correlation if (ot_role & RECEIVER) { - t1Slice.template rsub(receiverOutputSlice); - t1Slice.template sub(newReceiverInput, repeat); + t1Slice.rsub(receiverOutputSlice); + t1Slice.sub(newReceiverInput, repeat); t1Slice.pack(os[0]); // t1 = receiverOutputMatrix; @@ -260,7 +260,7 @@ void OTCorrelator::correlate(int start, int slice, { // u = t0 + t1 + x uSlice.unpack(os[1]); - senderOutputSlices[0].template conditional_add(baseReceiverInput, u, !useConstantBase); + senderOutputSlices[0].conditional_add(baseReceiverInput, u, !useConstantBase); } #ifdef OTEXT_TIMER gettimeofday(&commst2, NULL); @@ -302,13 +302,12 @@ void OTExtensionWithMatrix::transpose(int start, int slice) /* * Hash outputs to make into random OT */ -template void OTExtensionWithMatrix::hash_outputs(int nOTs) { - hash_outputs(nOTs, senderOutputMatrices, receiverOutputMatrix); + hash_outputs(nOTs, senderOutputMatrices, receiverOutputMatrix); } -template +template void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput) { //cout << "Hashing... " << flush; @@ -319,6 +318,7 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r gettimeofday(&startv, NULL); #endif + typedef typename V::PartType::RowType T; int n_rows = V::PartType::N_ROWS_ALLOCATED; int n = (nOTs + n_rows - 1) / n_rows * V::PartType::N_ROWS; @@ -326,11 +326,6 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r senderOutput[i].resize_vertical(n); receiverOutput.resize_vertical(n); - if (V::PartType::N_ROW_BYTES != T::size()) - throw runtime_error( - "length mismatch for MMO hash: " - + to_string(V::PartType::N_ROW_BYTES) + " != " - + to_string(T::size())); if (nOTs % 8 != 0) throw runtime_error("number of OTs must be divisible by 8"); @@ -378,7 +373,7 @@ void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output) output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { - receiverOutputMatrix.squares[j].template sub(senderOutputMatrices[0].squares[j]).to(output[j]); + receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]); } } @@ -516,56 +511,36 @@ void OTExtensionWithMatrix::print_pre_expand() } template class OTCorrelator; -template void OTCorrelator::correlate(int start, int slice, - BitVector& newReceiverInput, bool useConstantBase, int repeat); #define Z(BM,GF) \ -template void OTCorrelator::correlate(int start, int slice, \ - BitVector& newReceiverInput, bool useConstantBase, int repeat); \ -template void OTCorrelator::expand(int start, int slice); \ +template class OTCorrelator; \ template void OTCorrelator::reduce_squares(unsigned int nTriples, \ vector& output); -template class OTCorrelator>; -Z(Matrix, gf2n_short) - -template class OTCorrelator>>; -Z(Matrix>, gf2n_long) - -template class OTCorrelator>>; -Z(Matrix>, gfp1) - #define ZZZZ(GF) \ template void OTExtensionWithMatrix::print_post_correlate( \ BitVector& newReceiverInput, int j, int offset, int sender); \ -template void OTExtensionWithMatrix::extend(int nOTs_requested, \ - BitVector& newReceiverInput); \ -#define ZZZ(GF, M) \ -template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&); -#define MM Matrix, Z2<160> > > +#define ZZZ(GF, M) Z(M, GF) \ +template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&); -ZZZZ(gfp1) ZZZZ(gf2n_long) -ZZZ(Z2<160>, MM) ZZZ(gf2n_short, Matrix) ZZZ(gf2n_long, Matrix>) ZZZ(gfp1, Matrix>) +ZZZ(BitVec, Matrix) #undef XX #define XX(T,U,N,L) \ template class OTCorrelator, Z2 > > >; \ -template void OTCorrelator, Z2 > > >::correlate(int start, int slice, \ - BitVector& newReceiverInput, bool useConstantBase, int repeat); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ vector& output); \ -template void OTExtensionWithMatrix::hash_outputs, Z2 > > >(int, \ +template void OTExtensionWithMatrix::hash_outputs(int, \ std::vector, Z2 > >, std::allocator, Z2 > > > >&, \ Matrix, Z2 > >&); #undef X #define X(N,L) \ -template void OTCorrelator, Z2 > > >::expand >(int start, int slice); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ vector >& output); \ XX(Z2,Z2,N,L) @@ -579,3 +554,8 @@ Y(64, 48) Y(66, 64) Y(66, 48) Y(32, 32) +Y(1, 40) +Y(72, 48) +Y(74, 48) +Y(72, 64) +Y(74, 64) diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index c76f6f4d..29af01f3 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -39,12 +39,10 @@ public: receiverOutputMatrix(matrices[0]), t1(matrices[1]) {} void resize(int nOTs); - template void expand(int start, int slice); void setup_for_correlation(BitVector& baseReceiverInput, vector& baseSenderOutputs, U& baseReceiverOutput); - template void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); template void reduce_squares(unsigned int nTriples, vector& output); @@ -76,14 +74,12 @@ public: void seed(vector& baseSenderInput, BitMatrix& baseReceiverOutput); void transfer(int nOTs, const BitVector& receiverInput); - template void extend(int nOTs, BitVector& newReceiverInput); - void extend_correlated(BitVector& newReceiverInput); - void extend_correlated(int nOTs, BitVector& newReceiverInput); + void extend_correlated(const BitVector& newReceiverInput); + void extend_correlated(int nOTs, const BitVector& newReceiverInput); void transpose(int start, int slice); - template void expand_transposed(); - template + template void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput); void print(BitVector& newReceiverInput, int i = 0); @@ -100,7 +96,6 @@ public: octet* get_sender_output(int choice, int i); protected: - template void hash_outputs(int nOTs); }; diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 139b1e4f..0c1a6e61 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -79,7 +79,7 @@ public: }; template -class MascotMultiplier : public OTMultiplier> +class MascotMultiplier : public OTMultiplier { OTCorrelator > auth_ot_ext; void after_correlation(); @@ -88,13 +88,32 @@ class MascotMultiplier : public OTMultiplier> const vector& baseReceiverOutput); public: - vector c_output; + vector c_output; - MascotMultiplier(OTTripleGenerator>& generator, int thread_num); + MascotMultiplier(OTTripleGenerator& generator, int thread_num); void multiplyForInputs(MultJob job); }; +template +class TinyMultiplier : public OTMultiplier +{ + OTVole mac_vole; + + void after_correlation(); + void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput); + +public: + vector c_output; + + TinyMultiplier(OTTripleGenerator& generator, int thread_num); + + void multiplyForInputs(MultJob job) { (void) job; throw not_implemented(); } +}; + template class Spdz2kShare; template diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.hpp similarity index 82% rename from OT/OTMultiplier.cpp rename to OT/OTMultiplier.hpp index cdcbe5db..33839f8d 100644 --- a/OT/OTMultiplier.cpp +++ b/OT/OTMultiplier.hpp @@ -9,6 +9,7 @@ #include "OT/NPartyTripleGenerator.h" #include "OT/Rectangle.h" #include "Math/Z2k.h" +#include "Math/BitVec.h" #include "Protocols/SemiShare.h" #include "Protocols/Semi2kShare.h" #include "Protocols/Spdz2kShare.h" @@ -37,14 +38,28 @@ OTMultiplier::OTMultiplier(OTTripleGenerator& generator, } template -MascotMultiplier::MascotMultiplier(OTTripleGenerator>& generator, +MascotMultiplier::MascotMultiplier(OTTripleGenerator& generator, int thread_num) : - OTMultiplier>(generator, thread_num), + OTMultiplier(generator, thread_num), auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true) { c_output.resize(generator.nTriplesPerLoop); } +template +TinyMultiplier::TinyMultiplier(OTTripleGenerator& generator, + int thread_num) : + OTMultiplier(generator, thread_num), + mac_vole( + 128, 128, 0, 1, + generator.players[thread_num], + { }, + { }, + { }, BOTH, false) +{ + c_output.resize(generator.nTriplesPerLoop); +} + template Spdz2kMultiplier::Spdz2kMultiplier(OTTripleGenerator>& generator, int thread_num) : OTMultiplier> @@ -75,7 +90,7 @@ template void OTMultiplier::multiply() { keyBits.set(generator.machine.template get_mac_key()); - rot_ext.extend(keyBits.size(), keyBits); + rot_ext.extend(keyBits.size(), keyBits); this->outbox.push({}); senderOutput.resize(keyBits.size()); for (size_t j = 0; j < keyBits.size(); j++) @@ -123,7 +138,6 @@ void OTMultiplier::multiply() template void OTMultiplier::multiplyForTriples() { - typedef typename W::open_type T; typedef typename W::Rectangle X; // dummy input for OT correlator @@ -148,13 +162,13 @@ void OTMultiplier::multiplyForTriples() BitVector aBits = generator.valueBits[0]; //timers["Extension"].start(); rot_ext.extend_correlated(aBits); - rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); //timers["Extension"].stop(); //timers["Correlation"].start(); otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, baseReceiverOutput); - otCorrelator.template correlate(0, generator.nPreampTriplesPerLoop, + otCorrelator.correlate(0, generator.nPreampTriplesPerLoop, generator.valueBits[1], false, generator.nAmplify); //timers["Correlation"].stop(); @@ -171,6 +185,14 @@ void MascotMultiplier::init_authenticator(const BitVector& keyBits, this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput); } +template +void TinyMultiplier::init_authenticator(const BitVector& keyBits, + const vector >& senderOutput, + const vector& receiverOutput) +{ + mac_vole.init(keyBits, senderOutput, receiverOutput); +} + template void Spdz2kMultiplier::init_authenticator(const BitVector& keyBits, const vector< vector >& senderOutput, @@ -188,9 +210,11 @@ void SemiMultiplier::after_correlation() this->outbox.push({}); } -template -void MascotMultiplier::after_correlation() +template +void MascotMultiplier::after_correlation() { + typedef typename U::open_type T; + this->auth_ot_ext.resize( this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS); this->auth_ot_ext.set_role(BOTH); @@ -210,8 +234,8 @@ void MascotMultiplier::after_correlation() int nValues = this->generator.nTriplesPerLoop; if (this->generator.machine.check && (j % 2 == 0)) nValues *= 2; - this->auth_ot_ext.template expand(0, nValues); - this->auth_ot_ext.template correlate(0, nValues, + this->auth_ot_ext.expand(0, nValues); + this->auth_ot_ext.correlate(0, nValues, this->generator.valueBits[j], true); this->auth_ot_ext.reduce_squares(nValues, this->macs[j]); } @@ -219,6 +243,29 @@ void MascotMultiplier::after_correlation() } } +template +void TinyMultiplier::after_correlation() +{ + this->otCorrelator.reduce_squares(this->generator.nTriplesPerLoop, + this->c_output); + + this->outbox.push({}); + + this->macs.resize(3); + MultJob job; + this->inbox.pop(job); + for (int j = 0; j < 3; j++) + { + int nValues = this->generator.nTriplesPerLoop * T::default_length; + auto& bits = this->generator.valueBits[j]; + vector values(nValues); + for (int i = 0; i < nValues; i++) + values[i] = bits.get_bit(i); + mac_vole.evaluate(this->macs[j], values); + } + this->outbox.push(job); +} + template void Spdz2kMultiplier::after_correlation() { @@ -295,9 +342,9 @@ void OTMultiplier>::multiplyForBits() for (int i = 0; i < generator.nloops; i++) { - auth_ot_ext.expand(0, nBlocks); + auth_ot_ext.expand(0, nBlocks); inbox.pop(job); - auth_ot_ext.correlate(0, nBlocks, generator.valueBits[0], true); + auth_ot_ext.correlate(0, nBlocks, generator.valueBits[0], true); auth_ot_ext.transpose(0, nBlocks); for (int j = 0; j < nBits; j++) @@ -311,8 +358,8 @@ void OTMultiplier>::multiplyForBits() } } -template -void MascotMultiplier::multiplyForInputs(MultJob job) +template +void MascotMultiplier::multiplyForInputs(MultJob job) { assert(job.input); auto& generator = this->generator; @@ -320,10 +367,10 @@ void MascotMultiplier::multiplyForInputs(MultJob job) auth_ot_ext.set_role(mine ? RECEIVER : SENDER); int nOTs = job.n_inputs * generator.field_size; auth_ot_ext.resize(nOTs); - auth_ot_ext.template expand(0, job.n_inputs); + auth_ot_ext.expand(0, job.n_inputs); if (mine) this->inbox.pop(); - auth_ot_ext.template correlate(0, job.n_inputs, generator.valueBits[0], true); + auth_ot_ext.correlate(0, job.n_inputs, generator.valueBits[0], true); auto& input_macs = this->input_macs; input_macs.resize(job.n_inputs); if (mine) @@ -355,24 +402,3 @@ void OTMultiplier::multiplyForBits() { throw runtime_error("bit generation not implemented in this case"); } - -template class OTMultiplier>; -template class OTMultiplier>; -template class OTMultiplier>; -template class OTMultiplier>; -template class SemiMultiplier>; -template class SemiMultiplier>; -template class SemiMultiplier>; -template class SemiMultiplier>; -template class MascotMultiplier; -template class MascotMultiplier; -template class MascotMultiplier; - -#define X(K, S) \ - template class Spdz2kMultiplier; \ - template class OTMultiplier>; -X(64, 64) -X(64, 48) -X(66, 64) -X(66, 48) -X(32, 32) diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 34e5f12e..52ae7e07 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -3,7 +3,6 @@ #include "Networking/Player.h" #include "OT/BaseOT.h" -#include "OT/OTMachine.h" #include "Tools/random.h" #include "Tools/time-func.h" diff --git a/OT/Rectangle.h b/OT/Rectangle.h index 03bf0a7c..56198208 100644 --- a/OT/Rectangle.h +++ b/OT/Rectangle.h @@ -39,21 +39,16 @@ public: Rectangle& operator+=(const Rectangle& other); Rectangle operator-(const Rectangle & other); - template Rectangle& sub(Rectangle& other) { return other.rsub_(*this); } - template Rectangle& rsub(Rectangle& other) { return rsub_(other); } Rectangle& rsub_(Rectangle& other); - template Rectangle& sub(const void* other) { return sub_(other); } Rectangle& sub_(const void* other); void mul(const BitVector& a, const V& b); void randomize(PRNG& G); - template void randomize(int row, PRNG& G) { rows[row].randomize(G); } - template void conditional_add(BitVector& conditions, Rectangle& other, int offset) { conditional_add_(conditions, other, offset); } void conditional_add_(BitVector& conditions, Rectangle& other, diff --git a/OT/Rectangle.hpp b/OT/Rectangle.hpp index 56037ffb..eae2a880 100644 --- a/OT/Rectangle.hpp +++ b/OT/Rectangle.hpp @@ -3,6 +3,9 @@ * */ +#ifndef OT_RECTANGLE_HPP_ +#define OT_RECTANGLE_HPP_ + #include "Rectangle.h" #include "Math/Z2k.h" #include "OT/BitMatrix.h" @@ -114,3 +117,5 @@ void Rectangle::print(int i, int j) (void) j; cout << dec << i << ": " << hex << rows[i] << endl; } + +#endif diff --git a/OT/Triple.hpp b/OT/Triple.hpp index 9565386e..842f13b1 100644 --- a/OT/Triple.hpp +++ b/OT/Triple.hpp @@ -6,6 +6,8 @@ #ifndef OT_TRIPLE_HPP_ #define OT_TRIPLE_HPP_ +template class NPartyTripleGenerator; + template class Triple { diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index a559ee4c..e49906d9 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -12,7 +12,7 @@ #include "Tools/OfflineMachineBase.h" #include "OT/OTTripleSetup.h" -template class NPartyTripleGenerator; +class GeneratorThread; class MascotParams : virtual public OfflineParams { @@ -54,7 +54,7 @@ public: TripleMachine(int argc, const char** argv); template - NPartyTripleGenerator* new_generator(OTTripleSetup& setup, int i); + GeneratorThread* new_generator(OTTripleSetup& setup, int i); void run(); diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index 564d4c51..74a1cfc8 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -20,13 +20,12 @@ void DataPositions::set_num_players(int num_players) void DataPositions::increase(const DataPositions& delta) { - if (inputs.size() != delta.inputs.size()) - throw invalid_length(); + inputs.resize(max(inputs.size(), delta.inputs.size()), vector(N_DATA_FIELD_TYPE)); for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) { for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++) files[field_type][dtype] += delta.files[field_type][dtype]; - for (unsigned int j = 0; j < inputs.size(); j++) + for (unsigned int j = 0; j < delta.inputs.size(); j++) inputs[j][field_type] += delta.inputs[j][field_type]; map::const_iterator it; diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index f33cd880..663dfecf 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -109,6 +109,8 @@ public: void get_input(T& a, typename T::open_type& x, int i); void get(vector& S, DataTag tag, const vector& regs, int vector_size); + virtual array get_triple(int n_bits); + virtual void buffer_triples() {} }; @@ -291,6 +293,15 @@ inline void Preprocessing::get(vector& S, DataTag tag, get_no_count(S, tag, regs, vector_size); } +template +array Preprocessing::get_triple(int n_bits) +{ + (void) n_bits; + array res; + get(DATA_TRIPLE, res.data()); + return res; +} + template inline void Data_Files::purge() { diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index b9545fb2..c155071a 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -1,3 +1,5 @@ +#ifndef PROCESSOR_DATA_FILES_HPP_ +#define PROCESSOR_DATA_FILES_HPP_ #include "Processor/Data_Files.h" #include "Processor/Processor.h" @@ -223,3 +225,5 @@ void Sub_Data_Files::get_no_count(vector& S, DataTag tag, const vector using namespace std; +#include "Math/BitVec.h" + class Player; template class SubProcessor; @@ -31,6 +33,7 @@ public: } }; +template class NotImplementedInput { public: @@ -39,6 +42,10 @@ public: { (void) proc, (void) MC; } + NotImplementedInput(Player& P) + { + (void) P; + } void start(int n, vector regs) { (void) n, (void) regs; @@ -65,6 +72,25 @@ public: (void) proc, (void) regs; throw not_implemented(); } + void reset_all(Player& P) + { + (void) P; + throw not_implemented(); + } + void add_mine(int a, int b) + { + (void) a, (void) b; + throw not_implemented(); + } + void exchange() + { + throw not_implemented(); + } + V finalize(int a, int b) + { + (void) a, (void) b; + throw not_implemented(); + } }; class NotImplementedOutput diff --git a/Processor/FixInput.h b/Processor/FixInput.h index 768a1067..b191a71f 100644 --- a/Processor/FixInput.h +++ b/Processor/FixInput.h @@ -17,6 +17,8 @@ public: const static int N_PARAM = 1; const static char* NAME; + const static int TYPE = 1; + bigint items[N_DEST]; void read(std::istream& in, const int* params); diff --git a/Processor/FloatInput.h b/Processor/FloatInput.h index 14921ab0..8a0b1425 100644 --- a/Processor/FloatInput.h +++ b/Processor/FloatInput.h @@ -17,6 +17,8 @@ public: const static int N_PARAM = 1; const static char* NAME; + const static int TYPE = 2; + long items[N_DEST]; void read(std::istream& in, const int* params); diff --git a/Processor/Input.h b/Processor/Input.h index 17b2f348..ecc9b1f4 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -34,13 +34,19 @@ public: template static void input(SubProcessor& Proc, const vector& args, int size); - InputBase(ArithmeticProcessor* proc); + static void input_mixed(SubProcessor& Proc, const vector& args, int size); + template + static void prepare(SubProcessor& Proc, int player, const int* params, int size); + template + static void finalize(SubProcessor& Proc, int player, const int* params, int size); + + InputBase(ArithmeticProcessor* proc = 0); virtual ~InputBase(); virtual void reset(int player) = 0; void reset_all(Player& P); - virtual void add_mine(const typename T::open_type& input) = 0; + virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; virtual void add_other(int player) = 0; void add_from_all(const clear& input); @@ -48,8 +54,8 @@ public: virtual void exchange(); virtual T finalize_mine() = 0; - virtual void finalize_other(int player, T& target, octetStream& o) = 0; - T finalize(int player); + virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; + T finalize(int player, int n_bits = -1); }; template @@ -59,24 +65,27 @@ class Input : public InputBase typedef typename T::clear clear; typedef typename T::MAC_Check MAC_Check; - SubProcessor& proc; + SubProcessor* proc; MAC_Check& MC; + Preprocessing& prep; + Player& P; vector< PointerVector > shares; open_type rr, t, xi; public: Input(SubProcessor& proc, MAC_Check& mc); Input(SubProcessor* proc, Player& P); + Input(MAC_Check& MC, Preprocessing& prep, Player& P); void reset(int player); - void add_mine(const open_type& input); + void add_mine(const open_type& input, int n_bits = -1); void add_other(int player); void send_mine(); T finalize_mine(); - void finalize_other(int player, T& target, octetStream& o); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); void start(int player, int n_inputs); void stop(int player, const vector& targets); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 4d8d0d65..f2a22350 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -3,9 +3,16 @@ * */ +#ifndef PROCESSOR_INPUT_HPP_ +#define PROCESSOR_INPUT_HPP_ + #include "Input.h" #include "Processor.h" +#include "IntInput.h" +#include "FixInput.h" +#include "FloatInput.h" + template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) @@ -16,19 +23,25 @@ InputBase::InputBase(ArithmeticProcessor* proc) : template Input::Input(SubProcessor& proc, MAC_Check& mc) : - InputBase(&proc.Proc), proc(proc), MC(mc), + InputBase(&proc.Proc), proc(&proc), MC(mc), prep(proc.DataF), P(proc.P), shares(proc.P.num_players()) { } template Input::Input(SubProcessor* proc, Player& P) : - InputBase(&proc->Proc), proc(*proc), MC(proc->MC), shares( - P.num_players()) + InputBase(&proc->Proc), proc(proc), MC(proc->MC), prep(proc->DataF), P( + proc->P), shares(P.num_players()) { assert (proc != 0); } +template +Input::Input(MAC_Check& MC, Preprocessing& prep, Player& P) : + proc(0), MC(MC), prep(prep), P(P), shares(P.num_players()) +{ +} + template InputBase::~InputBase() { @@ -62,13 +75,14 @@ void InputBase::reset_all(Player& P) } template -void Input::add_mine(const open_type& input) +void Input::add_mine(const open_type& input, int n_bits) { - int player = proc.P.my_num(); + (void) n_bits; + int player = P.my_num(); shares[player].push_back({}); T& share = shares[player].back(); - proc.DataF.get_input(share, rr, player); - t.sub(input, rr); + prep.get_input(share, rr, player); + t = input - rr; t.pack(this->os[player]); share += T::constant(t, 0, MC.get_alphai()); this->values_input++; @@ -79,7 +93,7 @@ void Input::add_other(int player) { open_type t; shares[player].push_back({}); - proc.DataF.get_input(shares[player].back(), t, player); + prep.get_input(shares[player].back(), t, player); } template @@ -95,7 +109,7 @@ void InputBase::add_from_all(const clear& input) template void Input::send_mine() { - proc.P.send_all(this->os[proc.P.my_num()], true); + P.send_all(this->os[P.my_num()], true); } template @@ -112,7 +126,7 @@ template void Input::start(int player, int n_inputs) { reset(player); - if (player == proc.P.my_num()) + if (player == P.my_num()) { for (int i = 0; i < n_inputs; i++) { @@ -139,18 +153,19 @@ void Input::start(int player, int n_inputs) template void Input::stop(int player, const vector& targets) { - if (proc.P.my_num() == player) + assert(proc != 0); + if (P.my_num() == player) for (unsigned int i = 0; i < targets.size(); i++) - proc.get_S_ref(targets[i]) = finalize_mine(); + proc->get_S_ref(targets[i]) = finalize_mine(); else { octetStream o; this->timer.start(); - proc.P.receive_player(player, o, true); + P.receive_player(player, o, true); this->timer.stop(); for (unsigned int i = 0; i < targets.size(); i++) { - finalize_other(player, proc.get_S_ref(targets[i]), o); + finalize_other(player, proc->get_S_ref(targets[i]), o); } } } @@ -158,31 +173,66 @@ void Input::stop(int player, const vector& targets) template T Input::finalize_mine() { - return shares[proc.P.my_num()].next(); + return shares[P.my_num()].next(); } template void Input::finalize_other(int player, T& target, - octetStream& o) + octetStream& o, int n_bits) { + (void) n_bits; target = shares[player].next(); t.unpack(o); target += T::constant(t, 1, MC.get_alphai()); } template -T InputBase::finalize(int player) +T InputBase::finalize(int player, int n_bits) { if (player == P->my_num()) return finalize_mine(); else { T res; - finalize_other(player, res, os[player]); + finalize_other(player, res, os[player], n_bits); return res; } } +template +template +void InputBase::prepare(SubProcessor& Proc, int player, const int* params, + int size) +{ + auto& input = Proc.input; + if (player == Proc.P.my_num()) + { + for (int j = 0; j < size; j++) + { + U tuple = Proc.Proc.template get_input(Proc.Proc.use_stdin(), + params); + for (auto x : tuple.items) + input.add_mine(x); + } + } + else + { + for (int j = 0; j < U::N_DEST * size; j++) + input.add_other(player); + } +} + +template +template +void InputBase::finalize(SubProcessor& Proc, int player, const int* dest, + int size) +{ + auto& input = Proc.input; + for (int k = 0; k < size; k++) + for (int j = 0; j < U::N_DEST; j++) + Proc.get_S_ref(dest[j] + k) = input.finalize(player); +} + template template void InputBase::input(SubProcessor& Proc, @@ -195,7 +245,7 @@ void InputBase::input(SubProcessor& Proc, int n_from_me = 0; - if (Proc.Proc.opts.interactive and Proc.Proc.thread_num == 0) + if (Proc.Proc.use_stdin()) { for (size_t i = n_arg_tuple - 1; i < args.size(); i += n_arg_tuple) n_from_me += (args[i] == Proc.P.my_num()) * size; @@ -206,21 +256,7 @@ void InputBase::input(SubProcessor& Proc, for (size_t i = U::N_DEST; i < args.size(); i += n_arg_tuple) { int n = args[i + U::N_PARAM]; - if (n == Proc.P.my_num()) - { - for (int j = 0; j < size; j++) - { - U tuple = Proc.Proc.template get_input(n_from_me > 0, - &args[i]); - for (auto x : tuple.items) - input.add_mine(x); - } - } - else - { - for (int j = 0; j < U::N_DEST * size; j++) - input.add_other(n); - } + InputBase::prepare(Proc, n, &args[i], size); } if (n_from_me > 0) @@ -231,8 +267,63 @@ void InputBase::input(SubProcessor& Proc, for (size_t i = 0; i < args.size(); i += n_arg_tuple) { int player = args[i + n_arg_tuple - 1]; - for (int k = 0; k < size; k++) - for (int j = 0; j < U::N_DEST; j++) - Proc.get_S_ref(args[i + j] + k) = input.finalize(player); + finalize(Proc, player, &args[i], size); } } + +template +void InputBase::input_mixed(SubProcessor& Proc, const vector& args, + int size) +{ + auto& input = Proc.input; + input.reset_all(Proc.P); + int last_type = -1; + + for (size_t i = 0; i < args.size();) + { + int n_arg_tuple; + int type = args[i]; + int player; + switch (type) + { +#undef X +#define X(U) \ + case U::TYPE: \ + n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ + player = args[i + n_arg_tuple - 1]; \ + if (type != last_type and Proc.Proc.use_stdin()) \ + cout << "Please input " << U::NAME << "s:" << endl; \ + prepare(Proc, player, &args[i + U::N_DEST + 1], size); \ + break; + X(IntInput) X(FixInput) X(FloatInput) +#undef X + default: + throw runtime_error("unknown input type: " + to_string(type)); + } + i += n_arg_tuple; + last_type = type; + } + + input.exchange(); + + for (size_t i = 0; i < args.size();) + { + int n_arg_tuple; + int type = args[i]; + switch (type) + { +#define X(U) \ + case U::TYPE: \ + n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ + finalize(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \ + break; + X(IntInput) X(FixInput) X(FloatInput) +#undef X + default: + throw runtime_error("unknown input type: " + to_string(type)); + } + i += n_arg_tuple; + } +} + +#endif diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 7c4883f5..f4772da2 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -92,6 +92,7 @@ enum MULS = 0xA6, MULRS = 0xA7, DOTPRODS = 0xA8, + TRUNC_PR = 0xA9, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -103,6 +104,7 @@ enum INPUT = 0x60, INPUTFIX = 0xF0, INPUTFLOAT = 0xF1, + INPUTMIXED = 0xF2, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -169,6 +171,7 @@ enum READFILESHARE = 0xBE, CONDPRINTSTR = 0xBF, PRINTFLOATPREC = 0xE0, + CONDPRINTPLAIN = 0xE1, // GF(2^n) versions diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e08818f0..8fd8fdbb 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -153,6 +153,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GPROTECTMEMS: case GPROTECTMEMC: case PROTECTMEMINT: + case CONDPRINTPLAIN: r[0]=get_int(s); r[1]=get_int(s); break; @@ -293,6 +294,8 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GINPUT: case INPUTFIX: case INPUTFLOAT: + case INPUTMIXED: + case TRUNC_PR: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; @@ -594,6 +597,11 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Proc.get_S2_ref(r[0] + i).mul(Proc.read_S2(r[1] + i),Proc.read_C2(r[2] + i)); return; + case LDI: + Proc.temp.assign_ansp(n); + for (int i = 0; i < size; i++) + Proc.write_Cp(r[0] + i,Proc.temp.ansp); + return; case ADDC: for (int i = 0; i < size; i++) Proc.get_Cp_ref(r[0] + i).add(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i)); @@ -602,10 +610,24 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i)); return; + case ADDM: + for (int i = 0; i < size; i++) + Proc.get_Sp_ref(r[0] + i).add(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i),Proc.P.my_num(),Proc.MCp.get_alphai()); + return; + case ADDCI: + Proc.temp.assign_ansp(n); + for (int i = 0; i < size; i++) + Proc.get_Cp_ref(r[0] + i).add(Proc.temp.ansp,Proc.read_Cp(r[1] + i)); + return; case SUBS: for (int i = 0; i < size; i++) Proc.get_Sp_ref(r[0] + i).sub(Proc.read_Sp(r[1] + i),Proc.read_Sp(r[2] + i)); return; + case SUBSFI: + Proc.temp.assign_ansp(n); + for (int i = 0; i < size; i++) + Proc.get_Sp_ref(r[0] + i).sub(Proc.temp.ansp,Proc.read_Sp(r[1] + i),Proc.P.my_num(),Proc.MCp.get_alphai()); + return; case MULM: for (int i = 0; i < size; i++) Proc.get_Sp_ref(r[0] + i).mul(Proc.read_Sp(r[1] + i),Proc.read_Cp(r[2] + i)); @@ -614,6 +636,11 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Proc.get_Cp_ref(r[0] + i).mul(Proc.read_Cp(r[1] + i),Proc.read_Cp(r[2] + i)); return; + case MULCI: + Proc.temp.assign_ansp(n); + for (int i = 0; i < size; i++) + Proc.get_Cp_ref(r[0] + i).mul(Proc.temp.ansp,Proc.read_Cp(r[1] + i)); + return; case TRIPLE: for (int i = 0; i < size; i++) Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i), @@ -623,6 +650,33 @@ inline void Instruction::execute(Processor& Proc) const for (int i = 0; i < size; i++) Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0] + i)); return; + case LDINT: + for (int i = 0; i < size; i++) + Proc.write_Ci(r[0] + i, int(n)); + return; + case ADDINT: + for (int i = 0; i < size; i++) + Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) + Proc.read_Ci(r[2] + i); + return; + case SUBINT: + for (int i = 0; i < size; i++) + Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) - Proc.read_Ci(r[2] + i); + return; + case MULINT: + for (int i = 0; i < size; i++) + Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) * Proc.read_Ci(r[2] + i); + return; + case DIVINT: + for (int i = 0; i < size; i++) + Proc.get_Ci_ref(r[0] + i) = Proc.read_Ci(r[1] + i) / Proc.read_Ci(r[2] + i); + return; + case CONVINT: + for (int i = 0; i < size; i++) + { + Proc.temp.assign_ansp(Proc.read_Ci(r[1] + i)); + Proc.get_Cp_ref(r[0] + i) = Proc.temp.ansp; + } + return; } int r[3] = {this->r[0], this->r[1], this->r[2]}; @@ -1003,6 +1057,9 @@ inline void Instruction::execute(Processor& Proc) const case INPUTFLOAT: sint::Input::template input(Proc.Procp, start, size); return; + case INPUTMIXED: + sint::Input::input_mixed(Proc.Procp, start, size); + return; case STARTINPUT: Proc.Procp.input.start(r[0],n); break; @@ -1150,6 +1207,9 @@ inline void Instruction::execute(Processor& Proc) const case GDOTPRODS: Proc.Proc2.dotprods(start, size); return; + case TRUNC_PR: + Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); + return; case JMP: Proc.PC += (signed int) n; break; @@ -1257,6 +1317,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.out << Proc.read_Cp(r[0]) << flush; } break; + case CONDPRINTPLAIN: + if (not Proc.read_Cp(r[0]).is_zero()) + Proc.out << Proc.read_Cp(r[1]) << flush; + break; case GPRINTREGPLAIN: { Proc.out << Proc.read_C2(r[0]) << flush; @@ -1462,9 +1526,6 @@ void Program::execute(Processor& Proc) const { unsigned int size = p.size(); Proc.PC=0; - octet seed[SEED_SIZE]; - memset(seed, 0, SEED_SIZE); - Proc.shared_prng.SetSeed(seed); while (Proc.PC::Machine(int my_number, Names& playerNames, else if (memtype.compare("old")==0) { inpf.open(memory_filename(), ios::in | ios::binary); - if (inpf.fail()) { throw file_error(); } + if (inpf.fail()) { throw file_error(memory_filename()); } inpf >> M2 >> Mp >> Mi; inpf.close(); } diff --git a/Processor/NoLivePrep.h b/Processor/NoLivePrep.h index 77302ba5..600f74ed 100644 --- a/Processor/NoLivePrep.h +++ b/Processor/NoLivePrep.h @@ -21,6 +21,11 @@ public: (void) proc; throw not_implemented(); } + template + NoLivePrep(DataPositions& usage, U& _) : NoLivePrep(0, usage) + { + (void) _; + } }; #endif /* PROCESSOR_NOLIVEPREP_H_ */ diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 0b4bc11f..8f3b2298 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -16,12 +16,16 @@ OnlineOptions::OnlineOptions() : playerno(-1) lgp = gfp::MAX_N_BITS; live_prep = true; batch_size = 10000; + memtype = "empty"; } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size, bool default_live_prep) : OnlineOptions() { + if (default_batch_size <= 0) + default_batch_size = batch_size; + opt.syntax = std::string(argv[0]) + " [OPTIONS] [] "; opt.add( @@ -82,6 +86,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-b", // Flag token. "--batch-size" // Flag token. ); + opt.add( + memtype.c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Where to obtain memory, new|old|empty (default: empty)\n\t" + "new: copy from Player-Memory-P file\n\t" + "old: reuse previous memory in Memory-P\n\t" + "empty: create new empty memory", // Help description. + "-m", // Flag token. + "--memory" // Flag token. + ); opt.parse(argc, argv); @@ -92,6 +108,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, else live_prep = opt.get("-L")->isSet; opt.get("-b")->getInt(batch_size); + opt.get("--memory")->getString(memtype); opt.resetArgs(); } diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 323e4f22..ccd54ab9 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -19,10 +19,11 @@ public: int playerno; std::string progname; int batch_size; + std::string memtype; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, - int default_batch_size = 10000, bool default_live_prep = true); + int default_batch_size = 0, bool default_live_prep = true); void finalize(ez::ezOptionParser& opt, int argc, const char** argv); }; diff --git a/Processor/Processor.h b/Processor/Processor.h index 48651e13..4baa9281 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -96,6 +96,11 @@ public: ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), sent(0), rounds(0), opts(opts) {} + + bool use_stdin() + { + return thread_num == 0 and opts.interactive; + } }; template diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 5b813873..0a96bd74 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -42,6 +42,7 @@ Processor::Processor(int thread_num,Player& P, open_input_file(P.my_num(), thread_num); secure_prng.ReSeed(); + shared_prng.SeedGlobally(P); out.activate(P.my_num() == 0 or machine.opts.interactive); } diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.cpp index 9b0dc38c..03f40813 100644 --- a/Processor/ProcessorBase.cpp +++ b/Processor/ProcessorBase.cpp @@ -43,7 +43,14 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co if (input_file.eof()) throw IO_Error("not enough inputs in " + input_filename); if (input_file.fail()) - throw IO_Error(string() + "cannot read " + T::NAME + " from " + input_filename); + { + input_file.clear(); + string token; + input_file >> token; + throw IO_Error( + string() + "cannot read " + T::NAME + " from " + input_filename + + ", problem with '" + token + "'"); + } return res; } diff --git a/Programs/Source/benchmark_mobilenet.mpc b/Programs/Source/benchmark_mobilenet.mpc new file mode 100644 index 00000000..08afad12 --- /dev/null +++ b/Programs/Source/benchmark_mobilenet.mpc @@ -0,0 +1,576 @@ +import ml + +network = program.args[1] + +if len(program.args) > 2: + if program.args[2] == '1': + program.use_trunc_pr = True + elif program.args[2] == '2': + squant.round_nearest = True + elif not program.args[2] == '0': + raise Exception('option invalid') + +ml.QuantBase.n_threads = 8 + +if len(program.args) > 3: + ml.QuantBase.n_threads = int(program.args[3]) + +from ml import * + +if network == 'v1_0.25_128': + layers = [ + QuantConv2d((1, 128, 128, 3), (8, 3, 3, 3), (8,), (1, 64, 64, 8), (2, 2)), + QuantDepthwiseConv2d((1, 64, 64, 8), (1, 3, 3, 8), (8,), (1, 64, 64, 8), (1, 1)), + QuantConv2d((1, 64, 64, 8), (16, 1, 1, 8), (16,), (1, 64, 64, 16), (1, 1)), + QuantDepthwiseConv2d((1, 64, 64, 16), (1, 3, 3, 16), (16,), (1, 32, 32, 16), (2, 2)), + QuantConv2d((1, 32, 32, 16), (32, 1, 1, 16), (32,), (1, 32, 32, 32), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 32), (1, 3, 3, 32), (32,), (1, 32, 32, 32), (1, 1)), + QuantConv2d((1, 32, 32, 32), (32, 1, 1, 32), (32,), (1, 32, 32, 32), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 32), (1, 3, 3, 32), (32,), (1, 16, 16, 32), (2, 2)), + QuantConv2d((1, 16, 16, 32), (64, 1, 1, 32), (64,), (1, 16, 16, 64), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 64), (1, 3, 3, 64), (64,), (1, 16, 16, 64), (1, 1)), + QuantConv2d((1, 16, 16, 64), (64, 1, 1, 64), (64,), (1, 16, 16, 64), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 64), (1, 3, 3, 64), (64,), (1, 8, 8, 64), (2, 2)), + QuantConv2d((1, 8, 8, 64), (128, 1, 1, 64), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantConv2d((1, 8, 8, 128), (128, 1, 1, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantConv2d((1, 8, 8, 128), (128, 1, 1, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantConv2d((1, 8, 8, 128), (128, 1, 1, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantConv2d((1, 8, 8, 128), (128, 1, 1, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantConv2d((1, 8, 8, 128), (128, 1, 1, 128), (128,), (1, 8, 8, 128), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 128), (1, 3, 3, 128), (128,), (1, 4, 4, 128), (2, 2)), + QuantConv2d((1, 4, 4, 128), (256, 1, 1, 128), (256,), (1, 4, 4, 256), (1, 1)), + QuantDepthwiseConv2d((1, 4, 4, 256), (1, 3, 3, 256), (256,), (1, 4, 4, 256), (1, 1)), + QuantConv2d((1, 4, 4, 256), (256, 1, 1, 256), (256,), (1, 4, 4, 256), (1, 1)), + QuantAveragePool2d((1, 4, 4, 256), (1, 1, 1, 256), (4, 4)), + QuantConv2d((1, 1, 1, 256), (1001, 1, 1, 256), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.25_160': + layers = [ + QuantConv2d((1, 160, 160, 3), (8, 3, 3, 3), (8,), (1, 80, 80, 8), (2, 2)), + QuantDepthwiseConv2d((1, 80, 80, 8), (1, 3, 3, 8), (8,), (1, 80, 80, 8), (1, 1)), + QuantConv2d((1, 80, 80, 8), (16, 1, 1, 8), (16,), (1, 80, 80, 16), (1, 1)), + QuantDepthwiseConv2d((1, 80, 80, 16), (1, 3, 3, 16), (16,), (1, 40, 40, 16), (2, 2)), + QuantConv2d((1, 40, 40, 16), (32, 1, 1, 16), (32,), (1, 40, 40, 32), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 32), (1, 3, 3, 32), (32,), (1, 40, 40, 32), (1, 1)), + QuantConv2d((1, 40, 40, 32), (32, 1, 1, 32), (32,), (1, 40, 40, 32), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 32), (1, 3, 3, 32), (32,), (1, 20, 20, 32), (2, 2)), + QuantConv2d((1, 20, 20, 32), (64, 1, 1, 32), (64,), (1, 20, 20, 64), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 64), (1, 3, 3, 64), (64,), (1, 20, 20, 64), (1, 1)), + QuantConv2d((1, 20, 20, 64), (64, 1, 1, 64), (64,), (1, 20, 20, 64), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 64), (1, 3, 3, 64), (64,), (1, 10, 10, 64), (2, 2)), + QuantConv2d((1, 10, 10, 64), (128, 1, 1, 64), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantConv2d((1, 10, 10, 128), (128, 1, 1, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantConv2d((1, 10, 10, 128), (128, 1, 1, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantConv2d((1, 10, 10, 128), (128, 1, 1, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantConv2d((1, 10, 10, 128), (128, 1, 1, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantConv2d((1, 10, 10, 128), (128, 1, 1, 128), (128,), (1, 10, 10, 128), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 128), (1, 3, 3, 128), (128,), (1, 5, 5, 128), (2, 2)), + QuantConv2d((1, 5, 5, 128), (256, 1, 1, 128), (256,), (1, 5, 5, 256), (1, 1)), + QuantDepthwiseConv2d((1, 5, 5, 256), (1, 3, 3, 256), (256,), (1, 5, 5, 256), (1, 1)), + QuantConv2d((1, 5, 5, 256), (256, 1, 1, 256), (256,), (1, 5, 5, 256), (1, 1)), + QuantAveragePool2d((1, 5, 5, 256), (1, 1, 1, 256), (5, 5)), + QuantConv2d((1, 1, 1, 256), (1001, 1, 1, 256), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.25_192': + layers = [ + QuantConv2d((1, 192, 192, 3), (8, 3, 3, 3), (8,), (1, 96, 96, 8), (2, 2)), + QuantDepthwiseConv2d((1, 96, 96, 8), (1, 3, 3, 8), (8,), (1, 96, 96, 8), (1, 1)), + QuantConv2d((1, 96, 96, 8), (16, 1, 1, 8), (16,), (1, 96, 96, 16), (1, 1)), + QuantDepthwiseConv2d((1, 96, 96, 16), (1, 3, 3, 16), (16,), (1, 48, 48, 16), (2, 2)), + QuantConv2d((1, 48, 48, 16), (32, 1, 1, 16), (32,), (1, 48, 48, 32), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 32), (1, 3, 3, 32), (32,), (1, 48, 48, 32), (1, 1)), + QuantConv2d((1, 48, 48, 32), (32, 1, 1, 32), (32,), (1, 48, 48, 32), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 32), (1, 3, 3, 32), (32,), (1, 24, 24, 32), (2, 2)), + QuantConv2d((1, 24, 24, 32), (64, 1, 1, 32), (64,), (1, 24, 24, 64), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 64), (1, 3, 3, 64), (64,), (1, 24, 24, 64), (1, 1)), + QuantConv2d((1, 24, 24, 64), (64, 1, 1, 64), (64,), (1, 24, 24, 64), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 64), (1, 3, 3, 64), (64,), (1, 12, 12, 64), (2, 2)), + QuantConv2d((1, 12, 12, 64), (128, 1, 1, 64), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantConv2d((1, 12, 12, 128), (128, 1, 1, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantConv2d((1, 12, 12, 128), (128, 1, 1, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantConv2d((1, 12, 12, 128), (128, 1, 1, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantConv2d((1, 12, 12, 128), (128, 1, 1, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantConv2d((1, 12, 12, 128), (128, 1, 1, 128), (128,), (1, 12, 12, 128), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 128), (1, 3, 3, 128), (128,), (1, 6, 6, 128), (2, 2)), + QuantConv2d((1, 6, 6, 128), (256, 1, 1, 128), (256,), (1, 6, 6, 256), (1, 1)), + QuantDepthwiseConv2d((1, 6, 6, 256), (1, 3, 3, 256), (256,), (1, 6, 6, 256), (1, 1)), + QuantConv2d((1, 6, 6, 256), (256, 1, 1, 256), (256,), (1, 6, 6, 256), (1, 1)), + QuantAveragePool2d((1, 6, 6, 256), (1, 1, 1, 256), (6, 6)), + QuantConv2d((1, 1, 1, 256), (1001, 1, 1, 256), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.25_224': + layers = [ + QuantConv2d((1, 224, 224, 3), (8, 3, 3, 3), (8,), (1, 112, 112, 8), (2, 2)), + QuantDepthwiseConv2d((1, 112, 112, 8), (1, 3, 3, 8), (8,), (1, 112, 112, 8), (1, 1)), + QuantConv2d((1, 112, 112, 8), (16, 1, 1, 8), (16,), (1, 112, 112, 16), (1, 1)), + QuantDepthwiseConv2d((1, 112, 112, 16), (1, 3, 3, 16), (16,), (1, 56, 56, 16), (2, 2)), + QuantConv2d((1, 56, 56, 16), (32, 1, 1, 16), (32,), (1, 56, 56, 32), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 32), (1, 3, 3, 32), (32,), (1, 56, 56, 32), (1, 1)), + QuantConv2d((1, 56, 56, 32), (32, 1, 1, 32), (32,), (1, 56, 56, 32), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 32), (1, 3, 3, 32), (32,), (1, 28, 28, 32), (2, 2)), + QuantConv2d((1, 28, 28, 32), (64, 1, 1, 32), (64,), (1, 28, 28, 64), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 64), (1, 3, 3, 64), (64,), (1, 28, 28, 64), (1, 1)), + QuantConv2d((1, 28, 28, 64), (64, 1, 1, 64), (64,), (1, 28, 28, 64), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 64), (1, 3, 3, 64), (64,), (1, 14, 14, 64), (2, 2)), + QuantConv2d((1, 14, 14, 64), (128, 1, 1, 64), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantConv2d((1, 14, 14, 128), (128, 1, 1, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantConv2d((1, 14, 14, 128), (128, 1, 1, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantConv2d((1, 14, 14, 128), (128, 1, 1, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantConv2d((1, 14, 14, 128), (128, 1, 1, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantConv2d((1, 14, 14, 128), (128, 1, 1, 128), (128,), (1, 14, 14, 128), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 128), (1, 3, 3, 128), (128,), (1, 7, 7, 128), (2, 2)), + QuantConv2d((1, 7, 7, 128), (256, 1, 1, 128), (256,), (1, 7, 7, 256), (1, 1)), + QuantDepthwiseConv2d((1, 7, 7, 256), (1, 3, 3, 256), (256,), (1, 7, 7, 256), (1, 1)), + QuantConv2d((1, 7, 7, 256), (256, 1, 1, 256), (256,), (1, 7, 7, 256), (1, 1)), + QuantAveragePool2d((1, 7, 7, 256), (1, 1, 1, 256), (7, 7)), + QuantConv2d((1, 1, 1, 256), (1001, 1, 1, 256), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.5_128': + layers = [ + QuantConv2d((1, 128, 128, 3), (16, 3, 3, 3), (16,), (1, 64, 64, 16), (2, 2)), + QuantDepthwiseConv2d((1, 64, 64, 16), (1, 3, 3, 16), (16,), (1, 64, 64, 16), (1, 1)), + QuantConv2d((1, 64, 64, 16), (32, 1, 1, 16), (32,), (1, 64, 64, 32), (1, 1)), + QuantDepthwiseConv2d((1, 64, 64, 32), (1, 3, 3, 32), (32,), (1, 32, 32, 32), (2, 2)), + QuantConv2d((1, 32, 32, 32), (64, 1, 1, 32), (64,), (1, 32, 32, 64), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 64), (1, 3, 3, 64), (64,), (1, 32, 32, 64), (1, 1)), + QuantConv2d((1, 32, 32, 64), (64, 1, 1, 64), (64,), (1, 32, 32, 64), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 64), (1, 3, 3, 64), (64,), (1, 16, 16, 64), (2, 2)), + QuantConv2d((1, 16, 16, 64), (128, 1, 1, 64), (128,), (1, 16, 16, 128), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 128), (1, 3, 3, 128), (128,), (1, 16, 16, 128), (1, 1)), + QuantConv2d((1, 16, 16, 128), (128, 1, 1, 128), (128,), (1, 16, 16, 128), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 128), (1, 3, 3, 128), (128,), (1, 8, 8, 128), (2, 2)), + QuantConv2d((1, 8, 8, 128), (256, 1, 1, 128), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantConv2d((1, 8, 8, 256), (256, 1, 1, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantConv2d((1, 8, 8, 256), (256, 1, 1, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantConv2d((1, 8, 8, 256), (256, 1, 1, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantConv2d((1, 8, 8, 256), (256, 1, 1, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantConv2d((1, 8, 8, 256), (256, 1, 1, 256), (256,), (1, 8, 8, 256), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 256), (1, 3, 3, 256), (256,), (1, 4, 4, 256), (2, 2)), + QuantConv2d((1, 4, 4, 256), (512, 1, 1, 256), (512,), (1, 4, 4, 512), (1, 1)), + QuantDepthwiseConv2d((1, 4, 4, 512), (1, 3, 3, 512), (512,), (1, 4, 4, 512), (1, 1)), + QuantConv2d((1, 4, 4, 512), (512, 1, 1, 512), (512,), (1, 4, 4, 512), (1, 1)), + QuantAveragePool2d((1, 4, 4, 512), (1, 1, 1, 512), (4, 4)), + QuantConv2d((1, 1, 1, 512), (1001, 1, 1, 512), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.5_160': + layers = [ + QuantConv2d((1, 160, 160, 3), (16, 3, 3, 3), (16,), (1, 80, 80, 16), (2, 2)), + QuantDepthwiseConv2d((1, 80, 80, 16), (1, 3, 3, 16), (16,), (1, 80, 80, 16), (1, 1)), + QuantConv2d((1, 80, 80, 16), (32, 1, 1, 16), (32,), (1, 80, 80, 32), (1, 1)), + QuantDepthwiseConv2d((1, 80, 80, 32), (1, 3, 3, 32), (32,), (1, 40, 40, 32), (2, 2)), + QuantConv2d((1, 40, 40, 32), (64, 1, 1, 32), (64,), (1, 40, 40, 64), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 64), (1, 3, 3, 64), (64,), (1, 40, 40, 64), (1, 1)), + QuantConv2d((1, 40, 40, 64), (64, 1, 1, 64), (64,), (1, 40, 40, 64), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 64), (1, 3, 3, 64), (64,), (1, 20, 20, 64), (2, 2)), + QuantConv2d((1, 20, 20, 64), (128, 1, 1, 64), (128,), (1, 20, 20, 128), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 128), (1, 3, 3, 128), (128,), (1, 20, 20, 128), (1, 1)), + QuantConv2d((1, 20, 20, 128), (128, 1, 1, 128), (128,), (1, 20, 20, 128), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 128), (1, 3, 3, 128), (128,), (1, 10, 10, 128), (2, 2)), + QuantConv2d((1, 10, 10, 128), (256, 1, 1, 128), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantConv2d((1, 10, 10, 256), (256, 1, 1, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantConv2d((1, 10, 10, 256), (256, 1, 1, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantConv2d((1, 10, 10, 256), (256, 1, 1, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantConv2d((1, 10, 10, 256), (256, 1, 1, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantConv2d((1, 10, 10, 256), (256, 1, 1, 256), (256,), (1, 10, 10, 256), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 256), (1, 3, 3, 256), (256,), (1, 5, 5, 256), (2, 2)), + QuantConv2d((1, 5, 5, 256), (512, 1, 1, 256), (512,), (1, 5, 5, 512), (1, 1)), + QuantDepthwiseConv2d((1, 5, 5, 512), (1, 3, 3, 512), (512,), (1, 5, 5, 512), (1, 1)), + QuantConv2d((1, 5, 5, 512), (512, 1, 1, 512), (512,), (1, 5, 5, 512), (1, 1)), + QuantAveragePool2d((1, 5, 5, 512), (1, 1, 1, 512), (5, 5)), + QuantConv2d((1, 1, 1, 512), (1001, 1, 1, 512), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.5_192': + layers = [ + QuantConv2d((1, 192, 192, 3), (16, 3, 3, 3), (16,), (1, 96, 96, 16), (2, 2)), + QuantDepthwiseConv2d((1, 96, 96, 16), (1, 3, 3, 16), (16,), (1, 96, 96, 16), (1, 1)), + QuantConv2d((1, 96, 96, 16), (32, 1, 1, 16), (32,), (1, 96, 96, 32), (1, 1)), + QuantDepthwiseConv2d((1, 96, 96, 32), (1, 3, 3, 32), (32,), (1, 48, 48, 32), (2, 2)), + QuantConv2d((1, 48, 48, 32), (64, 1, 1, 32), (64,), (1, 48, 48, 64), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 64), (1, 3, 3, 64), (64,), (1, 48, 48, 64), (1, 1)), + QuantConv2d((1, 48, 48, 64), (64, 1, 1, 64), (64,), (1, 48, 48, 64), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 64), (1, 3, 3, 64), (64,), (1, 24, 24, 64), (2, 2)), + QuantConv2d((1, 24, 24, 64), (128, 1, 1, 64), (128,), (1, 24, 24, 128), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 128), (1, 3, 3, 128), (128,), (1, 24, 24, 128), (1, 1)), + QuantConv2d((1, 24, 24, 128), (128, 1, 1, 128), (128,), (1, 24, 24, 128), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 128), (1, 3, 3, 128), (128,), (1, 12, 12, 128), (2, 2)), + QuantConv2d((1, 12, 12, 128), (256, 1, 1, 128), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantConv2d((1, 12, 12, 256), (256, 1, 1, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantConv2d((1, 12, 12, 256), (256, 1, 1, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantConv2d((1, 12, 12, 256), (256, 1, 1, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantConv2d((1, 12, 12, 256), (256, 1, 1, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantConv2d((1, 12, 12, 256), (256, 1, 1, 256), (256,), (1, 12, 12, 256), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 256), (1, 3, 3, 256), (256,), (1, 6, 6, 256), (2, 2)), + QuantConv2d((1, 6, 6, 256), (512, 1, 1, 256), (512,), (1, 6, 6, 512), (1, 1)), + QuantDepthwiseConv2d((1, 6, 6, 512), (1, 3, 3, 512), (512,), (1, 6, 6, 512), (1, 1)), + QuantConv2d((1, 6, 6, 512), (512, 1, 1, 512), (512,), (1, 6, 6, 512), (1, 1)), + QuantAveragePool2d((1, 6, 6, 512), (1, 1, 1, 512), (6, 6)), + QuantConv2d((1, 1, 1, 512), (1001, 1, 1, 512), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.5_224': + layers = [ + QuantConv2d((1, 224, 224, 3), (16, 3, 3, 3), (16,), (1, 112, 112, 16), (2, 2)), + QuantDepthwiseConv2d((1, 112, 112, 16), (1, 3, 3, 16), (16,), (1, 112, 112, 16), (1, 1)), + QuantConv2d((1, 112, 112, 16), (32, 1, 1, 16), (32,), (1, 112, 112, 32), (1, 1)), + QuantDepthwiseConv2d((1, 112, 112, 32), (1, 3, 3, 32), (32,), (1, 56, 56, 32), (2, 2)), + QuantConv2d((1, 56, 56, 32), (64, 1, 1, 32), (64,), (1, 56, 56, 64), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 64), (1, 3, 3, 64), (64,), (1, 56, 56, 64), (1, 1)), + QuantConv2d((1, 56, 56, 64), (64, 1, 1, 64), (64,), (1, 56, 56, 64), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 64), (1, 3, 3, 64), (64,), (1, 28, 28, 64), (2, 2)), + QuantConv2d((1, 28, 28, 64), (128, 1, 1, 64), (128,), (1, 28, 28, 128), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 128), (1, 3, 3, 128), (128,), (1, 28, 28, 128), (1, 1)), + QuantConv2d((1, 28, 28, 128), (128, 1, 1, 128), (128,), (1, 28, 28, 128), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 128), (1, 3, 3, 128), (128,), (1, 14, 14, 128), (2, 2)), + QuantConv2d((1, 14, 14, 128), (256, 1, 1, 128), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantConv2d((1, 14, 14, 256), (256, 1, 1, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantConv2d((1, 14, 14, 256), (256, 1, 1, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantConv2d((1, 14, 14, 256), (256, 1, 1, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantConv2d((1, 14, 14, 256), (256, 1, 1, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantConv2d((1, 14, 14, 256), (256, 1, 1, 256), (256,), (1, 14, 14, 256), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 256), (1, 3, 3, 256), (256,), (1, 7, 7, 256), (2, 2)), + QuantConv2d((1, 7, 7, 256), (512, 1, 1, 256), (512,), (1, 7, 7, 512), (1, 1)), + QuantDepthwiseConv2d((1, 7, 7, 512), (1, 3, 3, 512), (512,), (1, 7, 7, 512), (1, 1)), + QuantConv2d((1, 7, 7, 512), (512, 1, 1, 512), (512,), (1, 7, 7, 512), (1, 1)), + QuantAveragePool2d((1, 7, 7, 512), (1, 1, 1, 512), (7, 7)), + QuantConv2d((1, 1, 1, 512), (1001, 1, 1, 512), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.75_128': + layers = [ + QuantConv2d((1, 128, 128, 3), (24, 3, 3, 3), (24,), (1, 64, 64, 24), (2, 2)), + QuantDepthwiseConv2d((1, 64, 64, 24), (1, 3, 3, 24), (24,), (1, 64, 64, 24), (1, 1)), + QuantConv2d((1, 64, 64, 24), (48, 1, 1, 24), (48,), (1, 64, 64, 48), (1, 1)), + QuantDepthwiseConv2d((1, 64, 64, 48), (1, 3, 3, 48), (48,), (1, 32, 32, 48), (2, 2)), + QuantConv2d((1, 32, 32, 48), (96, 1, 1, 48), (96,), (1, 32, 32, 96), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 96), (1, 3, 3, 96), (96,), (1, 32, 32, 96), (1, 1)), + QuantConv2d((1, 32, 32, 96), (96, 1, 1, 96), (96,), (1, 32, 32, 96), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 96), (1, 3, 3, 96), (96,), (1, 16, 16, 96), (2, 2)), + QuantConv2d((1, 16, 16, 96), (192, 1, 1, 96), (192,), (1, 16, 16, 192), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 192), (1, 3, 3, 192), (192,), (1, 16, 16, 192), (1, 1)), + QuantConv2d((1, 16, 16, 192), (192, 1, 1, 192), (192,), (1, 16, 16, 192), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 192), (1, 3, 3, 192), (192,), (1, 8, 8, 192), (2, 2)), + QuantConv2d((1, 8, 8, 192), (384, 1, 1, 192), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantConv2d((1, 8, 8, 384), (384, 1, 1, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantConv2d((1, 8, 8, 384), (384, 1, 1, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantConv2d((1, 8, 8, 384), (384, 1, 1, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantConv2d((1, 8, 8, 384), (384, 1, 1, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantConv2d((1, 8, 8, 384), (384, 1, 1, 384), (384,), (1, 8, 8, 384), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 384), (1, 3, 3, 384), (384,), (1, 4, 4, 384), (2, 2)), + QuantConv2d((1, 4, 4, 384), (768, 1, 1, 384), (768,), (1, 4, 4, 768), (1, 1)), + QuantDepthwiseConv2d((1, 4, 4, 768), (1, 3, 3, 768), (768,), (1, 4, 4, 768), (1, 1)), + QuantConv2d((1, 4, 4, 768), (768, 1, 1, 768), (768,), (1, 4, 4, 768), (1, 1)), + QuantAveragePool2d((1, 4, 4, 768), (1, 1, 1, 768), (4, 4)), + QuantConv2d((1, 1, 1, 768), (1001, 1, 1, 768), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.75_160': + layers = [ + QuantConv2d((1, 160, 160, 3), (24, 3, 3, 3), (24,), (1, 80, 80, 24), (2, 2)), + QuantDepthwiseConv2d((1, 80, 80, 24), (1, 3, 3, 24), (24,), (1, 80, 80, 24), (1, 1)), + QuantConv2d((1, 80, 80, 24), (48, 1, 1, 24), (48,), (1, 80, 80, 48), (1, 1)), + QuantDepthwiseConv2d((1, 80, 80, 48), (1, 3, 3, 48), (48,), (1, 40, 40, 48), (2, 2)), + QuantConv2d((1, 40, 40, 48), (96, 1, 1, 48), (96,), (1, 40, 40, 96), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 96), (1, 3, 3, 96), (96,), (1, 40, 40, 96), (1, 1)), + QuantConv2d((1, 40, 40, 96), (96, 1, 1, 96), (96,), (1, 40, 40, 96), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 96), (1, 3, 3, 96), (96,), (1, 20, 20, 96), (2, 2)), + QuantConv2d((1, 20, 20, 96), (192, 1, 1, 96), (192,), (1, 20, 20, 192), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 192), (1, 3, 3, 192), (192,), (1, 20, 20, 192), (1, 1)), + QuantConv2d((1, 20, 20, 192), (192, 1, 1, 192), (192,), (1, 20, 20, 192), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 192), (1, 3, 3, 192), (192,), (1, 10, 10, 192), (2, 2)), + QuantConv2d((1, 10, 10, 192), (384, 1, 1, 192), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantConv2d((1, 10, 10, 384), (384, 1, 1, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantConv2d((1, 10, 10, 384), (384, 1, 1, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantConv2d((1, 10, 10, 384), (384, 1, 1, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantConv2d((1, 10, 10, 384), (384, 1, 1, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantConv2d((1, 10, 10, 384), (384, 1, 1, 384), (384,), (1, 10, 10, 384), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 384), (1, 3, 3, 384), (384,), (1, 5, 5, 384), (2, 2)), + QuantConv2d((1, 5, 5, 384), (768, 1, 1, 384), (768,), (1, 5, 5, 768), (1, 1)), + QuantDepthwiseConv2d((1, 5, 5, 768), (1, 3, 3, 768), (768,), (1, 5, 5, 768), (1, 1)), + QuantConv2d((1, 5, 5, 768), (768, 1, 1, 768), (768,), (1, 5, 5, 768), (1, 1)), + QuantAveragePool2d((1, 5, 5, 768), (1, 1, 1, 768), (5, 5)), + QuantConv2d((1, 1, 1, 768), (1001, 1, 1, 768), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.75_192': + layers = [ + QuantConv2d((1, 192, 192, 3), (24, 3, 3, 3), (24,), (1, 96, 96, 24), (2, 2)), + QuantDepthwiseConv2d((1, 96, 96, 24), (1, 3, 3, 24), (24,), (1, 96, 96, 24), (1, 1)), + QuantConv2d((1, 96, 96, 24), (48, 1, 1, 24), (48,), (1, 96, 96, 48), (1, 1)), + QuantDepthwiseConv2d((1, 96, 96, 48), (1, 3, 3, 48), (48,), (1, 48, 48, 48), (2, 2)), + QuantConv2d((1, 48, 48, 48), (96, 1, 1, 48), (96,), (1, 48, 48, 96), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 96), (1, 3, 3, 96), (96,), (1, 48, 48, 96), (1, 1)), + QuantConv2d((1, 48, 48, 96), (96, 1, 1, 96), (96,), (1, 48, 48, 96), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 96), (1, 3, 3, 96), (96,), (1, 24, 24, 96), (2, 2)), + QuantConv2d((1, 24, 24, 96), (192, 1, 1, 96), (192,), (1, 24, 24, 192), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 192), (1, 3, 3, 192), (192,), (1, 24, 24, 192), (1, 1)), + QuantConv2d((1, 24, 24, 192), (192, 1, 1, 192), (192,), (1, 24, 24, 192), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 192), (1, 3, 3, 192), (192,), (1, 12, 12, 192), (2, 2)), + QuantConv2d((1, 12, 12, 192), (384, 1, 1, 192), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantConv2d((1, 12, 12, 384), (384, 1, 1, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantConv2d((1, 12, 12, 384), (384, 1, 1, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantConv2d((1, 12, 12, 384), (384, 1, 1, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantConv2d((1, 12, 12, 384), (384, 1, 1, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantConv2d((1, 12, 12, 384), (384, 1, 1, 384), (384,), (1, 12, 12, 384), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 384), (1, 3, 3, 384), (384,), (1, 6, 6, 384), (2, 2)), + QuantConv2d((1, 6, 6, 384), (768, 1, 1, 384), (768,), (1, 6, 6, 768), (1, 1)), + QuantDepthwiseConv2d((1, 6, 6, 768), (1, 3, 3, 768), (768,), (1, 6, 6, 768), (1, 1)), + QuantConv2d((1, 6, 6, 768), (768, 1, 1, 768), (768,), (1, 6, 6, 768), (1, 1)), + QuantAveragePool2d((1, 6, 6, 768), (1, 1, 1, 768), (6, 6)), + QuantConv2d((1, 1, 1, 768), (1001, 1, 1, 768), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_0.75_224': + layers = [ + QuantConv2d((1, 224, 224, 3), (24, 3, 3, 3), (24,), (1, 112, 112, 24), (2, 2)), + QuantDepthwiseConv2d((1, 112, 112, 24), (1, 3, 3, 24), (24,), (1, 112, 112, 24), (1, 1)), + QuantConv2d((1, 112, 112, 24), (48, 1, 1, 24), (48,), (1, 112, 112, 48), (1, 1)), + QuantDepthwiseConv2d((1, 112, 112, 48), (1, 3, 3, 48), (48,), (1, 56, 56, 48), (2, 2)), + QuantConv2d((1, 56, 56, 48), (96, 1, 1, 48), (96,), (1, 56, 56, 96), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 96), (1, 3, 3, 96), (96,), (1, 56, 56, 96), (1, 1)), + QuantConv2d((1, 56, 56, 96), (96, 1, 1, 96), (96,), (1, 56, 56, 96), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 96), (1, 3, 3, 96), (96,), (1, 28, 28, 96), (2, 2)), + QuantConv2d((1, 28, 28, 96), (192, 1, 1, 96), (192,), (1, 28, 28, 192), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 192), (1, 3, 3, 192), (192,), (1, 28, 28, 192), (1, 1)), + QuantConv2d((1, 28, 28, 192), (192, 1, 1, 192), (192,), (1, 28, 28, 192), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 192), (1, 3, 3, 192), (192,), (1, 14, 14, 192), (2, 2)), + QuantConv2d((1, 14, 14, 192), (384, 1, 1, 192), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantConv2d((1, 14, 14, 384), (384, 1, 1, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantConv2d((1, 14, 14, 384), (384, 1, 1, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantConv2d((1, 14, 14, 384), (384, 1, 1, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantConv2d((1, 14, 14, 384), (384, 1, 1, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantConv2d((1, 14, 14, 384), (384, 1, 1, 384), (384,), (1, 14, 14, 384), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 384), (1, 3, 3, 384), (384,), (1, 7, 7, 384), (2, 2)), + QuantConv2d((1, 7, 7, 384), (768, 1, 1, 384), (768,), (1, 7, 7, 768), (1, 1)), + QuantDepthwiseConv2d((1, 7, 7, 768), (1, 3, 3, 768), (768,), (1, 7, 7, 768), (1, 1)), + QuantConv2d((1, 7, 7, 768), (768, 1, 1, 768), (768,), (1, 7, 7, 768), (1, 1)), + QuantAveragePool2d((1, 7, 7, 768), (1, 1, 1, 768), (7, 7)), + QuantConv2d((1, 1, 1, 768), (1001, 1, 1, 768), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_1.0_128': + layers = [ + QuantConv2d((1, 128, 128, 3), (32, 3, 3, 3), (32,), (1, 64, 64, 32), (2, 2)), + QuantDepthwiseConv2d((1, 64, 64, 32), (1, 3, 3, 32), (32,), (1, 64, 64, 32), (1, 1)), + QuantConv2d((1, 64, 64, 32), (64, 1, 1, 32), (64,), (1, 64, 64, 64), (1, 1)), + QuantDepthwiseConv2d((1, 64, 64, 64), (1, 3, 3, 64), (64,), (1, 32, 32, 64), (2, 2)), + QuantConv2d((1, 32, 32, 64), (128, 1, 1, 64), (128,), (1, 32, 32, 128), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 128), (1, 3, 3, 128), (128,), (1, 32, 32, 128), (1, 1)), + QuantConv2d((1, 32, 32, 128), (128, 1, 1, 128), (128,), (1, 32, 32, 128), (1, 1)), + QuantDepthwiseConv2d((1, 32, 32, 128), (1, 3, 3, 128), (128,), (1, 16, 16, 128), (2, 2)), + QuantConv2d((1, 16, 16, 128), (256, 1, 1, 128), (256,), (1, 16, 16, 256), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 256), (1, 3, 3, 256), (256,), (1, 16, 16, 256), (1, 1)), + QuantConv2d((1, 16, 16, 256), (256, 1, 1, 256), (256,), (1, 16, 16, 256), (1, 1)), + QuantDepthwiseConv2d((1, 16, 16, 256), (1, 3, 3, 256), (256,), (1, 8, 8, 256), (2, 2)), + QuantConv2d((1, 8, 8, 256), (512, 1, 1, 256), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantConv2d((1, 8, 8, 512), (512, 1, 1, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantConv2d((1, 8, 8, 512), (512, 1, 1, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantConv2d((1, 8, 8, 512), (512, 1, 1, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantConv2d((1, 8, 8, 512), (512, 1, 1, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantConv2d((1, 8, 8, 512), (512, 1, 1, 512), (512,), (1, 8, 8, 512), (1, 1)), + QuantDepthwiseConv2d((1, 8, 8, 512), (1, 3, 3, 512), (512,), (1, 4, 4, 512), (2, 2)), + QuantConv2d((1, 4, 4, 512), (1024, 1, 1, 512), (1024,), (1, 4, 4, 1024), (1, 1)), + QuantDepthwiseConv2d((1, 4, 4, 1024), (1, 3, 3, 1024), (1024,), (1, 4, 4, 1024), (1, 1)), + QuantConv2d((1, 4, 4, 1024), (1024, 1, 1, 1024), (1024,), (1, 4, 4, 1024), (1, 1)), + QuantAveragePool2d((1, 4, 4, 1024), (1, 1, 1, 1024), (4, 4)), + QuantConv2d((1, 1, 1, 1024), (1001, 1, 1, 1024), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_1.0_160': + layers = [ + QuantConv2d((1, 160, 160, 3), (32, 3, 3, 3), (32,), (1, 80, 80, 32), (2, 2)), + QuantDepthwiseConv2d((1, 80, 80, 32), (1, 3, 3, 32), (32,), (1, 80, 80, 32), (1, 1)), + QuantConv2d((1, 80, 80, 32), (64, 1, 1, 32), (64,), (1, 80, 80, 64), (1, 1)), + QuantDepthwiseConv2d((1, 80, 80, 64), (1, 3, 3, 64), (64,), (1, 40, 40, 64), (2, 2)), + QuantConv2d((1, 40, 40, 64), (128, 1, 1, 64), (128,), (1, 40, 40, 128), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 128), (1, 3, 3, 128), (128,), (1, 40, 40, 128), (1, 1)), + QuantConv2d((1, 40, 40, 128), (128, 1, 1, 128), (128,), (1, 40, 40, 128), (1, 1)), + QuantDepthwiseConv2d((1, 40, 40, 128), (1, 3, 3, 128), (128,), (1, 20, 20, 128), (2, 2)), + QuantConv2d((1, 20, 20, 128), (256, 1, 1, 128), (256,), (1, 20, 20, 256), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 256), (1, 3, 3, 256), (256,), (1, 20, 20, 256), (1, 1)), + QuantConv2d((1, 20, 20, 256), (256, 1, 1, 256), (256,), (1, 20, 20, 256), (1, 1)), + QuantDepthwiseConv2d((1, 20, 20, 256), (1, 3, 3, 256), (256,), (1, 10, 10, 256), (2, 2)), + QuantConv2d((1, 10, 10, 256), (512, 1, 1, 256), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantConv2d((1, 10, 10, 512), (512, 1, 1, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantConv2d((1, 10, 10, 512), (512, 1, 1, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantConv2d((1, 10, 10, 512), (512, 1, 1, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantConv2d((1, 10, 10, 512), (512, 1, 1, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantConv2d((1, 10, 10, 512), (512, 1, 1, 512), (512,), (1, 10, 10, 512), (1, 1)), + QuantDepthwiseConv2d((1, 10, 10, 512), (1, 3, 3, 512), (512,), (1, 5, 5, 512), (2, 2)), + QuantConv2d((1, 5, 5, 512), (1024, 1, 1, 512), (1024,), (1, 5, 5, 1024), (1, 1)), + QuantDepthwiseConv2d((1, 5, 5, 1024), (1, 3, 3, 1024), (1024,), (1, 5, 5, 1024), (1, 1)), + QuantConv2d((1, 5, 5, 1024), (1024, 1, 1, 1024), (1024,), (1, 5, 5, 1024), (1, 1)), + QuantAveragePool2d((1, 5, 5, 1024), (1, 1, 1, 1024), (5, 5)), + QuantConv2d((1, 1, 1, 1024), (1001, 1, 1, 1024), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_1.0_192': + layers = [ + QuantConv2d((1, 192, 192, 3), (32, 3, 3, 3), (32,), (1, 96, 96, 32), (2, 2)), + QuantDepthwiseConv2d((1, 96, 96, 32), (1, 3, 3, 32), (32,), (1, 96, 96, 32), (1, 1)), + QuantConv2d((1, 96, 96, 32), (64, 1, 1, 32), (64,), (1, 96, 96, 64), (1, 1)), + QuantDepthwiseConv2d((1, 96, 96, 64), (1, 3, 3, 64), (64,), (1, 48, 48, 64), (2, 2)), + QuantConv2d((1, 48, 48, 64), (128, 1, 1, 64), (128,), (1, 48, 48, 128), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 128), (1, 3, 3, 128), (128,), (1, 48, 48, 128), (1, 1)), + QuantConv2d((1, 48, 48, 128), (128, 1, 1, 128), (128,), (1, 48, 48, 128), (1, 1)), + QuantDepthwiseConv2d((1, 48, 48, 128), (1, 3, 3, 128), (128,), (1, 24, 24, 128), (2, 2)), + QuantConv2d((1, 24, 24, 128), (256, 1, 1, 128), (256,), (1, 24, 24, 256), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 256), (1, 3, 3, 256), (256,), (1, 24, 24, 256), (1, 1)), + QuantConv2d((1, 24, 24, 256), (256, 1, 1, 256), (256,), (1, 24, 24, 256), (1, 1)), + QuantDepthwiseConv2d((1, 24, 24, 256), (1, 3, 3, 256), (256,), (1, 12, 12, 256), (2, 2)), + QuantConv2d((1, 12, 12, 256), (512, 1, 1, 256), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantConv2d((1, 12, 12, 512), (512, 1, 1, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantConv2d((1, 12, 12, 512), (512, 1, 1, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantConv2d((1, 12, 12, 512), (512, 1, 1, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantConv2d((1, 12, 12, 512), (512, 1, 1, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantConv2d((1, 12, 12, 512), (512, 1, 1, 512), (512,), (1, 12, 12, 512), (1, 1)), + QuantDepthwiseConv2d((1, 12, 12, 512), (1, 3, 3, 512), (512,), (1, 6, 6, 512), (2, 2)), + QuantConv2d((1, 6, 6, 512), (1024, 1, 1, 512), (1024,), (1, 6, 6, 1024), (1, 1)), + QuantDepthwiseConv2d((1, 6, 6, 1024), (1, 3, 3, 1024), (1024,), (1, 6, 6, 1024), (1, 1)), + QuantConv2d((1, 6, 6, 1024), (1024, 1, 1, 1024), (1024,), (1, 6, 6, 1024), (1, 1)), + QuantAveragePool2d((1, 6, 6, 1024), (1, 1, 1, 1024), (6, 6)), + QuantConv2d((1, 1, 1, 1024), (1001, 1, 1, 1024), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] +if network == 'v1_1.0_224': + layers = [ + QuantConv2d((1, 224, 224, 3), (32, 3, 3, 3), (32,), (1, 112, 112, 32), (2, 2)), + QuantDepthwiseConv2d((1, 112, 112, 32), (1, 3, 3, 32), (32,), (1, 112, 112, 32), (1, 1)), + QuantConv2d((1, 112, 112, 32), (64, 1, 1, 32), (64,), (1, 112, 112, 64), (1, 1)), + QuantDepthwiseConv2d((1, 112, 112, 64), (1, 3, 3, 64), (64,), (1, 56, 56, 64), (2, 2)), + QuantConv2d((1, 56, 56, 64), (128, 1, 1, 64), (128,), (1, 56, 56, 128), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 128), (1, 3, 3, 128), (128,), (1, 56, 56, 128), (1, 1)), + QuantConv2d((1, 56, 56, 128), (128, 1, 1, 128), (128,), (1, 56, 56, 128), (1, 1)), + QuantDepthwiseConv2d((1, 56, 56, 128), (1, 3, 3, 128), (128,), (1, 28, 28, 128), (2, 2)), + QuantConv2d((1, 28, 28, 128), (256, 1, 1, 128), (256,), (1, 28, 28, 256), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 256), (1, 3, 3, 256), (256,), (1, 28, 28, 256), (1, 1)), + QuantConv2d((1, 28, 28, 256), (256, 1, 1, 256), (256,), (1, 28, 28, 256), (1, 1)), + QuantDepthwiseConv2d((1, 28, 28, 256), (1, 3, 3, 256), (256,), (1, 14, 14, 256), (2, 2)), + QuantConv2d((1, 14, 14, 256), (512, 1, 1, 256), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantConv2d((1, 14, 14, 512), (512, 1, 1, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantConv2d((1, 14, 14, 512), (512, 1, 1, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantConv2d((1, 14, 14, 512), (512, 1, 1, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantConv2d((1, 14, 14, 512), (512, 1, 1, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantConv2d((1, 14, 14, 512), (512, 1, 1, 512), (512,), (1, 14, 14, 512), (1, 1)), + QuantDepthwiseConv2d((1, 14, 14, 512), (1, 3, 3, 512), (512,), (1, 7, 7, 512), (2, 2)), + QuantConv2d((1, 7, 7, 512), (1024, 1, 1, 512), (1024,), (1, 7, 7, 1024), (1, 1)), + QuantDepthwiseConv2d((1, 7, 7, 1024), (1, 3, 3, 1024), (1024,), (1, 7, 7, 1024), (1, 1)), + QuantConv2d((1, 7, 7, 1024), (1024, 1, 1, 1024), (1024,), (1, 7, 7, 1024), (1, 1)), + QuantAveragePool2d((1, 7, 7, 1024), (1, 1, 1, 1024), (7, 7)), + QuantConv2d((1, 1, 1, 1024), (1001, 1, 1, 1024), (1001,), (1, 1, 1, 1001), (1, 1)), + QuantReshape((1, 1, 1, 1001), (2,), (1, 1001)), + QuantSoftmax((1, 1001), (1, 1001)) + ] + +QuantConvBase.init_temp(layers) + +for layer in layers: + layer.input_from(0) + +layers[0].X.input_from(1) + +opt = Optimizer() +opt.layers = layers +start_timer(1) +opt.forward(1) +stop_timer(1) diff --git a/Programs/Source/idash_predict.mpc b/Programs/Source/idash_predict.mpc new file mode 100644 index 00000000..637ec3e9 --- /dev/null +++ b/Programs/Source/idash_predict.mpc @@ -0,0 +1,41 @@ +import ml +import random + +program.use_trunc_pr = True +sfix.round_nearest = True + +sfix.set_precision(16, 31) +cfix.set_precision(16, 31) + +N = int(program.args[1]) +n_features = int(program.args[2]) + +program.allocated_mem['s'] = 1 + n_features + +b = sfix.load_mem(0) +W = sfix.load_mem(1, size=n_features) + +#sint.load_mem(100).reveal().print_reg() + +dense = ml.Dense(N, n_features, 1) +dense.b[0] = b +dense.W.assign_vector(W) + +print_ln('b=%s W[-1]=%s', dense.b[0].reveal(), + dense.W[n_features - 1][0][0].reveal()) + +@for_range_opt(n_features) +def _(i): + @for_range_opt(N) + def _(j): + dense.X[j][0][i] = sfix.get_input_from(0) + +dense.forward() + +print_str('predictions: ') + +@for_range(N) +def _(i): + pred = ml.sigmoid(dense.Y[i][0][0]) + print_str('%s', pred.reveal() >= 0.5) +print_ln() diff --git a/Programs/Source/idash_train.mpc b/Programs/Source/idash_train.mpc new file mode 100644 index 00000000..25bcd693 --- /dev/null +++ b/Programs/Source/idash_train.mpc @@ -0,0 +1,51 @@ +import ml +import random + +program.use_trunc_pr = True +sfix.round_nearest = True + +sfix.set_precision(16, 31) +cfix.set_precision(16, 31) +sfloat.vlen = sfix.f + +n_epochs = 200 + +n_normal = int(program.args[1]) +n_pos = int(program.args[2]) +n_features = int(program.args[3]) + +debug = 'debug' in program.args + +n_examples = n_normal + n_pos +N = max(n_normal, n_pos) * 2 + +X_normal = sfix.Matrix(n_normal, n_features) +X_pos = sfix.Matrix(n_pos, n_features) + +@for_range_opt(n_features) +def _(i): + @for_range_opt(n_normal) + def _(j): + X_normal[j][i] = sfix.get_input_from(0) + @for_range_opt(n_pos) + def _(j): + X_pos[j][i] = sfix.get_input_from(0) + +dense = ml.Dense(N, n_features, 1) +layers = [dense, ml.Output(N)] + +sgd = ml.SGD(layers, n_epochs, report_loss=debug) +sgd.reset([X_normal, X_pos]) +sgd.run() + +if debug: + @for_range(N) + def _(i): + print_ln('%s %s', layers[-1].Y[i].reveal(), + ml.sigmoid(layers[-1].X[i]).reveal()) + +layers[0].b[0].store_in_mem(0) +layers[0].W.get_vector().store_in_mem(1) + +print_ln('b=%s W[-1]=%s', layers[0].b[0].reveal(), + layers[0].W[n_features - 1][0][0].reveal()) diff --git a/Programs/Source/regression.mpc b/Programs/Source/regression.mpc new file mode 100644 index 00000000..992a5db0 --- /dev/null +++ b/Programs/Source/regression.mpc @@ -0,0 +1,195 @@ +import ml +import random +import re + +program.use_trunc_pr = True +ml.set_n_threads(8) + +debug = False + +if 'halfprec' in program.args: + print '8-bit precision after point' + sfix.set_precision(8, 31) + cfix.set_precision(8, 31) +else: + sfix.set_precision(16, 31) + cfix.set_precision(16, 31) + +sfloat.vlen = sfix.f + +if 'nearest' in program.args: + sfix.round_nearest = True + +n_examples = 227 +n_normal = 84 +n_features = 12634 + +if len(program.args) > 2: + if 'bc' in program.args: + print 'Compiling for BC-TCGA' + n_examples = 472 + n_normal = 49 + n_features = 17814 + +n_pos = n_examples - n_normal +n_epochs = 1 +if len(program.args) > 1: + n_epochs = int(program.args[1]) + +try: + ml.set_n_threads(int(program.args[2])) +except: + pass + +print 'Using %d threads' % ml.Layer.n_threads + +n_fold = 5 +test_share = 1. / n_fold +n_ex = [n_normal, n_pos] +n_tests = [int(test_share * x) for x in n_ex] +n_train = [x - y for x, y in zip(n_ex, n_tests)] + +weighted = 'weighted' in program.args + +if 'fast' in program.args: + N = min(n_train) * 2 +elif 'mini' in program.args: + N = 32 +elif weighted: + N = sum(n_train) +else: + N = max(n_train) * 2 + +n_test = sum(n_tests) + +indices = [regint.Array(n) for n in n_ex] +indices[0].assign(range(n_pos, n_pos + n_normal)) +indices[1].assign(range(n_pos)) + +test = regint.Array(n_test) + +if 'quant' in program.args: + dense = ml.QuantizedDense(N, n_features, 1) +else: + dense = ml.Dense(N, n_features, 1) + +layers = [dense, ml.Output(N, debug=debug)] + +Y = sfix.Array(n_examples) +X = sfix.Matrix(n_examples, n_features) +Y.input_from(0) + +@for_range_opt(n_features) +def _(i): + @for_range_opt(n_examples) + def _(j): + X[j][i] = sfix.get_input_from(0) + +print_ln('X[0][%s] = %s', n_features - 1, X[0][n_features - 1].reveal()) +print_ln('X[%s][0] = %s', n_examples - 1, X[n_examples - 1][0].reveal()) + +sgd = ml.SGD(layers, n_epochs, debug=debug, report_loss=True) + +if 'tol' in program.args: + sgd.tol = 0.001 + +for arg in program.args: + m = re.match('tol=(.*)', arg) + if m: + sgd.tol = float(m.group(1)) + print 'Stop with tolerance', sgd.tol + +sum_acc = cfix.MemValue(0) + +@for_range(100) +def _(i_run): + for idx in indices: + idx.shuffle() + + @for_range(n_fold) + def _(i_fold): + i_test = regint.MemValue(0) + Xs = [sfix.Matrix(n, n_features) for n in n_train] + training = [regint.Array(n) for n in n_train] + for label in 0, 1: + i_train = regint.MemValue(0) + @for_range(len(indices[label])) + def _(i): + @if_e(i / n_tests[label] == i_fold) + def _(): + test[i_test] = indices[label][i] + i_test.iadd(1) + @else_ + def _(): + j = indices[label][i] + training[label][i_train] = j + Xs[label][i_train] = X[j] + i_train.iadd(1) + + print_ln('training on %s', [list(x) for x in training]) + print_ln('testing on %s', list(test)) + + if 'static' in program.args or weighted: + sgd.reset() + if weighted: + if 'doublenormal' in program.args: + factor = 2 + else: + factor = 1 + layers[-1].set_weights([factor * n_train[1]] * n_train[0] + + [n_train[0]] * n_train[1]) + n_indices = n_train + n = n_train[0] + else: + n = N / 2 + assert 2 * n == N + n_indices = [n, n] + for label, idx in enumerate(training): + @for_range(n_indices[label]) + def _(i): + layers[0].X[i + label * n] = X[idx[i % len(idx)]] + layers[-1].Y[i + label * n] = label + else: + sgd.reset(Xs) + + sgd.run() + + match = lambda x, y: (y.v >> y.f).if_else(x > 0.5, x < 0.5) + + def get_acc(N, noise=False): + pos_acc = cfix.MemValue(0) + neg_acc = cfix.MemValue(0) + @for_range(N) + def _(i): + y = layers[-1].Y[i].reveal() + x = ml.sigmoid(layers[0].Y[i][0][0]).reveal() + m = match(x, y) + pos_acc.iadd(m * y) + neg_acc.iadd(m * (1 - y)) + if noise: + print_ln('%s %s %s', match(x, y), y, x) + return neg_acc.read(), pos_acc.read() + + print_ln('train_acc: %s', sum(get_acc(N)) / N) + + @for_range(n_test) + def _(i): + j = test[i] + layers[0].X[i] = X[j] + layers[-1].Y[i] = Y[j] + + sgd.forward(n_test) + print_ln('test_loss: %s', sgd.layers[-1].l.reveal()) + + accs = get_acc(n_test, True) + acc = sum(accs) + real_accs = [x / y for x, y in zip(accs, n_tests)] + real_acc = sum(real_accs) / 2 + print_ln('test_acc: %s (%s=%s/%s %s=%s/%s)', real_acc, + real_accs[0], accs[0], n_tests[0], real_accs[1], accs[1], + n_tests[1]) + sum_acc.iadd(real_acc) + #print_ln('test set: %s', test) + + print_ln('average test acc: %s (%s/%s)', + sum_acc / (n_fold * (i_run + 1)), sum_acc, (n_fold * (i_run + 1))) diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 80bf1c4e..66d53f38 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_BEAVER_HPP_ +#define PROTOCOLS_BEAVER_HPP_ + #include "Beaver.h" #include "Replicated.hpp" @@ -33,7 +36,7 @@ typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) (void) n; triples.push_back({{}}); auto& triple = triples.back(); - prep->get(DATA_TRIPLE, triple.data()); + triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); return 0; @@ -62,3 +65,5 @@ T Beaver::finalize_mul(int n) triple++; return tmp; } + +#endif diff --git a/Protocols/BrainPrep.h b/Protocols/BrainPrep.h index 167e5355..834c93a6 100644 --- a/Protocols/BrainPrep.h +++ b/Protocols/BrainPrep.h @@ -16,6 +16,7 @@ public: BrainPrep(SubProcessor* proc, DataPositions& usage) : MaliciousRingPrep(proc, usage) {} void buffer_triples(); + void buffer_inputs(int player); }; #endif /* PROTOCOLS_BRAINPREP_H_ */ diff --git a/Protocols/BrainPrep.hpp b/Protocols/BrainPrep.hpp index c2820b76..4a7c2ef9 100644 --- a/Protocols/BrainPrep.hpp +++ b/Protocols/BrainPrep.hpp @@ -150,3 +150,9 @@ void BrainPrep::buffer_triples() for (auto& x : triples) this->triples.push_back({{x[0], x[1], x[2]}}); } + +template +void BrainPrep::buffer_inputs(int player) +{ + this->buffer_inputs_as_usual(player, this->proc); +} diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index aa41c490..2a1ced00 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -25,7 +25,7 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef BrainPrep LivePrep; const static int N_MASK_BITS = clear::N_BITS + S; diff --git a/Protocols/CowGearOptions.cpp b/Protocols/CowGearOptions.cpp index de67840f..20533561 100644 --- a/Protocols/CowGearOptions.cpp +++ b/Protocols/CowGearOptions.cpp @@ -59,7 +59,7 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, exit(1); } if (covert_security > (1 << lowgear_security)) - insecure("LowGear security less than key generation security"); + insecure(", LowGear security less than key generation security"); } else lowgear_from_covert(); diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index 577680c4..bac48021 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -123,7 +123,7 @@ void CowGearPrep::buffer_triples() assert(not producer.triples.empty()); for (auto& triple : producer.triples) this->triples.push_back({{triple[0], triple[1], triple[2]}}); -#ifdef VERBOSE +#ifdef VERBOSE_HE cerr << "generated " << producer.triples.size() << " triples, now got " << this->triples.size() << endl; #endif @@ -133,7 +133,7 @@ template void CowGearPrep::buffer_inverses() { assert(this->proc != 0); - BufferPrep::buffer_inverses(this->proc->MC, this->proc->P); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); } template @@ -145,7 +145,7 @@ void CowGearPrep::buffer_inputs(int player) this->inputs.resize(this->proc->P.num_players()); for (auto& input : generator.inputs) this->inputs[player].push_back(input); -#ifdef VERBOSE +#ifdef VERBOSE_HE cerr << "generated " << generator.inputs.size() << " inputs, now got " << this->inputs[player].size() << endl; #endif diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 04ef6f8c..d520504f 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -12,6 +12,7 @@ using namespace std; #include "Networking/ServerSocket.h" #include "Protocols/Summer.h" #include "Protocols/MAC_Check_Base.h" +#include "Protocols/RandomPrep.h" #include "Tools/time-func.h" @@ -119,7 +120,7 @@ class MAC_Check_Z2k : public MAC_Check_ { protected: vector shares; - MascotPrep* prep; + RandomPrep* prep; W get_random_element(); @@ -134,7 +135,7 @@ public: MAC_Check_Z2k(const T& ai, Names& Nms, int thread_num); virtual void Check(const Player& P); void set_random_element(const W& random_element); - void set_prep(MascotPrep& prep); + void set_prep(RandomPrep& prep); virtual ~MAC_Check_Z2k() {}; }; diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 04e52faa..e0f93048 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -287,7 +287,7 @@ void MAC_Check_Z2k::set_random_element(const W& random_element) { } template -void MAC_Check_Z2k::set_prep(MascotPrep& prep) +void MAC_Check_Z2k::set_prep(RandomPrep& prep) { this->prep = &prep; } @@ -515,12 +515,12 @@ void Direct_MAC_Check::POpen_Begin(vector& values,const vector& } template -void direct_add_openings(vector& values, const Player& P, vector& os) +void direct_add_openings(vector& values, const PlayerBase& P, vector& os) { for (unsigned int i=0; i(os[j]); + values[i].template add(os.at(j)); } template diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 71d6913a..3e638d76 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -30,7 +30,7 @@ void MAC_Check_Base::CheckFor(const typename T::open_type& value, vector opened; POpen(opened, shares, P); for (auto& check : opened) - if (check != value) + if (typename T::clear(check) != value) { cout << check << " != " << value << endl; throw Offline_Check_Error("CheckFor"); diff --git a/Protocols/MalRepRingPrep.h b/Protocols/MalRepRingPrep.h index 216ce34d..2970f353 100644 --- a/Protocols/MalRepRingPrep.h +++ b/Protocols/MalRepRingPrep.h @@ -19,6 +19,8 @@ public: void shuffle_buffer_triples(); void buffer_squares(); + + void buffer_inputs(int player); }; // extra class to avoid recursion diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index babdb2d4..0eb26fba 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -6,6 +6,7 @@ #include "MalRepRingPrep.h" #include "MaliciousRepPrep.h" #include "MalRepRingOptions.h" +#include "ShuffleSacrifice.h" #include "Processor/OnlineOptions.h" template @@ -68,16 +69,26 @@ template void shuffle_triple_generation(vector>& triples, Player& P, typename T::MAC_Check& MC, int n_bits = -1) { - int N = max(1 << 20, OnlineOptions::singleton.batch_size); - int B = 3; - int C = 3; - int buffer_size = B * N + C; + ShuffleSacrifice sacrifice; vector> check_triples; + int buffer_size = sacrifice.minimum_n_inputs(OnlineOptions::singleton.batch_size); // optimistic triple generation Replicated protocol(P); generate_triples(check_triples, buffer_size, &protocol, n_bits); + sacrifice.triple_sacrifice(triples, check_triples, P, MC); +} + +template +void ShuffleSacrifice::triple_sacrifice(vector>& triples, + vector>& check_triples, Player& P, + typename T::MAC_Check& MC) +{ + int buffer_size = check_triples.size(); + assert(buffer_size >= minimum_n_inputs()); + int N = (buffer_size - C) / B; + // shuffle GlobalPRNG G(P); for (int i = 0; i < buffer_size; i++) @@ -98,7 +109,7 @@ void shuffle_triple_generation(vector>& triples, Player& P, vector opened; MC.POpen(opened, shares, P); for (int i = 0; i < C; i++) - if (typename T::open_type(opened[3 * i] * opened[3 * i + 1]) != opened[3 * i + 2]) + if (typename T::clear(opened[3 * i] * opened[3 * i + 1]) != opened[3 * i + 2]) throw Offline_Check_Error("shuffle opening"); // sacrifice buckets @@ -130,7 +141,7 @@ void shuffle_triple_generation(vector>& triples, Player& P, T& h = check_triples[i + N * j][2]; typename T::open_type& rho = *(it++); typename T::open_type& sigma = *(it++); - checks.push_back(c - h - rho * b - sigma * f); + checks.push_back(c - h - b * rho - f * sigma); } } MC.CheckFor(0, checks, P); @@ -151,3 +162,9 @@ void MalRepRingPrepWithBits::buffer_bits() prep.set_proc(&bit_proc); bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep); } + +template +void MalRepRingPrep::buffer_inputs(int player) +{ + this->buffer_inputs_as_usual(player, this->proc); +} diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index b1dedc52..f67ae04e 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -25,7 +25,7 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 3766f7a8..ce6da327 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -22,7 +22,7 @@ public: typedef HashMaliciousRepMC> MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; - typedef ReplicatedPrivateOutput> PrivateOutput; + typedef ::PrivateOutput> PrivateOutput; typedef Rep3Share Honest; typedef MaliciousRepPrep LivePrep; typedef MaliciousRep3Share prep_type; diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 122c2179..32d8af01 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPMC_HPP_ +#define PROTOCOLS_MALICIOUSREPMC_HPP_ + #include "MaliciousRepMC.h" #include "GC/Machine.h" #include "Math/BitVec.h" @@ -148,3 +151,5 @@ void CommMaliciousRepMC::Check(const Player& P) { (void)P; } + +#endif diff --git a/Protocols/MaliciousRepPrep.h b/Protocols/MaliciousRepPrep.h index 6e51cf8a..12ec9d03 100644 --- a/Protocols/MaliciousRepPrep.h +++ b/Protocols/MaliciousRepPrep.h @@ -30,6 +30,7 @@ class MaliciousRepPrep : public BufferPrep ReplicatedPrep honest_prep; typename T::Honest::Protocol* replicated; typename T::MAC_Check MC; + SubProcessor* proc; vector masked; vector checks; @@ -44,6 +45,7 @@ class MaliciousRepPrep : public BufferPrep void buffer_squares(); void buffer_inverses(); void buffer_bits(); + void buffer_inputs(int player); public: MaliciousRepPrep(SubProcessor* proc, DataPositions& usage); diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index a36ca398..4dbe345b 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -11,13 +11,13 @@ template MaliciousRepPrep::MaliciousRepPrep(SubProcessor* proc, DataPositions& usage) : MaliciousRepPrep(usage) { - (void) proc; + this->proc = proc; } template MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage) : BufferPrep(usage), honest_usage(usage.num_players()), - honest_prep(0, honest_usage), replicated(0) + honest_prep(0, honest_usage), replicated(0), proc(0) { } @@ -150,7 +150,7 @@ void MaliciousRepPrep::buffer_squares() template void MaliciousRepPrep::buffer_inverses() { - BufferPrep::buffer_inverses(MC, honest_prep.protocol->P); + ::buffer_inverses(this->inverses, *this, MC, honest_prep.protocol->P); } template @@ -188,3 +188,9 @@ void MaliciousRepPrep::buffer_bits() } MC.CheckFor(0, checks, P); } + +template +void MaliciousRepPrep::buffer_inputs(int player) +{ + this->buffer_inputs_as_usual(player, proc); +} diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index fb64a21b..0699633b 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -22,7 +22,7 @@ public: typedef MaliciousShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef ShamirShare Honest; typedef MaliciousRepPrep LivePrep; typedef T random_type; diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 92542fff..721b1f61 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -7,7 +7,8 @@ #define PROTOCOLS_MASCOTPREP_H_ #include "ReplicatedPrep.h" -#include "OT/NPartyTripleGenerator.h" +#include "RandomPrep.h" +#include "OT/TripleMachine.h" template class OTPrep : public virtual RingPrep @@ -27,7 +28,7 @@ public: }; template -class MascotPrep : public OTPrep +class MascotPrep : public OTPrep, public RandomPrep { public: MascotPrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index fcbf9b2d..3ddd06ed 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -3,11 +3,15 @@ * */ +#ifndef PROTOCOLS_MASCOTPREP_HPP_ +#define PROTOCOLS_MASCOTPREP_HPP_ + #include "MascotPrep.h" #include "Processor/Processor.h" #include "Processor/BaseMachine.h" #include "OT/OTTripleSetup.h" #include "OT/Triple.hpp" +#include "OT/NPartyTripleGenerator.hpp" template OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : @@ -58,7 +62,7 @@ template void MascotFieldPrep::buffer_inverses() { assert(this->proc != 0); - BufferPrep::buffer_inverses(this->proc->MC, this->proc->P); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); } template @@ -89,12 +93,18 @@ template T MascotPrep::get_random() { assert(this->proc); + return BufferPrep::get_random_from_inputs(this->proc->P.num_players()); +} + +template +T BufferPrep::get_random_from_inputs(int nplayers) +{ T res; - for (int j = 0; j < this->proc->P.num_players(); j++) + for (int j = 0; j < nplayers; j++) { T tmp; typename T::open_type _; - this->get_input(tmp, _, j); + this->get_input_no_count(tmp, _, j); res += tmp; } return res; @@ -108,3 +118,5 @@ size_t OTPrep::data_sent() else return 0; } + +#endif diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index c80476e6..c002cfcd 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -23,7 +23,7 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef MaliciousRepPrep LivePrep; PostSacriRepFieldShare() diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index ec10c20c..0d32cb0a 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -28,7 +28,7 @@ public: typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; static string type_short() diff --git a/Protocols/RandomPrep.h b/Protocols/RandomPrep.h new file mode 100644 index 00000000..ab13bf30 --- /dev/null +++ b/Protocols/RandomPrep.h @@ -0,0 +1,18 @@ +/* + * RandomPrep.h + * + */ + +#ifndef PROTOCOLS_RANDOMPREP_H_ +#define PROTOCOLS_RANDOMPREP_H_ + +template +class RandomPrep +{ +public: + virtual ~RandomPrep() {} + + virtual T get_random() = 0; +}; + +#endif /* PROTOCOLS_RANDOMPREP_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 140bac4e..05b07147 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -11,6 +11,7 @@ #include "Protocols/Replicated.h" template class ReplicatedRingPrep; +template class PrivateOutput; template class Rep3Share : public FixedVec @@ -25,7 +26,7 @@ public: typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef ReplicatedRingPrep LivePrep; typedef Rep3Share Honest; diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index ddf18300..8612aa9d 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -21,6 +21,7 @@ template class ReplicatedPrivateOutput; template class Share; template class Rep3Share; template class MAC_Check_Base; +template class Preprocessing; class ReplicatedBase { @@ -58,6 +59,9 @@ public: void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } void next_dotprod() {} T finalize_dotprod(int length); + + virtual void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { (void) regs, (void) size; (void) proc; throw not_implemented(); } }; template @@ -70,7 +74,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase public: typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; Replicated(Player& P); @@ -83,6 +86,8 @@ public: } void init_mul(SubProcessor* proc); + void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); + void init_mul(); typename T::clear prepare_mul(const T& x, const T& y, int n = -1); void exchange(); @@ -95,6 +100,8 @@ public: void next_dotprod(); T finalize_dotprod(int length); + void trunc_pr(const vector& regs, int size, SubProcessor& proc); + T get_random(); }; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index d532ac91..aeafc670 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -3,10 +3,21 @@ * */ +#ifndef PROTOCOLS_REPLICATED_HPP_ +#define PROTOCOLS_REPLICATED_HPP_ + #include "Replicated.h" #include "Processor/Processor.h" #include "Tools/benchmarking.h" +#include "SemiShare.h" +#include "SemiMC.h" +#include "ReplicatedInput.h" +#include "Rep3Share.h" + +#include "SemiMC.hpp" +#include "Math/Z2k.hpp" + template ProtocolBase::ProtocolBase() : counter(0) { @@ -73,6 +84,13 @@ void Replicated::init_mul(SubProcessor* proc) init_mul(); } +template +void Replicated::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) +{ + (void) prep, (void) MC; + init_mul(); +} + template void Replicated::init_mul() { @@ -156,3 +174,125 @@ T Replicated::get_random() res[i].randomize(shared_prngs[i]); return res; } + +template +void trunc_pr(const vector& regs, int size, + SubProcessor>>& proc) +{ + assert(regs.size() % 4 == 0); + assert(proc.P.num_players() == 3); + typedef SignedZ2 value_type; + typedef Rep3Share T; + bool generate = proc.P.my_num() == 2; + if (generate) + { + octetStream os[2]; + for (size_t i = 0; i < regs.size(); i += 4) + { + int k = regs[i + 2]; + int m = regs[i + 3]; + int n_shift = value_type::N_BITS - 1 - k; + assert(m < k); + assert(0 < k); + assert(m < value_type::N_BITS); + for (int l = 0; l < size; l++) + { + auto& res = proc.get_S_ref(regs[i] + l); + auto& G = proc.Proc.secure_prng; + auto mask = G.template get(); + auto unmask = (mask << (n_shift + 1)) >> (n_shift + m + 1); + T shares[4]; + shares[0].randomize_to_sum(mask, G); + shares[1].randomize_to_sum(unmask, G); + shares[2].randomize_to_sum( + (mask << (n_shift)) >> (value_type::N_BITS - 1), G); + res.randomize(G); + shares[3] = res; + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 4; j++) + shares[j][i].pack(os[i]); + } + } + } + for (int i = 0; i < 2; i++) + proc.P.send_to(i, os[i], true); + } + else + { + octetStream os; + proc.P.receive_player(2, os, true); + OffsetPlayer player(proc.P, 1 - 2 * proc.P.my_num()); + typedef SemiShare semi_type; + vector> to_open; + PointerVector> mask_shares[3]; + for (size_t i = 0; i < regs.size(); i += 4) + for (int l = 0; l < size; l++) + { + SemiShare share; + auto& x = proc.get_S_ref(regs[i + 1] + l); + if (proc.P.my_num() == 0) + share = x.sum(); + else + share = x[0]; + for (auto& mask_share : mask_shares) + mask_share.push_back(os.get()); + to_open.push_back(share + mask_shares[0].next()); + auto& res = proc.get_S_ref(regs[i] + l); + auto& a = res[1 - proc.P.my_num()]; + a.unpack(os); + } + PointerVector opened; + DirectSemiMC> MC; + MC.POpen_(opened, to_open, player); + os.reset_write_head(); + for (size_t i = 0; i < regs.size(); i += 4) + { + int k = regs[i + 2]; + int m = regs[i + 3]; + int n_shift = value_type::N_BITS - 1 - k; + assert(m < k); + assert(0 < k); + assert(m < value_type::N_BITS); + for (int l = 0; l < size; l++) + { + auto& res = proc.get_S_ref(regs[i] + l); + auto masked = opened.next() << n_shift; + auto shifted = (masked << 1) >> (n_shift + m + 1); + auto diff = SemiShare::constant(shifted, + player.my_num()) - mask_shares[1].next(); + auto msb = masked >> (value_type::N_BITS - 1); + auto bit_mask = mask_shares[2].next(); + auto overflow = (bit_mask + + SemiShare::constant(msb, player.my_num()) + - bit_mask * msb * 2); + auto res_share = diff + (overflow << (k - m)); + auto& a = res[1 - proc.P.my_num()]; + auto& b = res[proc.P.my_num()]; + b = res_share - a; + b.pack(os); + } + } + player.exchange(os); + for (size_t i = 0; i < regs.size(); i += 4) + for (int l = 0; l < size; l++) + proc.get_S_ref(regs[i] + l)[proc.P.my_num()] += + os.get(); + } +} + +template +void trunc_pr(const vector& regs, int size, SubProcessor& proc) +{ + (void) regs, (void) size, (void) proc; + throw not_implemented(); +} + +template +void Replicated::trunc_pr(const vector& regs, int size, + SubProcessor& proc) +{ + ::trunc_pr(regs, size, proc); +} + +#endif diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index fa319531..5d1d1e78 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -27,10 +27,12 @@ public: void stop(int player, vector targets); virtual void reset(int player) = 0; - virtual void add_mine(const typename T::open_type& input) = 0; + virtual void add_mine(const typename T::open_type& input, + int n_bits = -1) = 0; virtual void add_other(int player) = 0; virtual void send_mine() = 0; - virtual void finalize_other(int player, T& target, octetStream& o) = 0; + virtual void finalize_other(int player, T& target, octetStream& o, + int n_bits = -1) = 0; T finalize_mine(); }; @@ -54,6 +56,15 @@ public: { (void) MC; } + ReplicatedInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : + ReplicatedInput(P) + { + (void) MC, (void) prep; + } + ReplicatedInput(Player& P) : + ReplicatedInput(0, P) + { + } ReplicatedInput(SubProcessor* proc, Player& P) : PrepLessInput(proc), proc(proc), P(P), protocol(P) { @@ -61,11 +72,11 @@ public: } void reset(int player); - void add_mine(const typename T::open_type& input); + void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player); void send_mine(); void exchange(); - void finalize_other(int player, T& target, octetStream& o); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_REPLICATEDINPUT_H_ */ diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 90d71812..fe033dd9 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_REPLICATEDINPUT_HPP_ +#define PROTOCOLS_REPLICATEDINPUT_HPP_ + #include "ReplicatedInput.h" #include "Processor/Processor.h" @@ -21,14 +24,14 @@ void ReplicatedInput::reset(int player) } template -inline void ReplicatedInput::add_mine(const typename T::open_type& input) +inline void ReplicatedInput::add_mine(const typename T::open_type& input, int n_bits) { auto& shares = this->shares; shares.push_back({}); T& my_share = shares.back(); - my_share[0].randomize(protocol.shared_prngs[0]); + my_share[0].randomize(protocol.shared_prngs[0], n_bits); my_share[1] = input - my_share[0]; - my_share[1].pack(os[1]); + my_share[1].pack(os[1], n_bits); this->values_input++; } @@ -96,19 +99,19 @@ void PrepLessInput::stop(int player, vector targets) template inline void ReplicatedInput::finalize_other(int player, T& target, - octetStream& o) + octetStream& o, int n_bits) { if (P.get_offset(player) == 1) { typename T::value_type t; - t.unpack(o); + t.unpack(o, n_bits); target[0] = t; target[1] = 0; } else { target[0] = 0; - target[1].randomize(protocol.shared_prngs[1]); + target[1].randomize(protocol.shared_prngs[1], n_bits); } } @@ -117,3 +120,5 @@ T PrepLessInput::finalize_mine() { return this->shares[this->i_share++]; } + +#endif diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 5b999aeb..1f3c781d 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_REPLICATEDMC_HPP_ +#define PROTOCOLS_REPLICATEDMC_HPP_ + #include "ReplicatedMC.h" template @@ -31,3 +34,5 @@ void ReplicatedMC::POpen_End(vector& values, values[i] = S[i].sum() + tmp; } } + +#endif diff --git a/Protocols/ReplicatedMachine.hpp b/Protocols/ReplicatedMachine.hpp index c454b6ef..00838c31 100644 --- a/Protocols/ReplicatedMachine.hpp +++ b/Protocols/ReplicatedMachine.hpp @@ -19,7 +19,7 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, OnlineOptions online_opts(opt, argc, argv); OnlineOptions::singleton = online_opts; - NetworkOptions network_opts(opt, argc, argv); + NetworkOptionsWithNumber network_opts(opt, argc, argv, nplayers, false); opt.add( "", // Default. 0, // Required? @@ -34,16 +34,14 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, int playerno = online_opts.playerno; string progname = online_opts.progname; - int pnb = network_opts.portnum_base; - string hostname = network_opts.hostname; bool use_encryption = not opt.get("-u")->isSet; if (not use_encryption) insecure("unencrypted communication"); Names N; - Server* server = Server::start_networking(N, playerno, nplayers, hostname, pnb); + Server* server = network_opts.start_networking(N, playerno); - Machine(playerno, N, progname, "empty", + Machine(playerno, N, progname, online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, 0, use_encryption, online_opts.live_prep, online_opts).run(); diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index f7915292..1c4ab044 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -13,6 +13,10 @@ #include +template +void buffer_inverses(vector>& inverses, Preprocessing& prep, + MAC_Check_Base& MC, Player& P); + template class BufferPrep : public Preprocessing { @@ -31,7 +35,8 @@ protected: virtual void buffer_bits() = 0; virtual void buffer_inputs(int player); - virtual void buffer_inverses(MAC_Check_Base& MC, Player& P); + // don't call this if T::Input requires input tuples + void buffer_inputs_as_usual(int player, SubProcessor* proc); public: typedef T share_type; @@ -53,6 +58,8 @@ public: void get_input_no_count(T& a, typename T::open_type& x, int i); void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); + + T get_random_from_inputs(int nplayers); }; template @@ -91,6 +98,8 @@ public: virtual ~SemiHonestRingPrep() {} virtual void buffer_bits() { this->buffer_bits_without_check(); } + virtual void buffer_inputs(int player) + { this->buffer_inputs_as_usual(player, this->proc); } }; template diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 7e8e8a93..5203690c 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -3,8 +3,12 @@ * */ +#ifndef PROTOCOlS_REPLICATEDPREP_HPP_ +#define PROTOCOlS_REPLICATEDPREP_HPP_ + #include "ReplicatedPrep.h" #include "Math/gfp.h" +#include "Processor/OnlineOptions.h" template BufferPrep::BufferPrep(DataPositions& usage) : @@ -136,17 +140,20 @@ void ReplicatedPrep::buffer_inverses() auto protocol = this->protocol; assert(protocol != 0); typename T::MAC_Check MC; - BufferPrep::buffer_inverses(MC, protocol->P); + ::buffer_inverses(this->inverses, *this, MC, protocol->P); } template -void BufferPrep::buffer_inverses(MAC_Check_Base& MC, Player& P) +void buffer_inverses(vector>& inverses, Preprocessing& prep, + MAC_Check_Base& MC, Player& P) { + int buffer_size = OnlineOptions::singleton.batch_size; vector> triples(buffer_size); vector c; for (int i = 0; i < buffer_size; i++) { - get_three_no_count(DATA_TRIPLE, triples[i][0], triples[i][1], triples[i][2]); + prep.get_three_no_count(DATA_TRIPLE, triples[i][0], triples[i][1], + triples[i][2]); c.push_back(triples[i][2]); } vector c_open; @@ -369,6 +376,39 @@ inline void BufferPrep::buffer_inputs(int player) throw not_implemented(); } +template +void BufferPrep::buffer_inputs_as_usual(int player, SubProcessor* proc) +{ + assert(proc != 0); + auto& P = proc->P; + this->inputs.resize(P.num_players()); + typename T::Input input(proc, P); + input.reset(player); + auto buffer_size = OnlineOptions::singleton.batch_size; + if (P.my_num() == player) + { + for (int i = 0; i < buffer_size; i++) + { + typename T::clear r; + r.randomize(proc->Proc.secure_prng); + input.add_mine(r); + this->inputs[player].push_back({input.finalize_mine(), r}); + } + input.send_mine(); + } + else + { + octetStream os; + P.receive_player(player, os, true); + T share; + for (int i = 0; i < buffer_size; i++) + { + input.finalize_other(player, share, os); + this->inputs[player].push_back({share, 0}); + } + } +} + template void BufferPrep::get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size) @@ -376,3 +416,5 @@ void BufferPrep::get_no_count(vector& S, DataTag tag, (void) S, (void) tag, (void) regs, (void) vector_size; throw not_implemented(); } + +#endif diff --git a/Protocols/ReplicatedPrivateOutput.hpp b/Protocols/ReplicatedPrivateOutput.hpp index 55eb851c..d3487223 100644 --- a/Protocols/ReplicatedPrivateOutput.hpp +++ b/Protocols/ReplicatedPrivateOutput.hpp @@ -20,7 +20,7 @@ void ReplicatedPrivateOutput::start(int player, int target, int source) { (void)player, (void)target, (void)source; - throw not_implemented(); + throw runtime_error("not implemented, use PrivateOutput"); } template diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 77e83c98..0cc348d9 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -27,7 +27,18 @@ public: { } - void add_mine(const typename T::clear& input); + SemiInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : + SemiInput(P) + { + (void) MC, (void) prep; + } + + SemiInput(Player& P) : + IndividualInput(0, P) + { + } + + void add_mine(const typename T::clear& input, int n_bits = -1); }; #endif /* PROTOCOLS_SEMIINPUT_H_ */ diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index 11d9c1ad..28673250 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -3,25 +3,30 @@ * */ +#ifndef PROTOCOLS_SEMIINPUT_HPP_ +#define PROTOCOLS_SEMIINPUT_HPP_ + #include "SemiInput.h" #include "ShamirInput.hpp" template -void SemiInput::add_mine(const typename T::clear& input) +void SemiInput::add_mine(const typename T::clear& input, int n_bits) { auto& P = this->P; typename T::open_type sum, share; for (int i = 0; i < P.num_players(); i++) { if (i < P.num_players() - 1) - share.randomize(secure_prng); + share.randomize(secure_prng, n_bits); else share = input - sum; sum += share; if (i == P.my_num()) this->shares.push_back(share); else - share.pack(this->os[i]); + share.pack(this->os[i], n_bits); } } + +#endif diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 130dea35..97fcc9c6 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -27,11 +27,14 @@ template class DirectSemiMC : public SemiMC { public: + DirectSemiMC() {} // emulate Direct_MAC_Check DirectSemiMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } - void POpen_Begin(vector& values,const vector& S,const Player& P); + void POpen_(vector& values,const vector& S,const PlayerBase& P); + void POpen_Begin(vector& values,const vector& S,const Player& P) + { POpen_(values, S, P); } void POpen_End(vector& values,const vector& S,const Player& P); void Check(const Player& P) { (void)P; } diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index e7f87a96..386b68da 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_SEMIMC_HPP_ +#define PROTOCOLS_SEMIMC_HPP_ + #include "SemiMC.h" #include "MAC_Check.hpp" @@ -26,8 +29,8 @@ void SemiMC::POpen_End(vector& values, } template -void DirectSemiMC::POpen_Begin(vector& values, - const vector& S, const Player& P) +void DirectSemiMC::POpen_(vector& values, + const vector& S, const PlayerBase& P) { values.clear(); values.insert(values.begin(), S.begin(), S.end()); @@ -44,3 +47,5 @@ void DirectSemiMC::POpen_End(vector& values, { (void) values, (void) S, (void) P; } + +#endif diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index 82026b2c..9b9fac16 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -29,5 +29,5 @@ template void SemiPrep::buffer_inverses() { assert(this->proc != 0); - BufferPrep::buffer_inverses(this->proc->MC, this->proc->P); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); } diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index a61ae1fc..2a42f636 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -52,7 +52,7 @@ public: static string type_short() { return "D" + string(1, T::type_char()); } - static SemiShare constant(const clear& other, int my_num, const T& alphai) + static SemiShare constant(const clear& other, int my_num, const T& alphai = {}) { return SemiShare(other, my_num, alphai); } @@ -115,6 +115,15 @@ public: (void)full; super::unpack(os); } + + void pack(octetStream& os, int n_bits) const + { + super::pack(os, n_bits); + } + void unpack(octetStream& os, int n_bits) + { + super::unpack(os, n_bits); + } }; #endif /* PROTOCOLS_SEMISHARE_H_ */ diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 1726d2ad..541d6ecf 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -30,7 +30,7 @@ public: void reset(int player); void add_other(int player); void send_mine(); - void finalize_other(int player, T& target, octetStream& o); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; template @@ -55,7 +55,7 @@ public: { } - void add_mine(const typename T::clear& input); + void add_mine(const typename T::clear& input, int n_bits = -1); }; #endif /* PROTOCOLS_SHAMIRINPUT_H_ */ diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 40d88826..32ab32b0 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_SHAMIRINPUT_HPP_ +#define PROTOCOLS_SHAMIRINPUT_HPP_ + #include "ShamirInput.h" #include "Machines/ShamirMachine.h" @@ -19,8 +22,9 @@ void IndividualInput::reset(int player) } template -void ShamirInput::add_mine(const typename T::clear& input) +void ShamirInput::add_mine(const typename T::clear& input, int n_bits) { + (void) n_bits; auto& P = this->P; int n = P.num_players(); int t = ShamirMachine::s().threshold; @@ -69,8 +73,11 @@ void IndividualInput::send_mine() } template -void IndividualInput::finalize_other(int player, T& target, octetStream& o) +void IndividualInput::finalize_other(int player, T& target, octetStream& o, + int n_bits) { (void) player; - target.unpack(o); + target.unpack(o, n_bits); } + +#endif diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index f82324c4..9863d4a6 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -23,7 +23,7 @@ public: typedef ShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; - typedef ReplicatedPrivateOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; typedef ShamirShare Honest; diff --git a/Protocols/Share.h b/Protocols/Share.h index ffe41727..80394fce 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -20,7 +20,8 @@ template class MAC_Check_; template class Direct_MAC_Check; template class MascotMultiplier; template class MascotFieldPrep; -template class NPartyTripleGenerator; +template class MascotTripleGenerator; +template class MascotPrep; union square128; @@ -38,10 +39,11 @@ class Share typedef T clear; typedef Share prep_type; - typedef MascotMultiplier Multiplier; - typedef NPartyTripleGenerator TripleGenerator; + typedef MascotMultiplier Multiplier; + typedef MascotTripleGenerator TripleGenerator; typedef T sacri_type; typedef typename T::Square Rectangle; + typedef Rectangle Square; typedef MAC_Check_ MAC_Check; typedef Direct_MAC_Check Direct_MC; @@ -49,6 +51,7 @@ class Share typedef ::PrivateOutput PrivateOutput; typedef SPDZ Protocol; typedef MascotFieldPrep LivePrep; + typedef MascotPrep RandomPrep; const static bool needs_ot = true; const static bool dishonest_majority = true; diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h new file mode 100644 index 00000000..9a792e09 --- /dev/null +++ b/Protocols/ShuffleSacrifice.h @@ -0,0 +1,36 @@ +/* + * ShuffleSacrifice.h + * + */ + +#ifndef PROTOCOLS_SHUFFLESACRIFICE_H_ +#define PROTOCOLS_SHUFFLESACRIFICE_H_ + +#include +#include +using namespace std; + +class Player; + +template +class ShuffleSacrifice +{ + static const int B = 3; + static const int C = 3; + +public: + static int minimum_n_inputs(int n_outputs = 1) + { + return max(n_outputs, minimum_n_outputs()) * B + C; + } + static int minimum_n_outputs() + { + return 1 << 20; + } + + void triple_sacrifice(vector>& triples, + vector>& check_triples, Player& P, + typename T::MAC_Check& MC); +}; + +#endif /* PROTOCOLS_SHUFFLESACRIFICE_H_ */ diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 4bb29f25..df410fcf 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SPDZ2KPREP_H_ #include "MascotPrep.h" +#include "Spdz2kShare.h" template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index bb2cd0d7..66d41a99 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -94,7 +94,7 @@ void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) BitShare a, a2; bit_prep->get_two(DATA_SQUARE, a, a2); squares.push_back((a2 + a) * 4 + one); - random_shares.push_back(2 * a + one); + random_shares.push_back(a * 2 + one); } vector opened; bit_MC->POpen(opened, squares, bit_proc->P); diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index a9fb5c06..c94807f6 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -13,6 +13,7 @@ #include "OT/Rectangle.h" template class Spdz2kMultiplier; +template class Spdz2kTripleGenerator; template class Spdz2kShare : public Share> @@ -30,14 +31,14 @@ public: typedef Spdz2kShare prep_type; typedef Spdz2kMultiplier Multiplier; - typedef NPartyTripleGenerator TripleGenerator; + typedef Spdz2kTripleGenerator TripleGenerator; typedef Z2 sacri_type; typedef Z2kRectangle Rectangle; typedef MAC_Check_Z2k, Z2, open_type, Spdz2kShare> MAC_Check; typedef MAC_Check Direct_MC; typedef ::Input Input; - typedef NotImplementedOutput PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef SPDZ Protocol; typedef Spdz2kPrep LivePrep; diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 59efc1ca..4c363dc1 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -5,6 +5,8 @@ #include using namespace std; +#include "Networking/Player.h" + template void check_share(vector& Sa, typename T::clear& value, typename T::value_type& mac, int N, const typename T::value_type& key); @@ -26,6 +28,8 @@ void write_mac_keys(const string& directory, int player_num, int nplayers, U key template void read_mac_keys(const string& directory, int player_num, int nplayers, U& keyp, T& key2); +template +void read_mac_keys(const string& directory, const Names& N, U& keyp, T& key2); template class Files diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 30659d93..1dbbfff7 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -11,6 +11,11 @@ template class Share; template class SemiShare; template class FixedVec; +namespace GC +{ +template class TinySecret; +} + template void make_share(Share* Sa,const U& a,int N,const V& key,PRNG& G) { @@ -31,6 +36,21 @@ void make_share(Share* Sa,const U& a,int N,const V& key,PRNG& G) Sa[N-1]=S; } +template +void make_share(GC::TinySecret* Sa,const U& a,int N,const V& key,PRNG& G) +{ + int length = Sa[0].default_length; + for (int i = 0; i < N; i++) + Sa[i].resize_regs(length); + for (int j = 0; j < length; j++) + { + typename GC::TinySecret::part_type shares[N]; + make_share(shares, a.get_bit(j), N, key, G); + for (int i = 0; i < N; i++) + Sa[i].get_reg(j) = shares[i]; + } +} + template void make_share(SemiShare* Sa,const T& a,int N,const T& key,PRNG& G) { @@ -51,7 +71,7 @@ void make_share(FixedVec* Sa, const T& a, int N, const T& key, PRNG& G); template inline void make_share(vector& Sa, - const typename T::clear& a, int N, const typename T::mac_type& key, + const typename T::clear& a, int N, const typename T::mac_key_type& key, PRNG& G) { Sa.resize(N); @@ -148,11 +168,13 @@ inline void generate_keys(const string& directory, int nplayers) } } +template inline string mac_filename(string directory, int playerno) { if (directory.empty()) directory = "."; - return directory + "/Player-MAC-Keys-P" + to_string(playerno); + return directory + "/Player-MAC-Keys-" + string(1, T::type_char()) + "-P" + + to_string(playerno); } template @@ -160,7 +182,7 @@ void write_mac_keys(const string& directory, int i, int nplayers, U macp, T mac2 { ofstream outf; stringstream filename; - filename << mac_filename(directory, i); + filename << mac_filename(directory, i); cout << "Writing to " << filename.str().c_str() << endl; outf.open(filename.str().c_str()); outf << nplayers << endl; @@ -182,8 +204,11 @@ void read_mac_keys(const string& directory, int player_num, int nplayers, U& key { int nn; - string filename = directory + "Player-MAC-Keys-P" + to_string(player_num); + string filename = mac_filename(directory, player_num); ifstream inpf; +#ifdef VERBOSE + cerr << "Reading MAC keys from " << filename << endl; +#endif inpf.open(filename); if (inpf.fail()) { @@ -217,7 +242,9 @@ void generate_mac_keys(typename T::mac_key_type::Scalar& keyp, gf2n& key2, for (int i = 0; i < nplayers; i++) { stringstream filename; - filename << prep_data_prefix << "Player-MAC-Keys-P" << i; + filename + << mac_filename(prep_data_prefix, + i); inpf.open(filename.str().c_str()); typename T::mac_key_type::Scalar pp; gf2n p2; @@ -254,7 +281,7 @@ void generate_mac_keys(typename T::mac_key_type::Scalar& keyp, gf2n& key2, * str = "2" or "p" */ template -void make_mult_triples(const typename T::mac_type& key, int N, int ntrip, +void make_mult_triples(const typename T::mac_key_type& key, int N, int ntrip, bool zero, string prep_data_prefix, int thread_num = -1) { PRNG G; diff --git a/README.md b/README.md index f9866cf0..3ab0d11b 100644 --- a/README.md +++ b/README.md @@ -66,16 +66,15 @@ on how to activate them. #### Protocols -The following table lists all protocols that are fully supported. Rep3 -stands for three-party replicated secret sharing. +The following table lists all protocols that are fully supported. -| Security model | Mod prime / GF(2^n) | Mod 2^k | Binary | -| --- | --- | --- | --- | -| Malicious, dishonest majority | [MASCOT](#arithmetic-circuits) | [SPDZ2k](#arithmetic-circuits) | [BMR](#bmr) | -| Covert, dishonest majority | [CowGear](#arithmetic-circuits) | N/A | N/A | -| Semi-honest, dishonest majority | [Semi](#arithmetic-circuits) | [Semi2k](#arithmetic-circuits) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3](#honest-majority) / [BMR](#bmr) | -| Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3](#honest-majority) / [BMR](#bmr) | +| Security model | Mod prime / GF(2^n) | Mod 2^k | Bin. SS | Garbling | +| --- | --- | --- | --- | --- | +| Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny](#secret-sharing) | [BMR](#bmr) | +| Covert, dishonest majority | [CowGear](#secret-sharing) | N/A | N/A | N/A | +| Semi-honest, dishonest majority | [Semi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | +| Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | #### History @@ -126,7 +125,7 @@ run it with different protocols. The section on offline phases will then explain how to benchmark the offline phases required for the SPDZ protocol. Running the online phase outputs the amount of offline material required, which allows to -compute the preprocessing time for a particulor computation. +compute the preprocessing time for a particular computation. #### Requirements - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7) @@ -229,9 +228,9 @@ Some full implementations require oblivious transfer, which is implemented as OT extension based on https://github.com/mkskeller/SimpleOT. -### Arithmetic circuits +### Secret sharing -The following table shows all programs for arithmetic dishonest-majority computation: +The following table shows all programs for dishonest-majority computation using secret sharing: | Program | Protocol | Domain | Security | Script | | --- | --- | --- | --- | --- | @@ -240,12 +239,23 @@ The following table shows all programs for arithmetic dishonest-majority computa | `semi-party.x` | OT-based | Mod prime | Semi-honest | `semi.sh` | | `semi2k-party.x` | OT-based | Mod 2^k | Semi-honest | `semi2k.sh` | | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | +| `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | +| `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | Semi and Semi2k denote the result of stripping MASCOT/SPDZ2k of all steps required for malicious security, namely amplifying, sacrificing, MAC generation, and OT correlation checks. What remains is the generation of additively shared Beaver triples using OT. +Similarly, SemiBin denotes a protocol that generates bit-wise +multiplication triples using OT without any element of malicious +security. + +Tiny denotes the adaption of SPDZ2k to the binary setting. In +particular, the SPDZ2k sacrifice does not work for bits, so we replace +it by cut-and-choose according to [Furukawa et +al.](https://eprint.iacr.org/2016/944.pdf). + CowGear denotes a covertly secure version of LowGear. The reason for this is the key generation that only achieves covert security. It is possible however to run full LowGear for triple generation by using @@ -259,7 +269,7 @@ First compile the virtual machine: `make -j8 mascot-party.x` and a high-level program, for example the tutorial (use `-R 64` for -SPDZ2k and Semi2k): +SPDZ2k and Semi2k and `-B ` for SemiBin): `./compile.py -F 64 tutorial` diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index ccc26d78..b1b5e2eb 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -32,15 +32,18 @@ run_player() { $SPDZROOT/Server.x $players $port & fi rem=$(($players - 2)) + if test "$1"; then + log_prefix=$1- + fi for i in $(seq 0 $rem); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params - log=$SPDZROOT/logs/$i + log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | { if test $i = 0; then tee $log; else cat > $log; fi; } & done last_player=$(($players - 1)) >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params - $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$last_player 2>&1 || return 1 + $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 } sleep 0.5 diff --git a/Scripts/semi-bin.sh b/Scripts/semi-bin.sh new file mode 100755 index 00000000..d7c41e3c --- /dev/null +++ b/Scripts/semi-bin.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player semi-bin-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 86abc8b6..4853d7ff 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -6,7 +6,7 @@ done function test { - if ! Scripts/$1.sh tutorial | grep 'expected -0.2, got -0.2'; then + if ! Scripts/$1.sh tutorial | grep 'weighted average: 2.333'; then Scripts/$1.sh tutorial exit 1 fi @@ -26,6 +26,6 @@ done ./compile.py -B 16 tutorial -for i in replicated mal-rep-bin yao rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do +for i in replicated mal-rep-bin semi-bin yao tiny rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do test $i done diff --git a/Scripts/tiny.sh b/Scripts/tiny.sh new file mode 100755 index 00000000..4df971fa --- /dev/null +++ b/Scripts/tiny.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player tiny-party.x $* || exit 1 diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 0d13b62e..31b4e686 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -94,7 +94,7 @@ void BitVector::input(istream& s,bool human) if (s.tellg() == 0) { cout << "IO problem. Empty file?" << endl; - throw file_error(); + throw file_error("BitVector input"); } throw end_of_file(); } diff --git a/Tools/Buffer.h b/Tools/Buffer.h index ff8910d6..758bc3b6 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -101,7 +101,7 @@ inline void Buffer::fill_buffer() } else { - char read_buffer[sizeof(buffer)]; + char read_buffer[BUFFER_SIZE * T::size()]; read(read_buffer); //memset(buffer, 0, sizeof(buffer)); for (int i = 0; i < BUFFER_SIZE; i++) diff --git a/Tools/Bundle.h b/Tools/Bundle.h index 01176861..14f66808 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -15,7 +15,7 @@ class Bundle : public vector public: T& mine; - Bundle(const Player& P) : + Bundle(const PlayerBase& P) : vector(P.num_players()), mine(this->at(P.my_num())) { } diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index 27990d57..070979cc 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -9,6 +9,7 @@ #include "Math/gfp.h" #include "Math/bigint.h" #include "Math/Z2k.h" +#include "Math/BitVec.h" #include @@ -44,10 +45,10 @@ void MMO::encrypt_and_xor(void* output, const void* input, const octet* key, _mm_storeu_si128(((__m128i*)output) + indices[i], out[i]); } -template -void MMO::hashBlocks(void* output, const void* input, size_t alloc_size, - size_t used_size) +template +void MMO::hashBlocks(void* output, const void* input, size_t alloc_size) { + size_t used_size = N_BYTES; int n_blocks = DIV_CEIL(used_size, 16); if (n_blocks > N_KEYS) throw runtime_error("not enough MMO keys"); @@ -65,7 +66,7 @@ void MMO::hashBlocks(void* output, const void* input, size_t alloc_size, template void MMO::hashBlocks(void* output, const void* input) { - hashBlocks(output, input, sizeof(T), T::size()); + hashBlocks(output, input, sizeof(T)); for (int j = 0; j < N; j++) ((T*)output + j)->normalize(); } @@ -93,7 +94,7 @@ void MMO::hashBlocks(void* output, const void* input) if (gfp1::get_ZpD().get_t() < 2) throw not_implemented(); gfp1* out = (gfp1*)output; - hashBlocks<8>(output, input, sizeof(gfp1), gfp1::size()); + hashBlocks<8, gfp1::N_BYTES>(output, input, sizeof(gfp1)); for (int i = 0; i < 8; i++) out[i].zero_overhang(); int left = 8; @@ -142,3 +143,6 @@ Z(gf2n_long) Z(Z2<64>) Z(Z2<112>) Z(Z2<128>) Z(Z2<160>) Z(Z2<114>) Z(Z2<130>) Z(Z2<72>) Z(SignedZ2<64>) Z(SignedZ2<72>) Z(gf2n_short) +Z(BitVec) +Z(Z2<41>) +Z(Z2<120>) Z(Z2<122>) Z(Z2<136>) Z(Z2<138>) diff --git a/Tools/MMO.h b/Tools/MMO.h index a6646dc7..8cab59e4 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -28,9 +28,8 @@ public: void setIV(int i, octet key[AES_BLK_SIZE]); template void hashOneBlock(void* output, const void* input) { hashBlocks((T*)output, input); } - template - void hashBlocks(void* output, const void* input, size_t alloc_size, - size_t used_size); + template + void hashBlocks(void* output, const void* input, size_t alloc_size); template void hashBlocks(void* output, const void* input); template diff --git a/Tools/NetworkOptions.cpp b/Tools/NetworkOptions.cpp index 4b3c507c..aa939cf4 100644 --- a/Tools/NetworkOptions.cpp +++ b/Tools/NetworkOptions.cpp @@ -4,6 +4,9 @@ */ #include "NetworkOptions.h" +#include "Networking/Server.h" + +using namespace std; NetworkOptions::NetworkOptions(ez::ezOptionParser& opt, int argc, const char** argv) @@ -31,3 +34,51 @@ NetworkOptions::NetworkOptions(ez::ezOptionParser& opt, int argc, opt.get("-h")->getString(hostname); opt.resetArgs(); } + +NetworkOptionsWithNumber::NetworkOptionsWithNumber(ez::ezOptionParser& opt, + int argc, const char** argv, int default_nplayers, bool variable_nplayers) : + NetworkOptions(opt, argc, argv) +{ + if (variable_nplayers) + opt.add( + to_string(default_nplayers).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of players", // Help description. + "-N", // Flag token. + "--nparties" // Flag token. + ); + + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Filename containing list of party ip addresses. Alternative to --hostname for startup coordination.", // Help description. + "-ip", // Flag token. + "--ip-file-name" // Flag token. + ); + + opt.parse(argc, argv); + + if (variable_nplayers) + opt.get("-N")->getInt(nplayers); + else + nplayers = default_nplayers; + + opt.get("-ip")->getString(ip_filename); + + opt.resetArgs(); +} + +Server* NetworkOptionsWithNumber::start_networking(Names& N, int my_num) +{ + if (ip_filename.length() > 0) + { + N.init(my_num, portnum_base, ip_filename, nplayers); + return 0; + } + else + return Server::start_networking(N, my_num, nplayers, hostname, portnum_base); +} diff --git a/Tools/NetworkOptions.h b/Tools/NetworkOptions.h index d3c11a5a..8a74271d 100644 --- a/Tools/NetworkOptions.h +++ b/Tools/NetworkOptions.h @@ -7,6 +7,8 @@ #define TOOLS_NETWORKOPTIONS_H_ #include "ezOptionParser.h" +#include "Networking/Server.h" +#include "Networking/Player.h" #include @@ -19,4 +21,16 @@ public: NetworkOptions(ez::ezOptionParser& opt, int argc, const char** argv); }; +class NetworkOptionsWithNumber : NetworkOptions +{ +public: + int nplayers; + std::string ip_filename; + + NetworkOptionsWithNumber(ez::ezOptionParser& opt, int argc, + const char** argv, int default_nplayers, bool variable_nplayers); + + Server* start_networking(Names& N, int my_num); +}; + #endif /* TOOLS_NETWORKOPTIONS_H_ */ diff --git a/Check-Offline-Z2k.cpp b/Utils/Check-Offline-Z2k.cpp similarity index 100% rename from Check-Offline-Z2k.cpp rename to Utils/Check-Offline-Z2k.cpp diff --git a/Check-Offline.cpp b/Utils/Check-Offline.cpp similarity index 100% rename from Check-Offline.cpp rename to Utils/Check-Offline.cpp diff --git a/Fake-Offline.cpp b/Utils/Fake-Offline.cpp similarity index 95% rename from Fake-Offline.cpp rename to Utils/Fake-Offline.cpp index 2b61c9d0..39f305d2 100644 --- a/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -10,6 +10,8 @@ #include "Protocols/fake-stuff.h" #include "Exceptions/Exceptions.h" #include "GC/MaliciousRepSecret.h" +#include "GC/SemiSecret.h" +#include "GC/TinySecret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" @@ -20,6 +22,7 @@ #include "Protocols/fake-stuff.hpp" #include "Processor/Data_Files.hpp" #include "Math/Z2k.hpp" +#include "GC/Secret.hpp" #include #include @@ -114,7 +117,7 @@ void make_square_tuples(const typename T::mac_type& key,int N,int ntrip,const st * ntrip = Number bits needed */ template -void make_bits(const typename T::mac_type& key, int N, int ntrip, bool zero, +void make_bits(const typename T::mac_key_type& key, int N, int ntrip, bool zero, int thread_num = -1) { PRNG G; @@ -470,6 +473,9 @@ int generate(ez::ezOptionParser& opt) if (zero) cout << "Set all values to zero" << endl; + // check compatibility + gf2n::init_field(lg2); + PRNG G; G.ReSeed(); prep_data_prefix = get_prep_dir(nplayers, lgp, lg2); @@ -486,7 +492,7 @@ int generate(ez::ezOptionParser& opt) if (mkdir_p(PREP_DIR) == -1) { cerr << "mkdir_p(" PREP_DIR ") failed\n"; - throw file_error(); + throw file_error(PREP_DIR); } generate_mac_keys(keyp, key2, nplayers, prep_data_prefix); @@ -526,5 +532,15 @@ int generate(ez::ezOptionParser& opt) make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); + make_mult_triples({}, nplayers, default_num, zero, prep_data_prefix); + make_bits({}, nplayers, default_num, zero); + + gf2n _; + Z2<40> keyt; + generate_mac_keys>(keyt, _, nplayers, prep_data_prefix); + + make_mult_triples>(keyt, nplayers, default_num, zero, prep_data_prefix); + make_bits>(keyt, nplayers, default_num, zero); + return 0; } diff --git a/Server.cpp b/Utils/Server.cpp similarity index 100% rename from Server.cpp rename to Utils/Server.cpp diff --git a/Setup.cpp b/Utils/Setup.cpp similarity index 90% rename from Setup.cpp rename to Utils/Setup.cpp index bc2fac18..40e2a95b 100644 --- a/Setup.cpp +++ b/Utils/Setup.cpp @@ -1,4 +1,5 @@ #include "Math/Setup.h" +#include "Protocols/Share.h" #include "Protocols/fake-stuff.hpp" #include #include @@ -27,7 +28,7 @@ int main(int argc, char** argv) bool need_mac = false; for (int i = 0; i < n; i++) { - string filename = mac_filename(dir, i); + string filename = mac_filename>(dir, i); ifstream in(filename); need_mac |= not in.good(); } diff --git a/check-passive.cpp b/Utils/check-passive.cpp similarity index 100% rename from check-passive.cpp rename to Utils/check-passive.cpp diff --git a/client-setup.cpp b/Utils/client-setup.cpp similarity index 100% rename from client-setup.cpp rename to Utils/client-setup.cpp diff --git a/cnc-offline.cpp b/Utils/cnc-offline.cpp similarity index 100% rename from cnc-offline.cpp rename to Utils/cnc-offline.cpp diff --git a/default-prime-length.cpp b/Utils/default-prime-length.cpp similarity index 100% rename from default-prime-length.cpp rename to Utils/default-prime-length.cpp diff --git a/galois-degree.cpp b/Utils/galois-degree.cpp similarity index 100% rename from galois-degree.cpp rename to Utils/galois-degree.cpp diff --git a/Utils/gc-emulate.cpp b/Utils/gc-emulate.cpp new file mode 100644 index 00000000..bdeb357f --- /dev/null +++ b/Utils/gc-emulate.cpp @@ -0,0 +1,36 @@ +/* + * gc-emulate.cpp + * + */ + +#include +#include +#include +#include "GC/Machine.h" +#include "GC/Processor.h" + +#include "GC/Processor.hpp" +#include "GC/Machine.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "Processor/Machine.hpp" +#include "Processor/Instruction.hpp" + +int main(int argc, char** argv) +{ + if (argc < 2) + exit(1); + + GC::Memory dynamic_memory; + GC::Machine machine; + GC::Processor processor(machine); + GC::Program program; + program.parse(string(argv[1]) + "-0"); + machine.reset(program, dynamic_memory); + processor.reset(program); + if (argc > 2) + processor.open_input_file(argv[2]); + while (program.execute(processor, dynamic_memory) != GC::DONE_BREAK); +} diff --git a/Scripts/gen_input_f2n.cpp b/Utils/gen_input_f2n.cpp similarity index 100% rename from Scripts/gen_input_f2n.cpp rename to Utils/gen_input_f2n.cpp diff --git a/Scripts/gen_input_fp.cpp b/Utils/gen_input_fp.cpp similarity index 100% rename from Scripts/gen_input_fp.cpp rename to Utils/gen_input_fp.cpp diff --git a/ot-offline.cpp b/Utils/ot-offline.cpp similarity index 100% rename from ot-offline.cpp rename to Utils/ot-offline.cpp diff --git a/pairwise-offline.cpp b/Utils/pairwise-offline.cpp similarity index 100% rename from pairwise-offline.cpp rename to Utils/pairwise-offline.cpp diff --git a/simple-offline.cpp b/Utils/simple-offline.cpp similarity index 100% rename from simple-offline.cpp rename to Utils/simple-offline.cpp diff --git a/spdz2-offline.cpp b/Utils/spdz2-offline.cpp similarity index 100% rename from spdz2-offline.cpp rename to Utils/spdz2-offline.cpp