from Compiler.oram import * from Compiler.path_oram import PathORAM, XOR from Compiler.util import bit_compose def first_diff(a_bits, b_bits): length = len(a_bits) level_bits = [None] * length not_found = 1 for i in range(length): # find first position where bits differ (i.e. first 0 in 1 - a XOR b) t = 1 - XOR(a_bits[i], b_bits[i]) prev_nf = not_found not_found *= t level_bits[i] = prev_nf - not_found return level_bits, not_found def find_deeper(a, b, path, start, length, compute_level=True): a_bits = a.value.bit_decompose(length) b_bits = b.value.bit_decompose(length) path_bits = [type(a_bits[0])(x) for x in path.bit_decompose(length)] a_bits.reverse() b_bits.reverse() path_bits.reverse() level_bits = [0] * length # make sure that winner is set at start if one input is empty any_empty = OR(a.empty, b.empty) a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)] b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)] diff = [XOR(ab, bb) for ab,bb in list(zip(a_bits, b_bits))[start:length]] diff_preor = type(a.value).PreOR([any_empty] + diff) diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])] winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty) winner_bits = [if_else(winner, bd, ad) for ad,bd in zip(a_diff, b_diff)] winner_preor = type(a.value).PreOR(winner_bits) level_bits = [x - y for x,y in zip(winner_preor, [0] + winner_preor)] return [0] * start + level_bits + [1 - sum(level_bits)], winner def find_deepest(paths, search_path, start, length, compute_level=True): if len(paths) == 1: return None, paths[0], 1 l = len(paths) // 2 _, a, a_index = find_deepest(paths[:l], search_path, start, length, False) _, b, b_index = find_deepest(paths[l:], search_path, start, length, False) level, winner = find_deeper(a, b, search_path, start, length, compute_level) return level, if_else(winner, b, a), if_else(winner, b_index << l, a_index) def ge_unary_public(a, b): return sum(a[b-1:]) def gu_step(high, low): greater = high[0] * (1 - high[1]) not_greater = high[1] return if_else(not_greater, 0, high[0] + low[0]), \ if_else(greater, 0, high[1] + low[1]) def greater_unary(a, b): if len(a) == 1: return a[0], b[0] else: l = len(a) // 2 return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l])) def comp_step(high, low): prod = high[0] * high[1] greater = high[0] - prod smaller = high[1] - prod deferred = 1 - greater - smaller indicator = greater, smaller, deferred return sum(map(operator.mul, indicator, (1, 0, low[0]))), \ sum(map(operator.mul, indicator, (0, 1, low[1]))) def comp_binary(a, b): if len(a) != len(b): raise CompilerError('Arguments must have same length: %s %s' % (str(a), str(b))) if len(a) == 1: return a[0], b[0] else: l = len(a) // 2 return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l])) def unary_to_binary(l): return sum(x * (i + 1) for i,x in enumerate(l)).bit_decompose(log2(len(l) + 1)) class CircuitORAM(PathORAM): def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \ stash_size=None, bucket_size=2, init_rounds=-1): self.bucket_oram = TrivialORAM self.bucket_size = bucket_size self.D = log2(size) self.logD = log2(self.D) self.L = self.D + 1 print('create oram of size %d with depth %d and %d buckets' \ % (size, self.D, self.n_buckets())) self.value_type = value_type self.index_type = value_type.get_type(self.D) if entry_size is not None: self.value_length = len(tuplify(entry_size)) self.entry_size = tuplify(entry_size) else: self.value_length = value_length self.entry_size = [None] * value_length self.index_size = log2(size) self.size = size empty_entry = Entry.get_empty(*self.internal_entry_size(), \ index_size=self.index_size) self.entry_type = empty_entry.types() self.ram = RAM(self.bucket_size * 2**(self.D+1), self.entry_type, \ self.get_array) self.buckets = self.ram if init_rounds != -1: # put memory initialization in different timer stop_timer() start_timer(1) self.ram.init_mem(self.empty_entry(apply_type=False)) if init_rounds != -1: stop_timer(1) start_timer() self.root = RefBucket(1, self) self.index = self.index_structure(size, self.D, value_type, init_rounds, True) stash_size = 20 vt, es = self.internal_entry_size() self.stash = TrivialORAM(stash_size, vt, entry_size=es, \ index_size=self.index_size) self.t = MemValue(regint(0)) self.state = MemValue(self.value_type.get_type(self.D)(0)) self.read_path = MemValue(value_type.clear_type(0)) # these include a read value from the stash self.read_value = [Array(self.D + 2, self.value_type.get_type(l)) for l in self.entry_size] self.read_empty = Array(self.D + 2, self.value_type.bit_type) def get_ram_index(self, path, level): clear_type = self.value_type.clear_type return ((2**(self.D) + clear_type.conv(path)) >> (self.D - (level - 1))) def get_bucket_ram(self, path, level): if level == 0: return self.stash.ram else: return RefRAM(self.get_ram_index(path, level), self) def get_bucket_oram(self, path, level): if level == 0: return self.stash else: return RefTrivialORAM(self.get_ram_index(path, level), self) def prepare_deepest(self, path): deepest = [None] * (self.D + 2) deepest_index = [None] * (self.D + 2) src = Value() stash_empty = self.stash.ram.is_empty() level, _, index = find_deepest(self.stash.ram.get_value_array(0), path, 0, self.D) goal = if_else(stash_empty, ValueTuple([0] * len(level)), unary_to_binary(level)) src = if_else(stash_empty, src, Value(0)) src_index = if_else(stash_empty, 0, index) buckets = [self.get_bucket_ram(path, i) for i in range(self.L + 1)] bucket_deepest = [(goal, src, src_index, None)] for i in range(1, self.L): l, _, index = find_deepest(buckets[i].get_value_array(0), path, i - 1, self.D) bucket_deepest.append((unary_to_binary(l), Value(i), index, i)) def op(left, right, void=None): goal, src, src_index, _ = left l, secret_i, index, i = right high, low = comp_binary(l, goal) replace = high * (1 - low) * (1 - buckets[i].is_empty()) goal = if_else(replace, bit_compose(l), \ bit_compose(goal)).bit_decompose(len(goal)) src = if_else(replace, secret_i, src) src_index = if_else(replace, index, src_index) return goal, src, src_index, i preop_bucket_deepest = self.value_type.PreOp(op, bucket_deepest) for i in range(1, self.L + 1): goal, src, src_index, _ = preop_bucket_deepest[i-1] high, low = comp_binary(goal, bit_decompose(i, len(goal))) cond = 1 - low * (1 - high) deepest[i] = if_else(cond, src, Value()) deepest_index[i] = if_else(cond, src_index, 0) return deepest, deepest_index def prepare_target(self, path, deepest): deepest, deepest_index = deepest dest = Value() src = Value() src_index = 0 target = [None] * (self.L + 1) target_index = [None] * (self.L + 1) for i in range(self.L, -1 , -1): i_eq_src = src.equal(i, self.logD + 1) target[i] = if_else(i_eq_src, dest, Value()) target_index[i] = if_else(i_eq_src, src_index, 0) dest = if_else(i_eq_src, Value(), dest) src = if_else(i_eq_src, Value(), src) if i == 0: break cond = or_op(dest.empty * self.get_bucket_ram(path, i).has_empty_entry(), \ (1 - target[i].empty)) * (1 - deepest[i].empty) src = if_else(cond, deepest[i], src) src_index = if_else(cond, deepest_index[i], src_index) dest = if_else(cond, Value(i), dest) return target, target_index def evict_once(self, path): deepest = self.prepare_deepest(path) target = self.prepare_target(path, deepest) evictor = self.evict_once_fast(path, target) next(evictor) towrite = next(evictor) yield self.add_evicted(path, towrite) yield def evict_once_fast(self, path, target): target, target_index = target empty_entry = Entry.get_empty(*self.internal_entry_size(), \ index_size=self.index_size) hold = empty_entry dest = Value() towrite = [None] * (self.L + 1) for i in range(self.L + 1): cond = (1 - hold.is_empty) * (dest.equal(i, self.logD + 1)) towrite[i] = if_else(cond, hold, empty_entry) hold = if_else(cond, empty_entry, hold) dest = if_else(cond, Value(), dest) cond = 1 - target[i].empty bucket = self.get_bucket_oram(path, i) if i != self.L: index = target_index[i].bit_decompose(bucket.size) hold = if_else(cond, bucket.read_and_remove_by_public(index), hold) dest = if_else(cond, target[i], dest) if i == 1: yield yield towrite def add_evicted(self, path, towrite): # make sure to add after removing for i in range(1, self.L + 1): self.get_bucket_oram(path, i).add(towrite[i]) def evict_rounds(self): get_path = lambda x: bit_compose(reversed(x.bit_decompose(self.D))) paths = [get_path(2 * self.t + i) for i in range(2)] for path in paths: for _ in self.evict_once(path): yield self.t.iadd(1) def evict(self): raise CompilerError('Using this function is likely an error. Use recursive_evict() instead.') Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-%d' % self.size) for i,_ in enumerate(self.evict_rounds()): Program.prog.curr_tape.start_new_basicblock(name='circuit-evict-round-%d-%d' % (i, self.size)) def recursive_evict(self): Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-%d' % self.size) for i,_ in enumerate(self.recursive_evict_rounds()): Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size)) def recursive_evict_rounds(self): for _ in zip(self.evict_rounds(), self.index.l.recursive_evict_rounds()): yield def bucket_indices_on_path_to(self, leaf): # root is at 1, different to PathORAM for level in range(self.D + 1): base = self.get_ram_index(leaf, level + 1) * self.bucket_size yield [base + i for i in range(self.bucket_size)] def output(self): print_ln('stash') self.stash.output() @for_range(1, 2**(self.D+1)) def f(i): print_ln('node %s', self.value_type.clear_type(i)) RefRAM(i, self).output() self.index.output() def __repr__(self): return repr(self.stash) + '\n' + repr(RefBucket(1, self)) class DebugCircuitORAM(CircuitORAM): """ Debugging only. Tree ORAM using index revealing the access pattern. """ index_structure = LocalIndexStructure threshold = 2**10 def OptimalCircuitORAM(size, value_type, *args, **kwargs): if size <= threshold: print(size, 'below threshold', threshold) return LinearORAM(size, value_type, *args, **kwargs) else: print(size, 'above threshold', threshold) return RecursiveCircuitORAM(size, value_type, *args, **kwargs) class RecursiveCircuitIndexStructure(PackedIndexStructure): """ Secure index using secure tree ORAM. """ storage = staticmethod(OptimalCircuitORAM) class RecursiveCircuitORAM(CircuitORAM): """ Secure tree ORAM using secure index. """ index_structure = RecursiveCircuitIndexStructure class AtLeastOneRecursionPackedCircuitORAM(PackedIndexStructure): storage = RecursiveCircuitORAM class AtLeastOneRecursionPackedCircuitORAMWithEmpty(PackedORAMWithEmpty): storage = RecursiveCircuitORAM