if '_Array' not in dir(): from Compiler.oram import * from Compiler import permutation _Array = Array from Compiler import oram from functools import reduce #import pdb prog = program.Program.prog prog.set_bit_length(min(64, prog.bit_length)) class Counter(object): def __init__(self, val=0, max_val=None, size=None, value_type=sgf2n): if value_type is sgf2n: if isinstance(val, int): val = 1 << val if max_val is not None: self.bit_length = max_val+1 else: self.bit_length = sgf2n.bit_length elif value_type is sint: self.bit_length = log2(max_val+1) else: raise CompilerError('Invalid value type for Counter') self.value = value_type(val) self.value_type = value_type if isinstance(val, sgf2n): self._used = True else: self._used = False def used(self): return self._used def increment(self, b): """ Increment counter by a secret bit """ if self.value_type is sgf2n: prod = self.value * b self.value = (2*prod + self.value - prod) else: self.value = (self.value + b) self._used = True def decrement(self, b): """ Decrement counter by a secret bit """ if self.value_type is sgf2n: inv_2 = cgf2n(1) / cgf2n(2) prod = self.value * b self.value = (inv_2*prod + self.value - prod) self._used = True def reset(self): if self.value_type is sgf2n: self.value = self.value_type(1) else: self.value = self.value_type(0) self._used = False def equal(self, i): """ Equality with clear int """ if self.value_type is sgf2n: d = self.value - sgf2n(2**i) bits = d.bit_decompose(self.bit_length) return 1 - bits[i] else: return self.value.equal(i, self.bit_length) def equal_range(self, i): """ Vector of equality bits for 0, 1, ..., i-1 """ return self.value.bit_decompose(self.bit_length)[:i] def XOR(a, b): if isinstance(a, int) and isinstance(b, int): return a^b elif isinstance(a, sgf2n) or isinstance(b, sgf2n): return a + b else: try: return a ^ b except TypeError: return a + b - 2*a*b def pow2_eq(a, i, bit_length=40): """ Test for equality with 2**i, when a is a power of 2 (gf2n only)""" d = a - sgf2n(2**i) bits = d.bit_decompose(bit_length) return 1 - bits[i] def empty_entry_sorter(a, b): """ Sort by entry's empty bit (empty <= not empty) """ return (1 - a.empty()) * b.empty() def empty_entry_list_sorter(a, b): """ Sort a list by looking at first element's emptiness """ return (1 - a[0].empty()) * b[0].empty() def bucket_size_sorter(x, y): """ Sort buckets by their sizes. Bucket is a list of the form [entry_0, entry_1, ..., entry_Z, size], where size is a GF(2^n) element with a single 1 in the position corresponding to the bucket size """ Z = len(x) - 1 xs = x[-1] ys = y[-1] t = 2**Z * xs / ys # xs <= yx if bits 0 to Z of t are 0 return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z]) def shuffle(x, config=None, value_type=sgf2n, reverse=False): """ Simulate secure shuffling with Waksman network for 2 players. Returns the network switching config so it may be re-used later. """ n = len(x) if n & (n-1) != 0: raise CompilerError('shuffle requires n a power of 2') if config is None: config = permutation.configure_waksman(permutation.random_perm(n)) for i,c in enumerate(config): config[i] = [value_type(b) for b in c] permutation.waksman(x, config, reverse=reverse) permutation.waksman(x, config, reverse=reverse) return config def LT(a, b): a_bits = bit_decompose(a) b_bits = bit_decompose(b) u = cgf2n() BitLTC1(u, a_bits, b_bits, 16) class PathORAM(TreeORAM): def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, tau=3, sigma=5, stash_size=None, \ bucket_size=2, init_rounds=-1): #if size <= k: # raise CompilerError('ORAM size too small') print('create oram of size', size) self.bucket_oram = bucket_oram self.bucket_size = bucket_size self.D = log2(size) self.logD = log2(self.D) + 1 self.value_type = value_type if entry_size is not None: self.value_length = len(tuplify(entry_size)) self.entry_size = tuplify(entry_size) else: self.value_length = value_length self.entry_size = [None] * value_length self.index_size = log2(size) self.index_type = value_type.get_type(self.index_size) self.size = size self.entry_type = Entry.get_empty(*self.internal_entry_size()).types() self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type, self.get_array) if init_rounds != -1: # put memory initialization in different timer stop_timer() start_timer(1) self.buckets.init_mem(self.empty_entry()) if init_rounds != -1: stop_timer(1) start_timer() self.index = self.index_structure(size, self.D, value_type, init_rounds, True) # deterministic eviction strategy from Gentry et al. self.deterministic_eviction = True if stash_size is None: if self.deterministic_eviction: if self.bucket_size == 2: # Z=2 more efficient without sigma/tau limits tau = 20 sigma = 20 stash_size = 20 elif self.bucket_size == 3: tau = 20 sigma = 20 stash_size = 2 elif self.bucket_size == 4: tau = 3 sigma = 5 stash_size = 2 else: raise CompilerError('Bucket size %d not supported' % self.bucket_size) else: tau = 3 sigma = 5 stash_size = 48 self.tau = tau self.sigma = sigma self.stash_capacity = stash_size self.stash = TrivialORAM(stash_size, *self.internal_value_type(), \ index_size=self.index_size) # temp storage for the path + stash in eviction self.temp_size = stash_size + self.bucket_size*(self.D+1) self.temp_storage = RAM(self.temp_size, self.entry_type, self.get_array) self.temp_levels = [0] * self.temp_size # Array(self.temp_size, 'c') for i in range(self.temp_size): self.temp_levels[i] = 0 # these include a read value from the stash self.read_value = [Array(self.D + 2, self.value_type.get_type(l)) for l in self.entry_size] self.read_empty = Array(self.D + 2, self.value_type.bit_type) self.state = MemValue(self.value_type(0)) self.eviction_count = MemValue(cint(0)) # bucket and stash sizes counter #self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)] self.stash_size = Counter(0, max_val=stash_size) self.read_path = MemValue(value_type.clear_type(0)) @function_block def evict(): if self.value_type == sgf2n: self.use_shuffle_evict = True else: self.use_shuffle_evict = True leaf = random_block(self.D, self.value_type).reveal() if oram.use_insecure_randomness: leaf = self.value_type(regint.get_random(self.D)).reveal() if self.deterministic_eviction: leaf = 0 ec = self.eviction_count.read() # leaf bits already reversed so just use counter self.eviction_count.write((ec + 1) % 2**self.D) leaf = self.value_type.clear_type(ec) self.state.write(self.value_type(leaf)) print('eviction leaf =', leaf) # load the path for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): for j, ram_index in enumerate(ram_indices): self.temp_storage[i*self.bucket_size + j] = self.buckets[ram_index] self.temp_levels[i*self.bucket_size + j] = i ies = self.internal_entry_size() self.buckets[ram_index] = Entry.get_empty(*ies) # load the stash for i in range(len(self.stash.ram)): self.temp_levels[i + self.bucket_size*(self.D+1)] = 0 #for i, entry in enumerate(self.stash.ram): @for_range(len(self.stash.ram)) def f(i): entry = self.stash.ram[i] self.temp_storage[i + self.bucket_size*(self.D+1)] = entry te = Entry.get_empty(*self.internal_entry_size()) self.stash.ram[i] = te self.path_regs = [None] * self.bucket_size*(self.D+1) self.stash_regs = [None] * len(self.stash.ram) for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): for j, ram_index in enumerate(ram_indices): self.path_regs[j + i*self.bucket_size] = self.buckets[ram_index] for i in range(len(self.stash.ram)): self.stash_regs[i] = self.stash.ram[i] #self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)] if self.use_shuffle_evict: if self.bucket_size == 4: self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)] elif self.bucket_size == 2 or self.bucket_size == 3: self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0)] for j in range(self.D+1)] else: self.size_bits = [[self.value_type.bit_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)] self.stash_size = Counter(0, max_val=len(self.stash.ram)) leaf = self.state.read().reveal() if self.use_shuffle_evict: # more efficient eviction using permutation networks self.shuffle_evict(leaf) else: # naive eviction method for i,(entry, depth) in enumerate(zip(self.temp_storage, self.temp_levels)): self.evict_block(entry, depth, leaf) for i, entry in enumerate(self.stash_regs): self.stash.ram[i] = entry for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): for j, ram_index in enumerate(ram_indices): self.buckets[ram_index] = self.path_regs[i*self.bucket_size + j] self.evict = evict @method_block def read_and_remove_levels(self, u): #print 'reading path to', self.read_path leaf = self.read_path.read() for level in range(self.D + 1): ram_indices = list(self.bucket_indices_on_path_to(leaf))[level] #print 'level %d, bucket %d' % (level, ram_indices[0]/self.bucket_size) #for j in range(self.bucket_size): # #bucket.bucket.ram[j].v.reveal().print_reg('lev%d' % level) # print str(self.buckets[ram_indices[j]]) + ', ', #print '\n' #value, empty = bucket.bucket.read_and_remove(u, 1) empty_entry = self.empty_entry(False) skip = 1 found = Array(self.bucket_size, self.value_type.bit_type) entries = [self.buckets[j] for j in ram_indices] indices = [e.v for e in entries] empty_bits = [e.empty() for e in entries] for j in range(self.bucket_size): found[j] = indices[j].equal(u, self.index_size) * \ (1 - empty_bits[j]) # at most one 1 in found empty = 1 - sum(found) prod_entries = list(map(operator.mul, found, entries)) read_value = sum((entry.x.skip(skip) for entry in prod_entries), \ empty * empty_entry.x.skip(skip)) for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)): self.buckets[j] = entry - prod_entry + found[i] * empty_entry value, empty = [MemValue(v) for v in read_value], MemValue(empty) for v,w in zip(self.read_value, value): v[level] = w.read() self.read_empty[level] = empty.read() #print 'post-rar from', bucket #p_bucket.write(bucket.p_children(self.read_path & 1)) #self.read_path.irshift(1) self.check() value, empty = self.stash.read_and_remove(u, 1) for v, w in zip(self.read_value, value): v[self.D+1] = w self.read_empty[self.D+1] = empty def empty_entry(self, apply_type=True): vtype, entry_size = self.internal_entry_size() return Entry.get_empty(vtype, entry_size, apply_type, self.index_size) def shuffle_evict(self, leaf): """ Evict using oblivious shuffling etc """ evict_debug = False levels = [None] * len(self.temp_storage) bucket_sizes = Array(self.D + 2, cint) for i in range(self.D + 2): bucket_sizes[i] = regint(0) Program.prog.curr_tape.start_new_basicblock() leaf = self.state.read().reveal() if evict_debug: print_ln('\tEviction leaf: %s', leaf) for i,(entry, depth) in enumerate(zip(self.temp_storage, self.temp_levels)): lca_lev, cbits = self.compute_lca(entry.x[0], leaf, 1 - entry.empty()) level_bits = self.adjust_lca(cbits, depth, 1 - entry.empty()) # last bit indicates stash levels[i] = [sum(level_bits[j]*j for j in range(self.D+2)), level_bits[-1]] if evict_debug: @if_(1 - entry.empty().reveal()) def f(): print_ln('entry (%s, %s) going to level %s', entry.v.reveal(), entry.x[0].reveal(), levels[i][0].reveal()) print_ln('%s ' * len(level_bits), *[b.reveal() for b in level_bits]) if evict_debug: print_ln("") # sort entries+levels by emptiness: buckets already sorted so just perform a # sequence of merges on these and the stash buckets = [[[self.temp_storage[j]] + levels[j] for j in range(self.bucket_size*i,self.bucket_size*(i+1))] for i in range(self.D+1)] stash = [None] * (self.stash_capacity) for i in range(self.stash_capacity): j = i+self.bucket_size*(self.D+1) stash[i] = [self.temp_storage[j]] + levels[j] merged_entries = buckets + [stash] merged_entries = [m for sl in merged_entries for m in sl] me_len = len(merged_entries) while len(merged_entries) & (len(merged_entries)-1) != 0: merged_entries.append(None) # sort taking into account stash etc. (GF(2^n) ONLY atm) permutation.odd_even_merge_sort(merged_entries, lambda a,b: a[0].empty() * (a[-1] - 1 + b[-1]) + 1 - a[-1]) merged_entries = merged_entries[:me_len] # and sort assigned positions by emptiness (non-empty first) empty_bits_and_levels = [[0]*self.bucket_size for i in range(self.D+1)] stash_bits = 0 if evict_debug: print_str('Size bits: ') # convert bucket size bits to bits flagging emptiness for each position for j in range(self.D+1): s = self.size_bits[j] #for b in s: # b.reveal().print_reg('u%d' % j) if self.bucket_size == 4: c = s[0]*s[1] if self.value_type == sgf2n: empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)] empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)] empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)] empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)] else: empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)] empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)] empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)] empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)] elif self.bucket_size == 2: if evict_debug: print_str('%s,%s,', s[0].reveal(), s[1].reveal()) empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1]), self.value_type.clear_type(j)] empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)] elif self.bucket_size == 3: c = s[0]*s[1] empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c), self.value_type.clear_type(j)] empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)] empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c), self.value_type.clear_type(j)] if evict_debug: print_ln() empty_bits_and_levels = [x for sl in empty_bits_and_levels for x in sl] while len(empty_bits_and_levels) & (len(empty_bits_and_levels)-1) != 0: empty_bits_and_levels.append(None) permutation.odd_even_merge_sort(empty_bits_and_levels, permutation.bitwise_list_comparator) empty_bits_and_levels = [e for e in empty_bits_and_levels if e is not None] # assign levels to empty positions stash_level = self.value_type.clear_type(self.D + 1) if evict_debug: print_ln('Bits and levels: ') for i, entrylev in enumerate(merged_entries): entry = entrylev[0] level = entrylev[1] if i < len(empty_bits_and_levels): new_level = (empty_bits_and_levels[i][1] - level) * entry.empty() + level if evict_debug: print_ln('\t(empty pos %s, entry %s: empty lev %s, entry %s: new %s)', empty_bits_and_levels[i][0].reveal(), entry.empty().reveal(), empty_bits_and_levels[i][1].reveal(), level.reveal(), new_level.reveal()) else: new_level = level + stash_level * entry.empty() if evict_debug: print_ln('\t(entry %s: level %s: new %s)', entry.empty().reveal(), level.reveal(), new_level.reveal()) merged_entries[i] = [entry, new_level] if evict_debug: print_ln() # shuffle entries and levels while len(merged_entries) & (len(merged_entries)-1) != 0: merged_entries.append(None) #self.root.bucket.empty_entry(False)) permutation.rec_shuffle(merged_entries, value_type=self.value_type) merged_entries = [e for e in merged_entries if e is not None] # need to copy entries/levels to memory for re-positioning entries_ram = RAM(self.temp_size, self.entry_type, self.get_array) levels_array = Array(self.temp_size, cint) for i,entrylev in enumerate(merged_entries): if entrylev is not None: entries_ram[i] = entrylev[0] levels_array[i] = entrylev[1].reveal() Program.prog.curr_tape.start_new_basicblock() # reveal shuffled levels @for_range(self.temp_size) def f(i): level = regint(levels_array[i]) sz = regint(bucket_sizes[level]) self.temp_storage[level*self.bucket_size + sz] = entries_ram[i] bucket_sizes[level] += 1 if evict_debug: for i in range(self.D+1): @if_(bucket_sizes[i] != self.bucket_size) def f(): print_str('Sizes: ') for i in range(self.D+2): print_str('%s,', bucket_sizes[i]) print_ln() runtime_error('Incorrect bucket sizes') Program.prog.curr_tape.start_new_basicblock() for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): for j, ram_index in enumerate(ram_indices): self.buckets[ram_index] = self.temp_storage[i*self.bucket_size + j] for i in range(self.stash_capacity): self.stash.ram[i] = self.temp_storage[i + (self.D+1)*self.bucket_size] def evict_block(self, entry, level, leaf): """ Evict an entry at a given level """ #leaf = self.state.read().reveal() lca_lev, cbits = self.compute_lca(entry.x[0], leaf, 1 - entry.empty()) #, level + self.sigma) #new_lca = self.adjust_lca(cbits, level, 1 - entry.empty()) lev, assigned = self.compute_pos(entry, level, lca_lev, leaf) #print 'evicted to lev', lev.value, assigned def read_and_remove(self, u): self.read_path.write(self.read_and_renew_index(u)) self.check() self.read_and_remove_levels(u) values = (ValueTuple(x) for x in zip(*self.read_value)) not_empty = [1 - x for x in self.read_empty] read_empty = 1 - sum(not_empty) read_value = sum(list(map(operator.mul, not_empty, values)), \ ValueTuple(0 for i in range(self.value_length))) self.check(u) Program.prog.curr_tape.\ start_new_basicblock(name='read_and_remove-%d-end' % self.size) return read_value, read_empty def buckets_on_path_to(self, leaf): """ Iterator of buckets on the path to a leaf """ bucket = RefBucket(MemValue(self.root.mem.address), self, True) yield bucket for i in range(self.D): bucket = bucket.ref_children(leaf & 1) leaf >>= 1 yield bucket def bucket_indices_on_path_to(self, leaf): leaf = regint(leaf) yield list(range(self.bucket_size)) index = 0 for i in range(self.D): index = 2*index + 1 + regint(cint(leaf) & 1) leaf >>= 1 yield [index*self.bucket_size + i for i in range(self.bucket_size)] def get_bucket_indices(self, i, l): """ Get RAM indices for the i-th bucket on path to leaf l """ index = 0 for j in range(i): index = 2*index + 1 + (l & 1) l >>= 1 index = regint(index) return [index * self.bucket_size + j for j in range(self.bucket_size)] def get_bucket(self, i, l): """ Get the i-th bucket on the path to leaf l """ bucket = RefBucket(MemValue(self.root.mem.address), self, True) for j in range(i): bucket = bucket.ref_children(l & 1) l >>= 1 return bucket def get_children(self, i, l): """ Get children of the i-th bucket on level l """ j = 2**l + i - 1 return self.buckets[2*j+1], self.buckets[2*j+2] def adjust_lca(self, lca_bits, lev, not_empty, prnt=False): """ Adjust LCA based on bucket capacities (and original clear level, lev) """ found = self.value_type.bit_type(0) assigned = self.value_type.bit_type(0) try_add_here = self.value_type.bit_type(0) new_lca = [self.value_type.bit_type(0)] * (self.D + 1) upper = min(lev + self.sigma, self.D) lower = max(lev - self.tau, 0) for j in range(upper, lower-1, -1): found += lca_bits[j] try_add_here += lca_bits[j] if self.bucket_size == 4: new_lca[j] = try_add_here * (1 - self.size_bits[j][2]) # (not_empty => lca_bits all 0) #new_lca[j] = found * (1 - assigned) * (1 - self.size_bits[j][2]) * not_empty elif self.bucket_size == 2 or self.bucket_size == 3: new_lca[j] = try_add_here * (1 - self.size_bits[j][1]) if prnt: new_lca[j].reveal().print_reg('nl%d' % j) assigned += new_lca[j] if self.value_type == sgf2n: try_add_here += new_lca[j] else: try_add_here += new_lca[j] - 2*try_add_here*new_lca[j] if self.bucket_size == 4: t = new_lca[j] * self.size_bits[j][0] t2 = t * self.size_bits[j][1] # s_0 := s_0 \xor b # s_1 := s_1 \xor (s_0 & b) # s_2 := s_2 \xor (s_0 & s_1 & b) if self.value_type == sgf2n: self.size_bits[j][0] += new_lca[j] self.size_bits[j][1] += t self.size_bits[j][2] += t2 #t * self.size_bits[j][1] else: self.size_bits[j][0] += new_lca[j] - 2*t self.size_bits[j][1] += t - 2*t2 self.size_bits[j][2] += t2 # '1 if empty' bit #self.size_bits[j][3] *= (1 - new_lca[j]) elif self.bucket_size == 2 or self.bucket_size == 3: t = new_lca[j] * self.size_bits[j][0] if self.value_type == sgf2n: self.size_bits[j][0] += new_lca[j] else: self.size_bits[j][0] += new_lca[j] - 2*t self.size_bits[j][1] += t else: raise CompilerError('Bucket size %d not supported' % self.bucket_size) add_to_stash = not_empty - sum(new_lca) #final_level = sum(new_lca[i]*i for i in range(self.D+1)) + add_to_stash * (self.D+1) # #if_then(cint(reveal(not_empty))) #final_level.reveal().print_reg('lca') #for j in range(2): # for k,b in enumerate(self.size_bits[j]): # b.reveal().print_reg('u%dj%d' % (k,j)) #end_if() return new_lca + [add_to_stash] def compute_lca(self, a, b, not_empty, limit=None): """ Compute depth of the least common ancestor of a and b, upper bounded by limit """ a_bits = bit_decompose(a, self.D) b_bits = bit_decompose(b, self.D) found = [None] * self.D not_found = self.value_type.bit_type(not_empty) #1 if limit is None: limit = self.D for i in range(self.D)[:limit]: # find first position where bits differ (i.e. first 0 in 1 - a XOR b) t = 1 - XOR(a_bits[i], b_bits[i]) prev_nf = not_found not_found *= t found[i] = prev_nf - not_found if self.use_shuffle_evict: return None, found + [not_found] else: one = self.value_type.clear_type(1) lca = sum(found[i]*(one << i) for i in range(self.D)[:limit]) + \ (one << limit) * not_found return Counter(lca, max_val=limit, value_type=self.value_type), found + [not_found] def compute_pos(self, entry, lev, levstar, leaf): """ Clear integer lev, secret gf2n levstar (rep. as power of 2 with Counter object). """ pos = 0 a = 0 b = 0 not_empty = 1 - entry.empty() upper = min(lev + self.sigma, self.D) lower = max(lev - self.tau, 0) levstar_eq = levstar.equal_range(upper+1) e = 0 b = 0 for j in range(upper, lower - 1, -1): # e = want to place at this level e = (1 - b) * ((1 - e)*levstar_eq[j] + e) * not_empty # b = can place at this level b = e * (1 - self.size_bits[j][-1]) s = 1 + sgf2n(self.size_bits[j][0]) t = cgf2n(1) for i in range(1, self.bucket_size): t <<= 1 s += t * (self.size_bits[j][i-1] + self.size_bits[j][i]) size_eq = (s * b).bit_decompose(self.bucket_size) a += sum(size_eq) #self.sizes[j].value.read().reveal().print_reg('sz%d' % j) #self.sizes[j].equal(self.bucket_size).reveal().print_reg('eq') #b.reveal().print_reg('b') #print 'sz%d:' % j, self.sizes[j].value #, levstar.value, b for i in range(self.bucket_size): c = size_eq[i] #t = cint(c.reveal()) #def f(): # entry.x[1].reveal().print_reg('writ') # t.print_reg('l%di%d' % (j,i)) # entry.x[0].reveal().print_reg('w lf') #if_statement(t,f) #if c.reveal() == 1: # print 'writing block %d at level %d on path to %d' % (i,j,leaf) # print 'writing', entry*c + bucket.ram[i]*(1 - c) prev = self.path_regs[i + j*self.bucket_size] new = c * (entry - prev) + prev self.path_regs[i + j*self.bucket_size] = new self.size_bits[j][i] += c add_to_stash = not_empty - a # (1-a) * not_empty stash_eq = Counter(self.stash_size.value * add_to_stash, len(self.stash.ram)).equal_range(self.stash.size) for i,s in enumerate(self.stash_regs): c = stash_eq[i] #* add_to_stash te = c * (entry - s) + s # entry*c + s*(1 - c) self.stash_regs[i] = te self.stash_size.increment(add_to_stash) #if add_to_stash.reveal() == 1: # print 'stash', self.stash_size.value return levstar, a def add(self, entry, state=None, evict=True): if state is None: state = self.state.read() l = state x = tuple(i.read() for i in entry.x) e = Entry(entry.v.read(), (l,) + x, entry.empty()) #self.temp_storage[self.temp_size-1] = e * 1 #self.temp_levels[self.temp_size-1] = 0 #print 'adding', self.temp_storage[-1][0] try: self.stash.add(e) except Exception: print(self) raise if evict: self.evict() class LocalPathORAM(PathORAM): """ Debugging only. Path ORAM using index revealing the access pattern. """ index_structure = LocalPackedIndexStructure def OptimalORAM(size, *args, **kwargs): # threshold set from experiments (lower than in SCSL) threshold = 2**10 if size <= threshold: return LinearORAM(size,*args,**kwargs) else: return RecursivePathORAM(size, *args, **kwargs) class RecursivePathIndexStructure(PackedIndexStructure): storage = staticmethod(OptimalORAM) class RecursivePathORAM(PathORAM): index_structure = RecursivePathIndexStructure class AtLeastOneRecursionPackedPathORAM(PackedIndexStructure): storage = RecursivePathORAM class AtLeastOneRecursionPackedPathORAMWithEmpty(PackedORAMWithEmpty): storage = RecursivePathORAM class OptimalPackedPathORAMWithEmpty(PackedORAMWithEmpty): storage = staticmethod(OptimalORAM)