from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single from Compiler.instructions import * from Compiler.util import tuplify,untuplify from Compiler import instructions,instructions_base,comparison,program,util import inspect,math import random import collections def get_program(): return instructions.program def get_tape(): return get_program().curr_tape def get_block(): return get_program().curr_block def vectorize(function): def vectorized_function(*args, **kwargs): if len(args) > 0 and isinstance(args[0], program.Tape.Register): instructions_base.set_global_vector_size(args[0].size) res = function(*args, **kwargs) instructions_base.reset_global_vector_size() elif 'size' in kwargs: instructions_base.set_global_vector_size(kwargs['size']) del kwargs['size'] res = function(*args, **kwargs) instructions_base.reset_global_vector_size() else: res = function(*args, **kwargs) return res vectorized_function.__name__ = function.__name__ return vectorized_function def set_instruction_type(function): def instruction_typed_function(*args, **kwargs): if len(args) > 0 and isinstance(args[0], program.Tape.Register): if args[0].is_gf2n: instructions_base.set_global_instruction_type('gf2n') else: instructions_base.set_global_instruction_type('modp') res = function(*args, **kwargs) instructions_base.reset_global_instruction_type() else: res = function(*args, **kwargs) return res instruction_typed_function.__name__ = function.__name__ return instruction_typed_function def print_str(s, *args): """ Print a string, with optional args for adding variables/registers with %s """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ i = 1 while 4*i < len(ss): print_char4(ss[4*(i-1):4*i]) i += 1 i = 4*(i-1) while i < len(ss): print_char(ss[i]) i += 1 if len(args) != s.count('%s'): raise CompilerError('Incorrect number of arguments for string format:', s) substrings = s.split('%s') for i,ss in enumerate(substrings): print_plain_str(ss) if i < len(args): if isinstance(args[i], MemValue): val = args[i].read() else: val = args[i] if isinstance(val, program.Tape.Register): if val.is_clear: val.print_reg_plain() else: raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfix): val.print_plain() elif isinstance(val, sfix) or isinstance(val, sfloat): raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfloat): val.print_float_plain() elif isinstance(val, list): print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) else: try: val.output() except AttributeError: print_plain_str(str(val)) def print_ln(s='', *args): """ Print line, with optional args for adding variables/registers with %s """ print_str(s, *args) print_char('\n') def print_ln_if(cond, s): if util.is_constant(cond): if cond: print_ln(s) else: s += ' ' * ((len(s) + 3) % 4) s += '\n' while s: cond.print_if(s[:4]) s = s[4:] def print_float_precision(n): print_float_prec(n) def runtime_error(msg='', *args): """ Print an error message and abort the runtime. """ print_str('User exception: ') print_ln(msg, *args) crash() def public_input(): res = regint() pubinput(res) return res # mostly obsolete functions # use the equivalent from types.py def load_int(value, size=None): return regint(value, size=size) def load_int_to_secret(value, size=None): return sint(value, size=size) def load_int_to_secret_vector(vector): res = sint(size=len(vector)) for i,val in enumerate(vector): ldsi(res[i], val) return res @vectorize def load_float_to_secret(value, sec=40): def _bit_length(x): return len(bin(x).lstrip('-0b')) num,den = value.as_integer_ratio() exp = int(round(math.log(den, 2))) nbits = _bit_length(num) if nbits > sfloat.vlen: num >>= (nbits - sfloat.vlen) exp -= (nbits - sfloat.vlen) elif nbits < sfloat.vlen: num <<= (sfloat.vlen - nbits) exp += (sfloat.vlen - nbits) if _bit_length(exp) > sfloat.plen: raise CompilerException('Cannot load floating point to secret: overflow') if num < 0: s = load_int_to_secret(1) z = load_int_to_secret(0) else: s = load_int_to_secret(0) if num == 0: z = load_int_to_secret(1) else: z = load_int_to_secret(0) v = load_int_to_secret(num) p = load_int_to_secret(exp) return sfloat(v, p, s, z) def load_clear_mem(address): return cint.load_mem(address) def load_secret_mem(address): return sint.load_mem(address) def load_mem(address, value_type): if value_type in _types: value_type = _types[value_type] return value_type.load_mem(address) @vectorize def store_in_mem(value, address): if isinstance(value, int): value = load_int(value) try: value.store_in_mem(address) except AttributeError: # legacy if value.is_clear: if isinstance(address, cint): stmci(value, address) else: stmc(value, address) else: if isinstance(address, cint): stmsi(value, address) else: stms(value, address) @set_instruction_type @vectorize def reveal(secret): try: return secret.reveal() except AttributeError: if secret.is_gf2n: res = cgf2n() else: res = cint() instructions.asm_open(res, secret) return res @vectorize def compare_secret(a, b, length, sec=40): res = sint() instructions.lts(res, a, b, length, sec) def get_input_from(player, size=None): return sint.get_input_from(player, size=size) def get_random_triple(size=None): return sint.get_random_triple(size=size) def get_random_bit(size=None): return sint.get_random_bit(size=size) def get_random_square(size=None): return sint.get_random_square(size=size) def get_random_inverse(size=None): return sint.get_random_inverse(size=size) def get_random_int(bits, size=None): return sint.get_random_int(bits, size=size) @vectorize def get_thread_number(): res = regint() ldtn(res) return res @vectorize def get_arg(): res = regint() ldarg(res) return res def make_array(l): if isinstance(l, program.Tape.Register): res = Array(1, type(l)) res[0] = l else: l = list(l) res = Array(len(l), type(l[0]) if l else cint) res.assign(l) return res class FunctionTapeCall: def __init__(self, thread, base, bases): self.thread = thread self.base = base self.bases = bases def start(self): self.thread.start(self.base) return self def join(self): self.thread.join() instructions.program.free(self.base, 'ci') for reg_type,addr in self.bases.iteritems(): get_program().free(addr, reg_type.reg_type) class Function: def __init__(self, function, name=None, compile_args=[]): self.type_args = {} self.function = function self.name = name if name is None: self.name = self.function.__name__ self.compile_args = compile_args def __call__(self, *args): args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x) if len(args) not in self.type_args: # first call type_args = collections.defaultdict(list) for i,arg in enumerate(args): type_args[get_reg_type(arg)].append(i) def wrapped_function(*compile_args): base = get_arg() bases = dict((t, regint.load_mem(base + i)) \ for i,t in enumerate(sorted(type_args))) runtime_args = [None] * len(args) for t in sorted(type_args): for i,i_arg in enumerate(type_args[t]): runtime_args[i_arg] = t.load_mem(bases[t] + i) return self.function(*(list(compile_args) + runtime_args)) self.on_first_call(wrapped_function) self.type_args[len(args)] = type_args type_args = self.type_args[len(args)] base = instructions.program.malloc(len(type_args), 'ci') bases = dict((t, get_program().malloc(len(type_args[t]), t)) \ for t in type_args) for i,reg_type in enumerate(sorted(type_args)): store_in_mem(bases[reg_type], base + i) for j,i_arg in enumerate(type_args[reg_type]): if get_reg_type(args[i_arg]) != reg_type: raise CompilerError('type mismatch') store_in_mem(args[i_arg], bases[reg_type] + j) return self.on_call(base, bases) class FunctionTape(Function): # not thread-safe def on_first_call(self, wrapped_function): self.thread = MPCThread(wrapped_function, self.name, args=self.compile_args) def on_call(self, base, bases): return FunctionTapeCall(self.thread, base, bases) def function_tape(function): return FunctionTape(function) def function_tape_with_compile_args(*args): def wrapper(function): return FunctionTape(function, compile_args=args) return wrapper def memorize(x): if isinstance(x, (tuple, list)): return tuple(memorize(i) for i in x) else: return MemValue(x) def unmemorize(x): if isinstance(x, (tuple, list)): return tuple(unmemorize(i) for i in x) else: return x.read() class FunctionBlock(Function): def on_first_call(self, wrapped_function): old_block = get_tape().active_basicblock parent_node = get_tape().req_node get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) block = get_tape().active_basicblock block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node print 'Compiling function', self.name result = wrapped_function(*self.compile_args) if result is not None: self.result = memorize(result) else: self.result = None print 'Done compiling function', self.name p_return_address = get_tape().program.malloc(1, 'ci') get_tape().function_basicblocks[block] = p_return_address return_address = regint.load_mem(p_return_address) get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False)) self.last_sub_block = get_tape().active_basicblock get_tape().close_scope(old_block, parent_node, 'end-' + self.name) old_block.set_exit(instructions.jmp(0, add_to_prog=False), get_tape().active_basicblock) self.basic_block = block def on_call(self, base, bases): if base is not None: instructions.starg(regint(base)) block = self.basic_block if block not in get_tape().function_basicblocks: raise CompilerError('unknown function') old_block = get_tape().active_basicblock old_block.set_exit(instructions.jmp(0, add_to_prog=False), block) p_return_address = get_tape().function_basicblocks[block] return_address = get_tape().new_reg('ci') old_block.return_address_store = instructions.ldint(return_address, 0) instructions.stmint(return_address, p_return_address) get_tape().start_new_basicblock(name='call-' + self.name) get_tape().active_basicblock.set_return(old_block, self.last_sub_block) get_tape().req_node.children.append(self.node) if self.result is not None: return unmemorize(self.result) def function_block(function): return FunctionBlock(function) def function_block_with_compile_args(*args): def wrapper(function): return FunctionBlock(function, compile_args=args) return wrapper def method_block(function): # If you use this, make sure to use MemValue for all member # variables. compiled_functions = {} def wrapper(self, *args): if self in compiled_functions: return compiled_functions[self](*args) else: name = '%s-%s' % (type(self).__name__, function.__name__) block = FunctionBlock(function, name=name, compile_args=(self,)) compiled_functions[self] = block return block(*args) return wrapper def cond_swap(x,y): b = x < y if isinstance(x, sfloat): res = ([], []) for i,j in enumerate(('v','p','z','s')): xx = x.__getattribute__(j) yy = y.__getattribute__(j) bx = b * xx by = b * yy res[0].append(bx + yy - by) res[1].append(xx - bx + by) return sfloat(*res[0]), sfloat(*res[1]) bx = b * x by = b * y return bx + y - by, x - bx + by def sort(a): res = a for i in range(len(a)): for j in reversed(range(i)): res[j], res[j+1] = cond_swap(res[j], res[j+1]) return res def odd_even_merge(a): if len(a) == 2: a[0], a[1] = cond_swap(a[0], a[1]) else: even = a[::2] odd = a[1::2] odd_even_merge(even) odd_even_merge(odd) a[0] = even[0] for i in range(1, len(a) / 2): a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i]) a[-1] = odd[-1] def odd_even_merge_sort(a): if len(a) == 1: return elif len(a) % 2 == 0: lower = a[:len(a)/2] upper = a[len(a)/2:] odd_even_merge_sort(lower) odd_even_merge_sort(upper) a[:] = lower + upper odd_even_merge(a) else: raise CompilerError('Length of list must be power of two') def chunky_odd_even_merge_sort(a): for i,j in enumerate(a): j.store_in_mem(i * j.sizeof()) l = 1 while l < len(a): l *= 2 k = 1 while k < l: k *= 2 def round(): for i in range(len(a)): a[i] = type(a[i]).load_mem(i * a[i].sizeof()) for i in range(len(a) / l): for j in range(l / k): base = i * l + j step = l / k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: b = a[base:base+k*step:step] for m in range(base + step, base + (k - 1) * step, 2 * step): a[m], a[m+step] = cond_swap(a[m], a[m+step]) for i in range(len(a)): a[i].store_in_mem(i * a[i].sizeof()) chunk = MPCThread(round, 'sort-%d-%d' % (l,k)) chunk.start() chunk.join() #round() for i in range(len(a)): a[i] = type(a[i]).load_mem(i * a[i].sizeof()) def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False): if n is None: n = len(a) a_base = instructions.program.malloc(n, 's') for i,j in enumerate(a): store_in_mem(j, a_base + i) instructions.program.restart_main_thread() else: a_base = a tmp_base = instructions.program.malloc(n, 's') chunks = {} threads = [] def run_threads(): for thread in threads: thread.start() for thread in threads: thread.join() del threads[:] def run_chunk(size, base): if size not in chunks: def swap_list(list_base): for i in range(size / 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) store_in_mem(x, base) store_in_mem(y, base + 1) chunks[size] = FunctionTape(swap_list, 'sort-%d' % size) return chunks[size](base) def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) lower_size = size / n_chunks / 2 * 2 n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size base = 0 round_threads = [] for i in range(n_lower_size): round_threads.append(run_chunk(lower_size, tmp_base + base)) base += lower_size for i in range(n_chunks - n_lower_size): round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) base += lower_size + 2 run_threads_in_rounds(round_threads) postproc_chunks = [] wrap_chunks = {} post_threads = [] pre_threads = [] def load_and_store(x, y, to_right): if to_right: store_in_mem(load_secret_mem(x), y) else: store_in_mem(load_secret_mem(y), x) def run_setup(k, a_addr, step, tmp_addr): if k == 2: def mem_op(preproc, a_addr, step, tmp_addr): load_and_store(a_addr, tmp_addr, preproc) load_and_store(a_addr + step, tmp_addr + 1, preproc) res = 2 else: def mem_op(preproc, a_addr, step, tmp_addr): instructions.program.curr_tape.merge_opens = False # for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)): for i in range(k - 2): m = a_addr + step + i * step load_and_store(m, tmp_addr + i, preproc) res = k - 2 if not use_chunk_wraps or k <= 4: mem_op(True, a_addr, step, tmp_addr) postproc_chunks.append((mem_op, (a_addr, step, tmp_addr))) else: if k not in wrap_chunks: pre_chunk = FunctionTape(mem_op, 'pre-%d' % k, compile_args=[True]) post_chunk = FunctionTape(mem_op, 'post-%d' % k, compile_args=[False]) wrap_chunks[k] = (pre_chunk, post_chunk) pre_chunk, post_chunk = wrap_chunks[k] pre_threads.append(pre_chunk(a_addr, step, tmp_addr)) post_threads.append(post_chunk(a_addr, step, tmp_addr)) return res def run_threads_in_rounds(all_threads): for thread in all_threads: if len(threads) == n_threads: run_threads() threads.append(thread) run_threads() del all_threads[:] def run_postproc(): run_threads_in_rounds(post_threads) for chunk,args in postproc_chunks: chunk(False, *args) postproc_chunks[:] = [] l = 1 while l < n: l *= 2 k = 1 while k < l: k *= 2 size = 0 instructions.program.curr_tape.merge_opens = False for i in range(n / l): for j in range(l / k): base = i * l + j step = l / k size += run_setup(k, a_base + base, step, tmp_base + size) run_threads_in_rounds(pre_threads) run_round(size) run_postproc() if isinstance(a, list): instructions.program.restart_main_thread() for i in range(n): a[i] = load_secret_mem(a_base + i) instructions.program.free(a_base, 's') instructions.program.free(tmp_base, 's') def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7): if n is None: n = len(a) a_base = instructions.program.malloc(n, 's') for i,j in enumerate(a): store_in_mem(j, a_base + i) instructions.program.restart_main_thread() else: a_base = a tmp_base = instructions.program.malloc(n, 's') tmp_i = instructions.program.malloc(1, 'ci') chunks = {} threads = [] def run_threads(): for thread in threads: thread.start() for thread in threads: thread.join() del threads[:] def run_threads_in_rounds(all_threads): for thread in all_threads: if len(threads) == n_threads: run_threads() threads.append(thread) run_threads() del all_threads[:] def run_chunk(size, base): if size not in chunks: def swap_list(list_base): for i in range(size / 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) store_in_mem(x, base) store_in_mem(y, base + 1) chunks[size] = FunctionTape(swap_list, 'sort-%d' % size) return chunks[size](base) def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) lower_size = size / n_chunks / 2 * 2 n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size base = 0 round_threads = [] for i in range(n_lower_size): round_threads.append(run_chunk(lower_size, tmp_base + base)) base += lower_size for i in range(n_chunks - n_lower_size): round_threads.append(run_chunk(lower_size + 2, tmp_base + base)) base += lower_size + 2 run_threads_in_rounds(round_threads) l = 1 while l < n: l *= 2 k = 1 while k < l: k *= 2 def load_and_store(x, y): if to_tmp: store_in_mem(load_secret_mem(x), y) else: store_in_mem(load_secret_mem(y), x) def outer(i): def inner(j): base = j step = l / k if k == 2: tmp_addr = regint.load_mem(tmp_i) load_and_store(base, tmp_addr) load_and_store(base + step, tmp_addr + 1) store_in_mem(tmp_addr + 2, tmp_i) else: def inner2(m): tmp_addr = regint.load_mem(tmp_i) load_and_store(m, tmp_addr) store_in_mem(tmp_addr + 1, tmp_i) range_loop(inner2, base + step, base + (k - 1) * step, step) range_loop(inner, a_base + i * l, a_base + i * l + l / k) instructions.program.curr_tape.merge_opens = False to_tmp = True store_in_mem(tmp_base, tmp_i) range_loop(outer, n / l) if k == 2: run_round(n) else: run_round(n / k * (k - 2)) instructions.program.curr_tape.merge_opens = False to_tmp = False store_in_mem(tmp_base, tmp_i) range_loop(outer, n / l) if isinstance(a, list): instructions.program.restart_main_thread() for i in range(n): a[i] = load_secret_mem(a_base + i) instructions.program.free(a_base, 's') instructions.program.free(tmp_base, 's') instructions.program.free(tmp_i, 'ci') def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32): l = sorted_length while l < len(a): l *= 2 k = 1 while k < l: k *= 2 n_outer = len(a) / l n_inner = l / k n_innermost = 1 if k == 2 else k / 2 - 1 @for_range_parallel(n_parallel / n_innermost / n_inner, n_outer) def loop(i): @for_range_parallel(n_parallel / n_innermost, n_inner) def inner(j): base = i*l + j step = l/k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: @for_range_parallel(n_parallel, n_innermost) def f(i): m1 = step + i * 2 * step m2 = m1 + base a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step]) def mergesort(A): B = Array(len(A), sint) def merge(i_left, i_right, i_end): i0 = MemValue(i_left) i1 = MemValue(i_right) @for_range(i_left, i_end) def loop(j): if_then(and_(lambda: i0 < i_right, or_(lambda: i1 >= i_end, lambda: regint(reveal(A[i0] <= A[i1]))))) B[j] = A[i0] i0.iadd(1) else_then() B[j] = A[i1] i1.iadd(1) end_if() width = MemValue(1) @do_while def width_loop(): @for_range(0, len(A), 2 * width) def merge_loop(i): merge(i, i + width, i + 2 * width) A.assign(B) width.imul(2) return width < len(A) def range_loop(loop_body, start, stop=None, step=None): if stop is None: stop = start start = 0 if step is None: step = 1 def loop_fn(i): loop_body(i) return i + step if isinstance(step, int): if step > 0: condition = lambda x: x < stop elif step < 0: condition = lambda x: x > stop else: raise CompilerError('step must not be zero') else: b = step > 0 condition = lambda x: b * (x < stop) + (1 - b) * (x > stop) while_loop(loop_fn, condition, start) if isinstance(start, int) and isinstance(stop, int) \ and isinstance(step, int): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ lambda x: ((stop - start) / step) * x[0] def for_range(start, stop=None, step=None): def decorator(loop_body): range_loop(loop_body, start, stop, step) return loop_body return decorator def for_range_parallel(n_parallel, n_loops): return map_reduce_single(n_parallel, n_loops, \ lambda *x: [], lambda *x: []) def map_reduce_single(n_parallel, n_loops, initializer, reducer, mem_state=None): if not isinstance(n_parallel, int): raise CompilerException('Number of parallel executions' \ 'must be constant') n_parallel = n_parallel or 1 if mem_state is None: # default to list of MemValues to allow varying types mem_state = [MemValue(x) for x in initializer()] use_array = False else: # use Arrays for multithread version use_array = True def decorator(loop_body): if isinstance(n_loops, int): loop_rounds = n_loops / n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel def write_state_to_memory(r): if use_array: mem_state.assign(r) else: # cannot do mem_state = [...] due to scope issue for j,x in enumerate(r): mem_state[j].write(x) # will be optimized out if n_loops <= n_parallel @for_range(loop_rounds) def f(i): state = tuplify(initializer()) for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) r = reducer(mem_state, state) write_state_to_memory(r) if isinstance(n_loops, int): state = mem_state for j in range(loop_rounds * n_parallel, n_loops): state = reducer(tuplify(loop_body(j)), state) else: @for_range(loop_rounds * n_parallel, n_loops) def f(j): r = reducer(tuplify(loop_body(j)), mem_state) write_state_to_memory(r) state = mem_state for i,x in enumerate(state): if use_array: mem_state[i] = x else: mem_state[i].write(x) def returner(): return untuplify(tuple(state)) return returner return decorator def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}): return map_reduce(n_threads, n_parallel, n_loops, \ lambda *x: [], lambda *x: [], thread_mem_req) def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ thread_mem_req={}): n_threads = n_threads or 1 if n_threads == 1 or n_loops == 1: dec = map_reduce_single(n_parallel, n_loops, initializer, reducer) if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint) return lambda loop_body: dec(lambda i: loop_body(i, thread_mem)) else: return dec def decorator(loop_body): thread_rounds = n_loops / n_threads remainder = n_loops % n_threads for t in thread_mem_req: if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') state = tuple(initializer()) def f(inc): if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) mem_state = Array(len(state), type(state[0]) \ if state else cint, args[get_arg()][1]) base = args[get_arg()][0] @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): if thread_mem_req: return loop_body(base + i, thread_mem) else: return loop_body(base + i) prog = get_program() threads = [] if thread_rounds: tape = prog.new_tape(f, (0,), 'multithread') for i in range(n_threads - remainder): mem_state = make_array(initializer()) args[remainder + i][0] = i * thread_rounds if len(mem_state): args[remainder + i][1] = mem_state.address threads.append(prog.run_tape(tape, remainder + i)) if remainder: tape1 = prog.new_tape(f, (1,), 'multithread1') for i in range(remainder): mem_state = make_array(initializer()) args[i][0] = (n_threads - remainder + i) * thread_rounds + i if len(mem_state): args[i][1] = mem_state.address threads.append(prog.run_tape(tape1, i)) for thread in threads: prog.join_tape(thread) if state: if thread_rounds: for i in range(n_threads - remainder): state = reducer(Array(len(state), type(state[0]), \ args[remainder + i][1]), state) if remainder: for i in range(remainder): state = reducer(Array(len(state), type(state[0]).reg_type, \ args[i][1]), state) def returner(): return untuplify(state) return returner return decorator def map_sum(n_threads, n_parallel, n_loops, n_items, value_types): value_types = tuplify(value_types) if len(value_types) == 1: value_types *= n_items elif len(value_types) != n_items: raise CompilerError('Incorrect number of value_types.') initializer = lambda: [t(0) for t in value_types] def summer(x,y): return tuple(a + b for a,b in zip(x,y)) return map_reduce(n_threads, n_parallel, n_loops, initializer, summer) def foreach_enumerate(a): for x in a: get_program().public_input(' '.join(str(y) for y in tuplify(x))) def decorator(loop_body): @for_range(len(a)) def f(i): loop_body(i, *(public_input() for j in range(len(tuplify(a[0]))))) return f return decorator def while_loop(loop_body, condition, arg): if not callable(condition): raise CompilerError('Condition must be callable') # store arg in stack pre_condition = condition(arg) if not isinstance(pre_condition, (bool,int)) or pre_condition: pushint(arg if isinstance(arg,regint) else regint(arg)) def loop_fn(): result = loop_body(regint.pop()) pushint(result) return condition(result) if_statement(pre_condition, lambda: do_while(loop_fn)) regint.pop() def while_do(condition, *args): def decorator(loop_body): while_loop(loop_body, condition, *args) return loop_body return decorator def do_loop(condition, loop_fn): # store initial condition to stack pushint(condition if isinstance(condition,regint) else regint(condition)) def wrapped_loop(): # save condition to stack new_cond = regint.pop() # run the loop condition = loop_fn(new_cond) pushint(condition) return condition do_while(wrapped_loop) regint.pop() def do_while(loop_fn): scope = instructions.program.curr_block parent_node = get_tape().req_node # possibly unknown loop count get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \ name='begin-loop') loop_block = instructions.program.curr_block condition = loop_fn() if callable(condition): condition = condition() branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False) instructions.program.curr_block.set_exit(branch, loop_block) get_tape().close_scope(scope, parent_node, 'end-loop') return loop_fn def if_then(condition): class State: pass state = State() if callable(condition): condition = condition() state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False instructions.program.curr_tape.if_states.append(state) def else_then(): try: state = instructions.program.curr_tape.if_states[-1] except IndexError: raise CompilerError('No open if block') if state.has_else: raise CompilerError('else block already defined') # run the else block state.if_exit_block = instructions.program.curr_block state.req_child.add_node(get_tape(), 'else-block') instructions.program.curr_tape.start_new_basicblock(state.start_block, \ name='else-block') state.else_block = instructions.program.curr_block state.has_else = True def end_if(): try: state = instructions.program.curr_tape.if_states.pop() except IndexError: raise CompilerError('No open if/else block') branch = instructions.jmpeqz(regint.conv(state.condition), 0, \ add_to_prog=False) # start next block get_tape().close_scope(state.start_block, state.req_child.parent, 'end-if') if state.has_else: # jump to else block if condition == 0 state.start_block.set_exit(branch, state.else_block) # set if block to skip else jump = instructions.jmp(0, add_to_prog=False) state.if_exit_block.set_exit(jump, instructions.program.curr_block) else: # set start block's conditional jump to next block state.start_block.set_exit(branch, instructions.program.curr_block) # nothing to compute without else state.req_child.aggregator = lambda x: x[0] def if_statement(condition, if_fn, else_fn=None): if condition is True or condition is False: # condition known at compile time if condition: if_fn() elif else_fn is not None: else_fn() else: state = if_then(condition) if_fn() if else_fn is not None: else_then() else_fn() end_if() def if_(condition): def decorator(body): if_then(condition) body() end_if() return decorator def if_e(condition): def decorator(body): if_then(condition) body() return decorator def else_(body): else_then() body() end_if() def and_(*terms): # not thread-safe p_res = instructions.program.malloc(1, 'ci') for term in terms: if_then(term()) store_in_mem(1, p_res) for term in terms: else_then() store_in_mem(0, p_res) end_if() def load_result(): res = regint.load_mem(p_res) instructions.program.free(p_res, 'ci') return res return load_result def or_(*terms): # not thread-safe p_res = instructions.program.malloc(1, 'ci') res = regint() for term in terms: if_then(term()) store_in_mem(1, p_res) else_then() store_in_mem(0, p_res) for term in terms: end_if() def load_result(): res = regint.load_mem(p_res) instructions.program.free(p_res, 'ci') return res return load_result def not_(term): return lambda: 1 - term() def start_timer(timer_id=0): get_tape().start_new_basicblock(name='pre-start-timer') start(timer_id) get_tape().start_new_basicblock(name='post-start-timer') def stop_timer(timer_id=0): get_tape().start_new_basicblock(name='pre-stop-timer') stop(timer_id) get_tape().start_new_basicblock(name='post-stop-timer') # Fixed point ops from math import ceil, log from floatingpoint import PreOR, TruncPr, two_power, shift_two def approximate_reciprocal(divisor, k, f, theta): """ returns aproximation of 1/divisor where type(divisor) = cint """ def twos_complement(x): bits = x.bit_decompose(k)[::-1] bit_array = Array(k, cint) bit_array.assign(bits) twos_result = MemValue(cint(0)) @for_range(k) def block(i): val = twos_result.read() val <<= 1 val += 1 - bit_array[i] twos_result.write(val) return twos_result.read() + 1 bit_array = Array(k, cint) bits = divisor.bit_decompose(k)[::-1] bit_array.assign(bits) cnt_leading_zeros = MemValue(regint(0)) flag = MemValue(regint(0)) cnt_leading_zeros = MemValue(regint(0)) normalized_divisor = MemValue(divisor) @for_range(k) def block(i): flag.write(flag.read() | bit_array[i] == 1) @if_(flag.read() == 0) def block(): cnt_leading_zeros.write(cnt_leading_zeros.read() + 1) normalized_divisor.write(normalized_divisor << 1) q = MemValue(two_power(k)) e = MemValue(twos_complement(normalized_divisor.read())) qr = q.read() er = e.read() for i in range(theta): qr = qr + shift_two(qr * er, k) er = shift_two(er * er, k) q = qr res = shift_two(q, (2*k - 2*f - cnt_leading_zeros)) return res def cint_cint_division(a, b, k, f): """ Goldschmidt method implemented with SE aproximation: http://stackoverflow.com/questions/2661541/picking-good-first-estimates-for-goldschmidt-division """ # theta can be replaced with something smaller # for safety we assume that is the same theta from previous GS method theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = cint(1) - 2 * cint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, cint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 for i in range(1, theta): A[i] = shift_two(A[i - 1] * W[i - 1], f) B[i] = shift_two(B[i - 1] * W[i - 1], f) W[i] = two - B[i] return (sign_a * sign_b) * A[theta - 1] from Compiler.program import Program def sint_cint_division(a, b, k, f, kappa): """ type(a) = sint, type(b) = cint """ theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) sign_b = cint(1) - 2 * cint(b < 0) sign_a = sint(1) - 2 * sint(a < 0) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) A = Array(theta, sint) B = Array(theta, cint) W = Array(theta, cint) A[0] = absolute_a B[0] = absolute_b W[0] = w0 @for_range(1, theta) def block(i): A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa) temp = shift_two(B[i - 1] * W[i - 1], f) # no reading and writing to the same variable in a for loop. W[i] = two - temp B[i] = temp return (sign_a * sign_b) * A[theta - 1] def IntDiv(a, b, k, kappa=None): return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k, kappa, nearest=True) def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): """ Goldschmidt method as presented in Catrina10, """ if 2 * k == int(get_program().options.ring): # not fitting otherwise nearest = True if get_program().options.binary: # no probabilistic truncation in binary circuits nearest = True res_f = f f = max((k - nearest) / 2 + 1, f) assert 2 * f > k - nearest theta = int(ceil(log(k/3.5) / log(2))) alpha = b.get_type(2 * k).two_power(2*f) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w y = a.extend(2 *k) * w y = y.round(2*k, f, kappa, nearest, signed=True) for i in range(theta): x = x.extend(2 * k) y = y.extend(2 * k) * (alpha + x).extend(2 * k) x = x * x y = y.round(2*k, 2*f, kappa, nearest, signed=True) x = x.round(2*k, 2*f, kappa, nearest, signed=True) y = y.extend(2 * k) * (alpha + x).extend(2 * k) y = y.round(k + 2 * f, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False): """ Approximate reciprocal of [b]: Given [b], compute [1/b] """ alpha = b.get_type(2 * k)(int(2.9142 * 2**k)) c, v = b.Norm(k, f, kappa, simplex_flag) #v should be 2**{k - m} where m is the length of the bitwise repr of [b] d = alpha - 2 * c w = d * v w = w.round(2 * k, 2 * (k - f), kappa, nearest, signed=True) # now w * 2 ^ {-f} should be an initial approximation of 1/b return w def Norm(b, k, f, kappa, simplex_flag=False): """ Computes secret integer values [c] and [v_prime] st. 2^{k-1} <= c < 2^k and c = b*v_prime """ # For simplex, we can get rid of computing abs(b) temp = None if simplex_flag == False: temp = b.less_than(0, 2 * k) elif simplex_flag == True: temp = cint(0) sign = 1 - 2 * temp # 1 - 2 * [b < 0] absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding bits = absolute_val.bit_decompose(k, kappa)[::-1] suffixes = PreOR(bits)[::-1] z = [0] * k for i in range(k - 1): z[i] = suffixes[i] - suffixes[i+1] z[k - 1] = suffixes[k-1] #doing complicated stuff to compute v = 2^{k-m} acc = cint(0) for i in range(k): acc += two_power(k-i-1) * z[i] part_reciprocal = absolute_val * acc signed_acc = sign * acc return part_reciprocal, signed_acc