mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 05:27:56 -05:00
768 lines
31 KiB
Python
768 lines
31 KiB
Python
if '_Array' not in dir():
|
|
from Compiler.oram import *
|
|
from Compiler import permutation
|
|
_Array = Array
|
|
|
|
from Compiler import oram
|
|
from functools import reduce
|
|
|
|
#import pdb
|
|
|
|
try:
|
|
prog = program.Program.prog
|
|
prog.set_bit_length(min(64, prog.bit_length))
|
|
except AttributeError:
|
|
pass
|
|
|
|
class Counter(object):
|
|
def __init__(self, val=0, max_val=None, size=None, value_type=sgf2n):
|
|
if value_type is sgf2n:
|
|
if isinstance(val, int):
|
|
val = 1 << val
|
|
if max_val is not None:
|
|
self.bit_length = max_val+1
|
|
else:
|
|
self.bit_length = sgf2n.bit_length
|
|
elif value_type is sint:
|
|
self.bit_length = log2(max_val+1)
|
|
else:
|
|
raise CompilerError('Invalid value type for Counter')
|
|
self.value = value_type(val)
|
|
self.value_type = value_type
|
|
|
|
if isinstance(val, sgf2n):
|
|
self._used = True
|
|
else:
|
|
self._used = False
|
|
|
|
def used(self):
|
|
return self._used
|
|
|
|
def increment(self, b):
|
|
""" Increment counter by a secret bit """
|
|
if self.value_type is sgf2n:
|
|
prod = self.value * b
|
|
self.value = (2*prod + self.value - prod)
|
|
else:
|
|
self.value = (self.value + b)
|
|
self._used = True
|
|
|
|
def decrement(self, b):
|
|
""" Decrement counter by a secret bit """
|
|
if self.value_type is sgf2n:
|
|
inv_2 = cgf2n(1) / cgf2n(2)
|
|
prod = self.value * b
|
|
self.value = (inv_2*prod + self.value - prod)
|
|
self._used = True
|
|
|
|
def reset(self):
|
|
if self.value_type is sgf2n:
|
|
self.value = self.value_type(1)
|
|
else:
|
|
self.value = self.value_type(0)
|
|
self._used = False
|
|
|
|
def equal(self, i):
|
|
""" Equality with clear int """
|
|
if self.value_type is sgf2n:
|
|
d = self.value - sgf2n(2**i)
|
|
bits = d.bit_decompose(self.bit_length)
|
|
return 1 - bits[i]
|
|
else:
|
|
return self.value.equal(i, self.bit_length)
|
|
|
|
def equal_range(self, i):
|
|
""" Vector of equality bits for 0, 1, ..., i-1 """
|
|
return self.value.bit_decompose(self.bit_length)[:i]
|
|
|
|
def XOR(a, b):
|
|
if isinstance(a, int) and isinstance(b, int):
|
|
return a^b
|
|
elif isinstance(a, sgf2n) or isinstance(b, sgf2n):
|
|
return a + b
|
|
else:
|
|
try:
|
|
return a ^ b
|
|
except TypeError:
|
|
return a + b - 2*a*b
|
|
|
|
def pow2_eq(a, i, bit_length=40):
|
|
""" Test for equality with 2**i, when a is a power of 2 (gf2n only)"""
|
|
d = a - sgf2n(2**i)
|
|
bits = d.bit_decompose(bit_length)
|
|
return 1 - bits[i]
|
|
|
|
def empty_entry_sorter(a, b):
|
|
""" Sort by entry's empty bit (empty <= not empty) """
|
|
return (1 - a.empty()) * b.empty()
|
|
|
|
def empty_entry_list_sorter(a, b):
|
|
""" Sort a list by looking at first element's emptiness """
|
|
return (1 - a[0].empty()) * b[0].empty()
|
|
|
|
def bucket_size_sorter(x, y):
|
|
""" Sort buckets by their sizes. Bucket is a list of the form
|
|
[entry_0, entry_1, ..., entry_Z, size],
|
|
|
|
where size is a GF(2^n) element with a single 1 in the position
|
|
corresponding to the bucket size """
|
|
Z = len(x) - 1
|
|
xs = x[-1]
|
|
ys = y[-1]
|
|
t = 2**Z * xs / ys
|
|
# xs <= yx if bits 0 to Z of t are 0
|
|
return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z])
|
|
|
|
|
|
def LT(a, b):
|
|
a_bits = bit_decompose(a)
|
|
b_bits = bit_decompose(b)
|
|
u = cgf2n()
|
|
BitLTC1(u, a_bits, b_bits, 16)
|
|
|
|
class PathORAM(TreeORAM):
|
|
def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \
|
|
bucket_oram=TrivialORAM, tau=3, sigma=5, stash_size=None, \
|
|
bucket_size=2, init_rounds=-1):
|
|
#if size <= k:
|
|
# raise CompilerError('ORAM size too small')
|
|
print('create oram of size', size)
|
|
self.bucket_oram = bucket_oram
|
|
self.bucket_size = bucket_size
|
|
self.D = log2(size)
|
|
self.logD = log2(self.D) + 1
|
|
self.value_type = value_type
|
|
if entry_size is not None:
|
|
self.value_length = len(tuplify(entry_size))
|
|
self.entry_size = tuplify(entry_size)
|
|
else:
|
|
self.value_length = value_length
|
|
self.entry_size = [None] * value_length
|
|
self.index_size = log2(size)
|
|
self.index_type = value_type.get_type(self.index_size)
|
|
self.size = size
|
|
self.entry_type = Entry.get_empty(*self.internal_entry_size()).types()
|
|
|
|
self.buckets = RAM(self.bucket_size * 2**(self.D+1), self.entry_type,
|
|
self.get_array)
|
|
if init_rounds != -1:
|
|
# put memory initialization in different timer
|
|
stop_timer()
|
|
start_timer(1)
|
|
self.buckets.init_mem(self.empty_entry())
|
|
if init_rounds != -1:
|
|
stop_timer(1)
|
|
start_timer()
|
|
self.index = self.index_structure(size, self.D, value_type, init_rounds, True)
|
|
|
|
# deterministic eviction strategy from Gentry et al.
|
|
self.deterministic_eviction = True
|
|
if stash_size is None:
|
|
if self.deterministic_eviction:
|
|
if self.bucket_size == 2:
|
|
# Z=2 more efficient without sigma/tau limits
|
|
tau = 20
|
|
sigma = 20
|
|
stash_size = 20
|
|
elif self.bucket_size == 3:
|
|
tau = 20
|
|
sigma = 20
|
|
stash_size = 2
|
|
elif self.bucket_size == 4:
|
|
tau = 3
|
|
sigma = 5
|
|
stash_size = 2
|
|
else:
|
|
raise CompilerError('Bucket size %d not supported' % self.bucket_size)
|
|
else:
|
|
tau = 3
|
|
sigma = 5
|
|
stash_size = 48
|
|
|
|
self.tau = tau
|
|
self.sigma = sigma
|
|
|
|
self.stash_capacity = stash_size
|
|
self.stash = TrivialORAM(stash_size, *self.internal_value_type(), \
|
|
index_size=self.index_size)
|
|
|
|
# temp storage for the path + stash in eviction
|
|
self.temp_size = stash_size + self.bucket_size*(self.D+1)
|
|
self.temp_storage = RAM(self.temp_size, self.entry_type, self.get_array)
|
|
self.temp_levels = [0] * self.temp_size # Array(self.temp_size, 'c')
|
|
for i in range(self.temp_size):
|
|
self.temp_levels[i] = 0
|
|
|
|
# these include a read value from the stash
|
|
self.read_value = [Array(self.D + 2, self.value_type.get_type(l))
|
|
for l in self.entry_size]
|
|
self.read_empty = Array(self.D + 2, self.value_type.bit_type)
|
|
|
|
self.state = MemValue(self.value_type(0))
|
|
self.eviction_count = MemValue(cint(0))
|
|
|
|
# bucket and stash sizes counter
|
|
#self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)]
|
|
self.stash_size = Counter(0, max_val=stash_size)
|
|
|
|
self.read_path = MemValue(value_type.clear_type(0))
|
|
|
|
@function_block
|
|
def evict():
|
|
if self.value_type == sgf2n:
|
|
self.use_shuffle_evict = True
|
|
else:
|
|
self.use_shuffle_evict = True
|
|
|
|
leaf = random_block(self.D, self.value_type).reveal()
|
|
if oram.use_insecure_randomness:
|
|
leaf = self.value_type(regint.get_random(self.D)).reveal()
|
|
if self.deterministic_eviction:
|
|
leaf = 0
|
|
ec = self.eviction_count.read()
|
|
# leaf bits already reversed so just use counter
|
|
self.eviction_count.write((ec + 1) % 2**self.D)
|
|
leaf = self.value_type.clear_type(ec)
|
|
|
|
self.state.write(self.value_type(leaf))
|
|
|
|
print('eviction leaf =', leaf)
|
|
|
|
# load the path
|
|
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
|
|
for j, ram_index in enumerate(ram_indices):
|
|
self.temp_storage[i*self.bucket_size + j] = self.buckets[ram_index]
|
|
self.temp_levels[i*self.bucket_size + j] = i
|
|
ies = self.internal_entry_size()
|
|
self.buckets[ram_index] = Entry.get_empty(*ies)
|
|
|
|
# load the stash
|
|
for i in range(len(self.stash.ram)):
|
|
self.temp_levels[i + self.bucket_size*(self.D+1)] = 0
|
|
#for i, entry in enumerate(self.stash.ram):
|
|
@for_range(len(self.stash.ram))
|
|
def f(i):
|
|
entry = self.stash.ram[i]
|
|
self.temp_storage[i + self.bucket_size*(self.D+1)] = entry
|
|
|
|
te = Entry.get_empty(*self.internal_entry_size())
|
|
self.stash.ram[i] = te
|
|
|
|
self.path_regs = [None] * self.bucket_size*(self.D+1)
|
|
self.stash_regs = [None] * len(self.stash.ram)
|
|
|
|
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
|
|
for j, ram_index in enumerate(ram_indices):
|
|
self.path_regs[j + i*self.bucket_size] = self.buckets[ram_index]
|
|
for i in range(len(self.stash.ram)):
|
|
self.stash_regs[i] = self.stash.ram[i]
|
|
|
|
#self.sizes = [Counter(0, max_val=4) for i in range(self.D + 1)]
|
|
if self.use_shuffle_evict:
|
|
if self.bucket_size == 4:
|
|
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0, 0, 1)] for j in range(self.D+1)]
|
|
elif self.bucket_size == 2 or self.bucket_size == 3:
|
|
self.size_bits = [[self.value_type.bit_type(i) for i in (0, 0)] for j in range(self.D+1)]
|
|
else:
|
|
self.size_bits = [[self.value_type.bit_type(0) for i in range(self.bucket_size)] for j in range(self.D+1)]
|
|
self.stash_size = Counter(0, max_val=len(self.stash.ram))
|
|
|
|
leaf = self.state.read().reveal()
|
|
|
|
if self.use_shuffle_evict:
|
|
# more efficient eviction using permutation networks
|
|
self.shuffle_evict(leaf)
|
|
else:
|
|
# naive eviction method
|
|
for i,(entry, depth) in enumerate(zip(self.temp_storage, self.temp_levels)):
|
|
self.evict_block(entry, depth, leaf)
|
|
|
|
for i, entry in enumerate(self.stash_regs):
|
|
self.stash.ram[i] = entry
|
|
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
|
|
for j, ram_index in enumerate(ram_indices):
|
|
self.buckets[ram_index] = self.path_regs[i*self.bucket_size + j]
|
|
|
|
self.evict = evict
|
|
|
|
@method_block
|
|
def read_and_remove_levels(self, u):
|
|
#print 'reading path to', self.read_path
|
|
leaf = self.read_path.read()
|
|
for level in range(self.D + 1):
|
|
ram_indices = list(self.bucket_indices_on_path_to(leaf))[level]
|
|
#print 'level %d, bucket %d' % (level, ram_indices[0]/self.bucket_size)
|
|
#for j in range(self.bucket_size):
|
|
# #bucket.bucket.ram[j].v.reveal().print_reg('lev%d' % level)
|
|
# print str(self.buckets[ram_indices[j]]) + ', ',
|
|
#print '\n'
|
|
#value, empty = bucket.bucket.read_and_remove(u, 1)
|
|
|
|
empty_entry = self.empty_entry(False)
|
|
skip = 1
|
|
found = Array(self.bucket_size, self.value_type.bit_type)
|
|
entries = [self.buckets[j] for j in ram_indices]
|
|
indices = [e.v for e in entries]
|
|
empty_bits = [e.empty() for e in entries]
|
|
|
|
for j in range(self.bucket_size):
|
|
found[j] = indices[j].equal(u, self.index_size) * \
|
|
(1 - empty_bits[j])
|
|
|
|
# at most one 1 in found
|
|
empty = 1 - sum(found)
|
|
prod_entries = list(map(operator.mul, found, entries))
|
|
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
|
|
empty * empty_entry.x.skip(skip))
|
|
for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)):
|
|
self.buckets[j] = entry - prod_entry + found[i] * empty_entry
|
|
|
|
value, empty = [MemValue(v) for v in read_value], MemValue(empty)
|
|
|
|
for v,w in zip(self.read_value, value):
|
|
v[level] = w.read()
|
|
self.read_empty[level] = empty.read()
|
|
#print 'post-rar from', bucket
|
|
#p_bucket.write(bucket.p_children(self.read_path & 1))
|
|
#self.read_path.irshift(1)
|
|
self.check()
|
|
|
|
value, empty = self.stash.read_and_remove(u, 1)
|
|
for v, w in zip(self.read_value, value):
|
|
v[self.D+1] = w
|
|
self.read_empty[self.D+1] = empty
|
|
|
|
def empty_entry(self, apply_type=True):
|
|
vtype, entry_size = self.internal_entry_size()
|
|
return Entry.get_empty(vtype, entry_size, apply_type, self.index_size)
|
|
|
|
def shuffle_evict(self, leaf):
|
|
""" Evict using oblivious shuffling etc """
|
|
evict_debug = False
|
|
levels = [None] * len(self.temp_storage)
|
|
|
|
bucket_sizes = Array(self.D + 2, cint)
|
|
for i in range(self.D + 2):
|
|
bucket_sizes[i] = regint(0)
|
|
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
leaf = self.state.read().reveal()
|
|
|
|
if evict_debug:
|
|
print_ln('\tEviction leaf: %s', leaf)
|
|
|
|
for i,(entry, depth) in enumerate(zip(self.temp_storage, self.temp_levels)):
|
|
lca_lev, cbits = self.compute_lca(entry.x[0], leaf, 1 - entry.empty())
|
|
|
|
level_bits = self.adjust_lca(cbits, depth, 1 - entry.empty())
|
|
# last bit indicates stash
|
|
levels[i] = [sum(level_bits[j]*j for j in range(self.D+2)), level_bits[-1]]
|
|
|
|
if evict_debug:
|
|
@if_(1 - entry.empty().reveal())
|
|
def f():
|
|
print_ln('entry (%s, %s) going to level %s', entry.v.reveal(), entry.x[0].reveal(), levels[i][0].reveal())
|
|
print_ln('%s ' * len(level_bits), *[b.reveal() for b in level_bits])
|
|
if evict_debug:
|
|
print_ln("")
|
|
|
|
# sort entries+levels by emptiness: buckets already sorted so just perform a
|
|
# sequence of merges on these and the stash
|
|
buckets = [[[self.temp_storage[j]] + levels[j] for j in range(self.bucket_size*i,self.bucket_size*(i+1))] for i in range(self.D+1)]
|
|
stash = [None] * (self.stash_capacity)
|
|
|
|
for i in range(self.stash_capacity):
|
|
j = i+self.bucket_size*(self.D+1)
|
|
stash[i] = [self.temp_storage[j]] + levels[j]
|
|
|
|
merged_entries = buckets + [stash]
|
|
|
|
merged_entries = [m for sl in merged_entries for m in sl]
|
|
me_len = len(merged_entries)
|
|
while len(merged_entries) & (len(merged_entries)-1) != 0:
|
|
merged_entries.append(None)
|
|
# sort taking into account stash etc. (GF(2^n) ONLY atm)
|
|
permutation.odd_even_merge_sort(merged_entries, lambda a,b: a[0].empty() * (a[-1] - 1 + b[-1]) + 1 - a[-1])
|
|
|
|
merged_entries = merged_entries[:me_len]
|
|
|
|
# and sort assigned positions by emptiness (non-empty first)
|
|
empty_bits_and_levels = [[0]*self.bucket_size for i in range(self.D+1)]
|
|
stash_bits = 0
|
|
|
|
if evict_debug:
|
|
print_str('Size bits: ')
|
|
|
|
# convert bucket size bits to bits flagging emptiness for each position
|
|
for j in range(self.D+1):
|
|
s = self.size_bits[j]
|
|
#for b in s:
|
|
# b.reveal().print_reg('u%d' % j)
|
|
if self.bucket_size == 4:
|
|
c = s[0]*s[1]
|
|
if self.value_type == sgf2n:
|
|
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] + s[2] + c), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
|
|
else:
|
|
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c + s[2]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1] + s[2]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c + s[2]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][3] = [1 - self.value_type.bit_type(s[2]), self.value_type.clear_type(j)]
|
|
elif self.bucket_size == 2:
|
|
if evict_debug:
|
|
print_str('%s,%s,', s[0].reveal(), s[1].reveal())
|
|
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
|
|
elif self.bucket_size == 3:
|
|
c = s[0]*s[1]
|
|
empty_bits_and_levels[j][0] = [1 - self.value_type.bit_type(s[0] + s[1] - c), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][1] = [1 - self.value_type.bit_type(s[1]), self.value_type.clear_type(j)]
|
|
empty_bits_and_levels[j][2] = [1 - self.value_type.bit_type(c), self.value_type.clear_type(j)]
|
|
|
|
if evict_debug:
|
|
print_ln()
|
|
|
|
empty_bits_and_levels = [x for sl in empty_bits_and_levels for x in sl]
|
|
while len(empty_bits_and_levels) & (len(empty_bits_and_levels)-1) != 0:
|
|
empty_bits_and_levels.append(None)
|
|
|
|
permutation.odd_even_merge_sort(empty_bits_and_levels, permutation.bitwise_list_comparator)
|
|
|
|
empty_bits_and_levels = [e for e in empty_bits_and_levels if e is not None]
|
|
|
|
# assign levels to empty positions
|
|
stash_level = self.value_type.clear_type(self.D + 1)
|
|
|
|
|
|
if evict_debug:
|
|
print_ln('Bits and levels: ')
|
|
for i, entrylev in enumerate(merged_entries):
|
|
entry = entrylev[0]
|
|
level = entrylev[1]
|
|
|
|
if i < len(empty_bits_and_levels):
|
|
new_level = (empty_bits_and_levels[i][1] - level) * entry.empty() + level
|
|
if evict_debug:
|
|
print_ln('\t(empty pos %s, entry %s: empty lev %s, entry %s: new %s)', empty_bits_and_levels[i][0].reveal(), entry.empty().reveal(),
|
|
empty_bits_and_levels[i][1].reveal(), level.reveal(), new_level.reveal())
|
|
else:
|
|
new_level = level + stash_level * entry.empty()
|
|
if evict_debug:
|
|
print_ln('\t(entry %s: level %s: new %s)', entry.empty().reveal(),
|
|
level.reveal(), new_level.reveal())
|
|
merged_entries[i] = [entry, new_level]
|
|
if evict_debug:
|
|
print_ln()
|
|
|
|
# shuffle entries and levels
|
|
flat = []
|
|
for x in merged_entries:
|
|
flat += list(x[0]) + [x[1]]
|
|
flat = self.value_type(flat)
|
|
assert len(flat) % len(merged_entries) == 0
|
|
l = len(flat) // len(merged_entries)
|
|
shuffled = flat.secure_shuffle(l)
|
|
merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]]
|
|
for i in range(len(shuffled) // l)]
|
|
|
|
# need to copy entries/levels to memory for re-positioning
|
|
entries_ram = RAM(self.temp_size, self.entry_type, self.get_array)
|
|
levels_array = Array(self.temp_size, cint)
|
|
|
|
for i,entrylev in enumerate(merged_entries):
|
|
if entrylev is not None:
|
|
entries_ram[i] = entrylev[0]
|
|
levels_array[i] = entrylev[1].reveal()
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
|
|
# reveal shuffled levels
|
|
@for_range(self.temp_size)
|
|
def f(i):
|
|
level = regint(levels_array[i])
|
|
sz = regint(bucket_sizes[level])
|
|
self.temp_storage[level*self.bucket_size + sz] = entries_ram[i]
|
|
bucket_sizes[level] += 1
|
|
|
|
if evict_debug:
|
|
for i in range(self.D+1):
|
|
@if_(bucket_sizes[i] != self.bucket_size)
|
|
def f():
|
|
print_str('Sizes: ')
|
|
for i in range(self.D+2):
|
|
print_str('%s,', bucket_sizes[i])
|
|
print_ln()
|
|
runtime_error('Incorrect bucket sizes')
|
|
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
|
|
for j, ram_index in enumerate(ram_indices):
|
|
self.buckets[ram_index] = self.temp_storage[i*self.bucket_size + j]
|
|
for i in range(self.stash_capacity):
|
|
self.stash.ram[i] = self.temp_storage[i + (self.D+1)*self.bucket_size]
|
|
|
|
|
|
def evict_block(self, entry, level, leaf):
|
|
""" Evict an entry at a given level """
|
|
#leaf = self.state.read().reveal()
|
|
lca_lev, cbits = self.compute_lca(entry.x[0], leaf, 1 - entry.empty()) #, level + self.sigma)
|
|
|
|
#new_lca = self.adjust_lca(cbits, level, 1 - entry.empty())
|
|
lev, assigned = self.compute_pos(entry, level, lca_lev, leaf)
|
|
#print 'evicted to lev', lev.value, assigned
|
|
|
|
def read_and_remove(self, u):
|
|
self.read_path.write(self.read_and_renew_index(u))
|
|
self.check()
|
|
self.read_and_remove_levels(u)
|
|
values = (ValueTuple(x) for x in zip(*self.read_value))
|
|
not_empty = [1 - x for x in self.read_empty]
|
|
read_empty = 1 - sum(not_empty)
|
|
read_value = sum(list(map(operator.mul, not_empty, values)), \
|
|
ValueTuple(0 for i in range(self.value_length)))
|
|
self.check(u)
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='read_and_remove-%d-end' % self.size)
|
|
return read_value, read_empty
|
|
|
|
def buckets_on_path_to(self, leaf):
|
|
""" Iterator of buckets on the path to a leaf """
|
|
bucket = RefBucket(MemValue(self.root.mem.address), self, True)
|
|
yield bucket
|
|
for i in range(self.D):
|
|
bucket = bucket.ref_children(leaf & 1)
|
|
leaf >>= 1
|
|
yield bucket
|
|
def bucket_indices_on_path_to(self, leaf):
|
|
leaf = regint(leaf)
|
|
yield list(range(self.bucket_size))
|
|
index = 0
|
|
for i in range(self.D):
|
|
index = 2*index + 1 + regint(cint(leaf) & 1)
|
|
leaf >>= 1
|
|
yield [index*self.bucket_size + i for i in range(self.bucket_size)]
|
|
|
|
def get_bucket_indices(self, i, l):
|
|
""" Get RAM indices for the i-th bucket on path to leaf l """
|
|
index = 0
|
|
for j in range(i):
|
|
index = 2*index + 1 + (l & 1)
|
|
l >>= 1
|
|
index = regint(index)
|
|
return [index * self.bucket_size + j for j in range(self.bucket_size)]
|
|
|
|
def get_bucket(self, i, l):
|
|
""" Get the i-th bucket on the path to leaf l """
|
|
bucket = RefBucket(MemValue(self.root.mem.address), self, True)
|
|
for j in range(i):
|
|
bucket = bucket.ref_children(l & 1)
|
|
l >>= 1
|
|
return bucket
|
|
|
|
def get_children(self, i, l):
|
|
""" Get children of the i-th bucket on level l """
|
|
j = 2**l + i - 1
|
|
return self.buckets[2*j+1], self.buckets[2*j+2]
|
|
|
|
def adjust_lca(self, lca_bits, lev, not_empty, prnt=False):
|
|
""" Adjust LCA based on bucket capacities (and original clear level, lev) """
|
|
found = self.value_type.bit_type(0)
|
|
assigned = self.value_type.bit_type(0)
|
|
try_add_here = self.value_type.bit_type(0)
|
|
new_lca = [self.value_type.bit_type(0)] * (self.D + 1)
|
|
|
|
upper = min(lev + self.sigma, self.D)
|
|
lower = max(lev - self.tau, 0)
|
|
|
|
for j in range(upper, lower-1, -1):
|
|
found += lca_bits[j]
|
|
try_add_here += lca_bits[j]
|
|
if self.bucket_size == 4:
|
|
new_lca[j] = try_add_here * (1 - self.size_bits[j][2]) # (not_empty => lca_bits all 0)
|
|
#new_lca[j] = found * (1 - assigned) * (1 - self.size_bits[j][2]) * not_empty
|
|
elif self.bucket_size == 2 or self.bucket_size == 3:
|
|
new_lca[j] = try_add_here * (1 - self.size_bits[j][1])
|
|
|
|
if prnt:
|
|
new_lca[j].reveal().print_reg('nl%d' % j)
|
|
|
|
assigned += new_lca[j]
|
|
if self.value_type == sgf2n:
|
|
try_add_here += new_lca[j]
|
|
else:
|
|
try_add_here += new_lca[j] - 2*try_add_here*new_lca[j]
|
|
|
|
if self.bucket_size == 4:
|
|
t = new_lca[j] * self.size_bits[j][0]
|
|
t2 = t * self.size_bits[j][1]
|
|
# s_0 := s_0 \xor b
|
|
# s_1 := s_1 \xor (s_0 & b)
|
|
# s_2 := s_2 \xor (s_0 & s_1 & b)
|
|
if self.value_type == sgf2n:
|
|
self.size_bits[j][0] += new_lca[j]
|
|
self.size_bits[j][1] += t
|
|
self.size_bits[j][2] += t2 #t * self.size_bits[j][1]
|
|
else:
|
|
self.size_bits[j][0] += new_lca[j] - 2*t
|
|
self.size_bits[j][1] += t - 2*t2
|
|
self.size_bits[j][2] += t2
|
|
# '1 if empty' bit
|
|
#self.size_bits[j][3] *= (1 - new_lca[j])
|
|
elif self.bucket_size == 2 or self.bucket_size == 3:
|
|
t = new_lca[j] * self.size_bits[j][0]
|
|
if self.value_type == sgf2n:
|
|
self.size_bits[j][0] += new_lca[j]
|
|
else:
|
|
self.size_bits[j][0] += new_lca[j] - 2*t
|
|
self.size_bits[j][1] += t
|
|
else:
|
|
raise CompilerError('Bucket size %d not supported' % self.bucket_size)
|
|
|
|
add_to_stash = not_empty - sum(new_lca)
|
|
|
|
#final_level = sum(new_lca[i]*i for i in range(self.D+1)) + add_to_stash * (self.D+1)
|
|
#
|
|
#if_then(cint(reveal(not_empty)))
|
|
#final_level.reveal().print_reg('lca')
|
|
#for j in range(2):
|
|
# for k,b in enumerate(self.size_bits[j]):
|
|
# b.reveal().print_reg('u%dj%d' % (k,j))
|
|
#end_if()
|
|
return new_lca + [add_to_stash]
|
|
|
|
def compute_lca(self, a, b, not_empty, limit=None):
|
|
""" Compute depth of the least common ancestor of a and b, upper bounded by limit """
|
|
a_bits = bit_decompose(a, self.D)
|
|
b_bits = bit_decompose(b, self.D)
|
|
found = [None] * self.D
|
|
not_found = self.value_type.bit_type(not_empty) #1
|
|
if limit is None:
|
|
limit = self.D
|
|
|
|
for i in range(self.D)[:limit]:
|
|
# find first position where bits differ (i.e. first 0 in 1 - a XOR b)
|
|
t = 1 - XOR(a_bits[i], b_bits[i])
|
|
prev_nf = not_found
|
|
not_found *= t
|
|
found[i] = prev_nf - not_found
|
|
|
|
if self.use_shuffle_evict:
|
|
return None, found + [not_found]
|
|
else:
|
|
one = self.value_type.clear_type(1)
|
|
lca = sum(found[i]*(one << i) for i in range(self.D)[:limit]) + \
|
|
(one << limit) * not_found
|
|
return Counter(lca, max_val=limit, value_type=self.value_type), found + [not_found]
|
|
|
|
def compute_pos(self, entry, lev, levstar, leaf):
|
|
""" Clear integer lev, secret gf2n levstar (rep. as power of 2 with Counter object). """
|
|
pos = 0
|
|
a = 0
|
|
b = 0
|
|
|
|
not_empty = 1 - entry.empty()
|
|
|
|
upper = min(lev + self.sigma, self.D)
|
|
lower = max(lev - self.tau, 0)
|
|
levstar_eq = levstar.equal_range(upper+1)
|
|
e = 0
|
|
b = 0
|
|
|
|
for j in range(upper, lower - 1, -1):
|
|
# e = want to place at this level
|
|
e = (1 - b) * ((1 - e)*levstar_eq[j] + e) * not_empty
|
|
|
|
# b = can place at this level
|
|
b = e * (1 - self.size_bits[j][-1])
|
|
s = 1 + sgf2n(self.size_bits[j][0])
|
|
t = cgf2n(1)
|
|
for i in range(1, self.bucket_size):
|
|
t <<= 1
|
|
s += t * (self.size_bits[j][i-1] + self.size_bits[j][i])
|
|
size_eq = (s * b).bit_decompose(self.bucket_size)
|
|
|
|
a += sum(size_eq)
|
|
|
|
#self.sizes[j].value.read().reveal().print_reg('sz%d' % j)
|
|
#self.sizes[j].equal(self.bucket_size).reveal().print_reg('eq')
|
|
#b.reveal().print_reg('b')
|
|
#print 'sz%d:' % j, self.sizes[j].value #, levstar.value, b
|
|
for i in range(self.bucket_size):
|
|
c = size_eq[i]
|
|
#t = cint(c.reveal())
|
|
#def f():
|
|
# entry.x[1].reveal().print_reg('writ')
|
|
# t.print_reg('l%di%d' % (j,i))
|
|
# entry.x[0].reveal().print_reg('w lf')
|
|
#if_statement(t,f)
|
|
#if c.reveal() == 1:
|
|
# print 'writing block %d at level %d on path to %d' % (i,j,leaf)
|
|
# print 'writing', entry*c + bucket.ram[i]*(1 - c)
|
|
prev = self.path_regs[i + j*self.bucket_size]
|
|
new = c * (entry - prev) + prev
|
|
self.path_regs[i + j*self.bucket_size] = new
|
|
|
|
self.size_bits[j][i] += c
|
|
|
|
add_to_stash = not_empty - a # (1-a) * not_empty
|
|
stash_eq = Counter(self.stash_size.value * add_to_stash, len(self.stash.ram)).equal_range(self.stash.size)
|
|
|
|
for i,s in enumerate(self.stash_regs):
|
|
c = stash_eq[i] #* add_to_stash
|
|
te = c * (entry - s) + s # entry*c + s*(1 - c)
|
|
self.stash_regs[i] = te
|
|
self.stash_size.increment(add_to_stash)
|
|
|
|
#if add_to_stash.reveal() == 1:
|
|
# print 'stash', self.stash_size.value
|
|
|
|
return levstar, a
|
|
|
|
def add(self, entry, state=None, evict=True):
|
|
if state is None:
|
|
state = self.state.read()
|
|
l = state
|
|
x = tuple(i.read() for i in entry.x)
|
|
|
|
e = Entry(entry.v.read(), (l,) + x, entry.empty())
|
|
|
|
#self.temp_storage[self.temp_size-1] = e * 1
|
|
#self.temp_levels[self.temp_size-1] = 0
|
|
#print 'adding', self.temp_storage[-1][0]
|
|
try:
|
|
self.stash.add(e)
|
|
except Exception:
|
|
print(self)
|
|
raise
|
|
if evict:
|
|
self.evict()
|
|
|
|
class LocalPathORAM(PathORAM):
|
|
""" Debugging only. Path ORAM using index revealing the access
|
|
pattern. """
|
|
index_structure = LocalPackedIndexStructure
|
|
|
|
def OptimalORAM(size, *args, **kwargs):
|
|
# threshold set from experiments (lower than in SCSL)
|
|
threshold = 2**10
|
|
if size <= threshold:
|
|
return LinearORAM(size,*args,**kwargs)
|
|
else:
|
|
return RecursivePathORAM(size, *args, **kwargs)
|
|
|
|
class RecursivePathIndexStructure(PackedIndexStructure):
|
|
storage = staticmethod(OptimalORAM)
|
|
|
|
class RecursivePathORAM(PathORAM):
|
|
index_structure = RecursivePathIndexStructure
|
|
|
|
class AtLeastOneRecursionPackedPathORAM(PackedIndexStructure):
|
|
storage = RecursivePathORAM
|
|
|
|
class AtLeastOneRecursionPackedPathORAMWithEmpty(PackedORAMWithEmpty):
|
|
storage = RecursivePathORAM
|
|
|
|
class OptimalPackedPathORAMWithEmpty(PackedORAMWithEmpty):
|
|
storage = staticmethod(OptimalORAM)
|