mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 05:27:56 -05:00
1912 lines
76 KiB
Python
1912 lines
76 KiB
Python
"""
|
|
This module contains an implementation of the tree-based oblivious
|
|
RAM as proposed by `Shi et al. <https://eprint.iacr.org/2011/407>`_ as
|
|
well as the straight-forward construction using linear scanning.
|
|
Unlike :py:class:`~Compiler.types.Array`, this allows access by a
|
|
secret index::
|
|
|
|
a = OptimalORAM(1000)
|
|
i = sint.get_input_from(0)
|
|
a[i] = sint.get_input_from(1)
|
|
|
|
`The introductory book by Evans et
|
|
al. <https://securecomputation.org>`_ contains `a chapter dedicated to
|
|
oblivious RAM
|
|
<https://securecomputation.org/docs/ch5-obliviousdata.pdf>`_.
|
|
|
|
"""
|
|
|
|
import random
|
|
import math
|
|
import collections
|
|
import itertools
|
|
import operator
|
|
import sys
|
|
from functools import reduce
|
|
|
|
from Compiler.types import *
|
|
from Compiler.types import _secret, _register
|
|
from Compiler.library import *
|
|
from Compiler.program import Program
|
|
from Compiler import floatingpoint,comparison,permutation
|
|
|
|
from Compiler.util import *
|
|
|
|
print_access = False
|
|
sint_bit_length = 6
|
|
max_demux_bits = 3
|
|
debug = False
|
|
use_binary_search = False
|
|
n_parallel = 1024
|
|
n_threads = None
|
|
detailed_timing = False
|
|
optimal_threshold = None
|
|
n_threads_for_tree = None
|
|
debug_online = False
|
|
crash_on_overflow = False
|
|
use_insecure_randomness = False
|
|
debug_ram_size = False
|
|
single_thread = False
|
|
|
|
def maybe_start_timer(n):
|
|
if detailed_timing:
|
|
start_timer(n)
|
|
|
|
def maybe_stop_timer(n):
|
|
if detailed_timing:
|
|
stop_timer(n)
|
|
|
|
class Block(object):
|
|
def __init__(self, value, lengths):
|
|
self.value = self.value_type.hard_conv(value)
|
|
self.lengths = tuplify(lengths)
|
|
def get_slice(self):
|
|
res = []
|
|
for length,start in zip(self.lengths, series(self.lengths)):
|
|
res.append(util.bit_compose((self.bits[start:start+length])))
|
|
return res
|
|
def __repr__(self):
|
|
return '<' + str(self.value) + '>'
|
|
|
|
class intBlock(Block):
|
|
""" Bit slicing for modp. """
|
|
value_type = sint
|
|
def __init__(self, value, start, lengths, entries_per_block):
|
|
Block.__init__(self, value, lengths)
|
|
length = sum(self.lengths)
|
|
self.n_bits = length * entries_per_block
|
|
self.start = self.value_type.hard_conv(start * length)
|
|
if Program.prog.options.ring:
|
|
self.lower, trunc, self.shift = floatingpoint.SplitInRing(
|
|
self.value, self.n_bits, self.start)
|
|
else:
|
|
self.lower, self.shift = \
|
|
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
|
|
Program.prog.security, True)
|
|
trunc = (self.value - self.lower).field_div(self.shift)
|
|
self.slice = trunc.mod2m(length, self.n_bits, signed=False)
|
|
self.upper = (trunc - self.slice) * self.shift
|
|
def get_slice(self):
|
|
total_length = sum(self.lengths)
|
|
if len(self.lengths) == 1:
|
|
self.bits = self.slice.bit_decompose(total_length)
|
|
return super(intBlock, self).get_slice()
|
|
else:
|
|
res = []
|
|
remainder = self.slice
|
|
for length,start in zip(self.lengths[:-1],series(self.lengths)):
|
|
res.append(remainder.mod2m(length, total_length - start,
|
|
signed=False))
|
|
remainder -= res[-1]
|
|
remainder = remainder.trunc_zeros(length,
|
|
total_length - start, False)
|
|
res.append(remainder)
|
|
return res
|
|
def set_slice(self, value):
|
|
value = sum(v << start for v,start in zip(value, series(self.lengths)))
|
|
self.value = self.upper + self.lower + value * self.shift
|
|
return self
|
|
|
|
class gf2nBlock(Block):
|
|
""" Bit slicing for GF2n. """
|
|
value_type = sgf2n
|
|
def __init__(self, value, start, lengths, entries_per_block):
|
|
Block.__init__(self, value, lengths)
|
|
length = sum(self.lengths)
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='gf2n-block-init-%d' % entries_per_block)
|
|
used_bits = entries_per_block * length
|
|
if entries_per_block == 2:
|
|
value_bits = bit_decompose(self.value, used_bits)
|
|
prod_bits = [start * bit for bit in value_bits]
|
|
anti_bits = [v - p for v,p in zip(value_bits,prod_bits)]
|
|
self.lower = sum(bit << i for i,bit in enumerate(prod_bits[:length]))
|
|
self.bits = list(map(operator.add, anti_bits[:length], prod_bits[length:])) + \
|
|
anti_bits[length:]
|
|
self.adjust = if_else(start, 1 << length, cgf2n(1))
|
|
elif entries_per_block < 4:
|
|
value_bits = bit_decompose(self.value, used_bits)
|
|
l = log2(entries_per_block)
|
|
start_bits = bit_decompose(start, l)
|
|
choice_bits = demux(start_bits)
|
|
inv_bits = [1 - bit for bit in floatingpoint.PreOR(choice_bits, None)]
|
|
mask_bits = sum(([x] * length for x in inv_bits), [])
|
|
lower_bits = list(map(operator.mul, value_bits, mask_bits))
|
|
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
|
|
self.bits = [sum(map(operator.mul, choice_bits, value_bits[i::length])) \
|
|
for i in range(length)]
|
|
self.adjust = sum(bit << (i * length) \
|
|
for i,bit in enumerate(choice_bits))
|
|
else:
|
|
value_bits = bit_decompose(self.value, used_bits)
|
|
l = log2(entries_per_block)
|
|
start_bits = bit_decompose(start, l)
|
|
powers = [2**(2**i) for i in range(l)]
|
|
selected = [power * bit + (1 - bit) \
|
|
for bit,power in zip(start_bits,powers)]
|
|
power_start = floatingpoint.KOpL(operator.mul, selected)
|
|
bits = bit_decompose(power_start, entries_per_block)
|
|
adjust = sum(bit << (i * length) for i,bit in enumerate(bits))
|
|
pre_bits = floatingpoint.PreOpL(lambda x,y,z=None: x + y, bits)
|
|
inv_bits = [1 - bit for bit in pre_bits]
|
|
mask_bits = sum(([x] * length for x in inv_bits), [])
|
|
lower_bits = list(map(operator.mul, value_bits, mask_bits))
|
|
masked = self.value - sum(bit << i for i,bit in enumerate(lower_bits))
|
|
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
|
|
self.bits = (masked / adjust).bit_decompose(used_bits)
|
|
self.adjust = adjust
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='gf2n-block-init-end-%d' % entries_per_block)
|
|
def set_slice(self, value):
|
|
upper_bits = self.bits[sum(self.lengths):]
|
|
upper = (sum(b << i for i,b in enumerate(upper_bits)) * \
|
|
self.adjust) << sum(self.lengths)
|
|
value = sum(v << start for v,start in zip(value, series(self.lengths)))
|
|
self.value = self.lower + value * self.adjust + upper
|
|
return self
|
|
|
|
block_types = { sint: intBlock,
|
|
sgf2n: gf2nBlock,
|
|
}
|
|
|
|
def get_block(x, y, *args):
|
|
for t in block_types:
|
|
if isinstance(x, t):
|
|
return block_types[t](x, y, *args)
|
|
elif isinstance(y, t):
|
|
return block_types[t](x, y, *args)
|
|
raise CompilerError('appropiate block type not found')
|
|
|
|
def get_bit(x, index, bit_length):
|
|
if isinstance(x, sgf2n):
|
|
bits = x.bit_decompose(bit_length)
|
|
choice_bits = cgf2n(1 << index).bit_decompose(bit_length)
|
|
return sum(map(operator.mul, bits, choice_bits))
|
|
else:
|
|
return get_block(x, index, 1, bit_length).get_slice()[0]
|
|
|
|
def demux(x):
|
|
""" Demuxing like in the Galois paper. """
|
|
# res = Array(2**len(x), x[0].reg_type)
|
|
# for i,v in enumerate(demux_list(x)):
|
|
# res[i] = v
|
|
# return res
|
|
|
|
if 2**len(x) <= n_parallel:
|
|
return demux_list(x)
|
|
else:
|
|
return demux_array(x)
|
|
|
|
def demux_list(x):
|
|
n = len(x)
|
|
if n == 0:
|
|
return [1]
|
|
elif n == 1:
|
|
return [1 - x[0], x[0]]
|
|
a = demux_list(x[:n//2])
|
|
b = demux_list(x[n//2:])
|
|
n_a = len(a)
|
|
a *= len(b)
|
|
b = reduce(operator.add, ([i] * n_a for i in b))
|
|
res = list(map(operator.mul, a, b))
|
|
return res
|
|
|
|
def demux_array(x, res=None):
|
|
tmp = demux_matrix(x).array
|
|
if res:
|
|
try:
|
|
assert issubclass(x.value_type, _register)
|
|
res[:] = tmp[:]
|
|
except:
|
|
@for_range(len(res))
|
|
def _(i):
|
|
res[i] = tmp[i]
|
|
else:
|
|
res = tmp
|
|
return res
|
|
|
|
def demux_matrix(x, n_threads=None):
|
|
n = len(x)
|
|
if n == 0:
|
|
return [1]
|
|
m = len(x[0])
|
|
t = type(x[0])
|
|
res = Matrix(2**n, m, type(x[0]))
|
|
if n == 1:
|
|
res[0] = 1 - x[0]
|
|
res[1] = x[0]
|
|
else:
|
|
a = Matrix(2**(n//2), m, type(x[0]))
|
|
a.assign(demux(x[:n//2]))
|
|
b = Matrix(2**(n-n//2), m, type(x[0]))
|
|
b.assign(demux(x[n//2:]))
|
|
@for_range_opt_multithread(n_threads, len(a))
|
|
def f(i):
|
|
@for_range_opt(len(b))
|
|
def f(j):
|
|
res[j * len(a) + i][:] = a[i][:] * b[j][:]
|
|
return res
|
|
|
|
def get_first_one(x):
|
|
prefix_list = [0] + floatingpoint.PreOR(x, Program.prog.security)
|
|
return [prefix_list[i+1] - prefix_list[i] for i in range(len(x))]
|
|
|
|
class Value(object):
|
|
def __init__(self, value=None, empty=None):
|
|
if value is None:
|
|
self.empty = 1
|
|
self.value = 0
|
|
else:
|
|
try:
|
|
self.value = next(value)
|
|
self.empty = next(value)
|
|
except TypeError:
|
|
self.empty = 0 if empty is None else empty
|
|
self.value = value
|
|
def __iter__(self):
|
|
yield self.value
|
|
yield self.empty
|
|
def __add__(self, other):
|
|
return Value(self.value + other.value, self.empty + other.empty)
|
|
def __sub__(self, other):
|
|
return Value(self.value - other.value, self.empty - other.empty)
|
|
def __xor__(self, other):
|
|
return Value(self.value ^ other.value, self.empty ^ other.empty)
|
|
def __mul__(self, other):
|
|
return Value(other * self.value, other * self.empty)
|
|
__rmul__ = __mul__
|
|
def equal(self, other, length=None):
|
|
if isinstance(other, int) and isinstance(self.value, int):
|
|
return (1 - self.empty) * (other == self.value)
|
|
return (1 - self.empty) * self.value.equal(other, length)
|
|
def reveal(self):
|
|
return Value(reveal(self.value), reveal(self.empty))
|
|
def output(self):
|
|
# @if_e(self.empty)
|
|
# def f():
|
|
# print_str('<>')
|
|
# @else_
|
|
# def f():
|
|
print_str('<%s:%s>', self.empty, self.value)
|
|
def __index__(self):
|
|
return int(self.value)
|
|
def __repr__(self):
|
|
try:
|
|
value = self.empty
|
|
while True:
|
|
if value == 1:
|
|
return '<>'
|
|
if value == 0:
|
|
return '<%s>' % str(self.value)
|
|
value = value.value
|
|
except:
|
|
pass
|
|
return '<%s:%s>' % (str(self.value), str(self.empty))
|
|
|
|
class ValueTuple(tuple):
|
|
""" Works like a vector. """
|
|
def skip(self, skip):
|
|
return ValueTuple(self[skip:])
|
|
def __add__(self, other):
|
|
return ValueTuple(i + j for i,j in zip(self, other))
|
|
def __sub__(self, other):
|
|
return ValueTuple(i - j for i,j in zip(self, other))
|
|
def __xor__(self, other):
|
|
return ValueTuple(i ^ j for i,j in zip(self, other))
|
|
def __mul__(self, other):
|
|
return ValueTuple(other * i for i in self)
|
|
__rmul__ = __mul__
|
|
__rxor__ = __xor__
|
|
def output(self):
|
|
print_str('(' + ', '.join('%s' for i in range(len(self))) + ')', *self)
|
|
|
|
class Entry(object):
|
|
""" An (O)RAM entry with empty bit, index, and value. """
|
|
@staticmethod
|
|
def get_empty(value_type, entry_size, apply_type=True, index_size=None):
|
|
res = {}
|
|
for i,tt in enumerate((value_type, value_type.default_type)):
|
|
if apply_type:
|
|
apply = lambda length, x: value_type.get_type(length)(x)
|
|
else:
|
|
apply = lambda length, x: x
|
|
res[i] = Entry(apply(index_size, 0), \
|
|
tuple(apply(l, 0) for l in entry_size), \
|
|
apply(1, True), value_type)
|
|
res[0].defaults = res[1]
|
|
return res[0]
|
|
def __init__(self, v, x=None, empty=None, value_type=None):
|
|
self.created_non_empty = False
|
|
if x is None:
|
|
v = iter(v)
|
|
self.is_empty = next(v)
|
|
self.v = next(v)
|
|
self.x = ValueTuple(v)
|
|
else:
|
|
if empty is None:
|
|
self.created_non_empty = True
|
|
empty = value_type.bit_type(False)
|
|
self.is_empty = empty
|
|
self.v = v
|
|
if not isinstance(x, (tuple, list)):
|
|
x = (x,)
|
|
self.x = ValueTuple(x)
|
|
def empty(self):
|
|
return self.is_empty
|
|
def types(self):
|
|
return tuple(type(i) for i in self)
|
|
def values(self):
|
|
yield self.is_empty
|
|
yield self.v
|
|
for i in self.x:
|
|
yield i
|
|
def __iter__(self):
|
|
yield self.is_empty
|
|
yield self.v
|
|
for i in self.x:
|
|
yield i
|
|
def __len__(self):
|
|
return 2 + len(self.x)
|
|
def __repr__(self):
|
|
return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \
|
|
else '{%s: %s}' % (self.v, self.x)
|
|
def __add__(self, other):
|
|
try:
|
|
return Entry(i + j for i,j in zip(self, other))
|
|
except:
|
|
print(self, other)
|
|
raise
|
|
def __sub__(self, other):
|
|
return Entry(i - j for i,j in zip(self, other))
|
|
def __xor__(self, other):
|
|
return Entry(i ^ j for i,j in zip(self, other))
|
|
def __mul__(self, other):
|
|
try:
|
|
return Entry(other * i for i in self)
|
|
except:
|
|
print(self, other)
|
|
raise
|
|
__rmul__ = __mul__
|
|
def reveal(self):
|
|
return Entry(x.reveal() for x in self)
|
|
def output(self):
|
|
# @if_e(self.is_empty)
|
|
# def f():
|
|
# print_str('{empty=%s}', self.is_empty)
|
|
# @else_
|
|
# def f():
|
|
# print_str('{%s: %s}', self.v, self.x)\
|
|
print_str('{%s: %s,empty=%s}', self.v, self.x, self.is_empty)
|
|
|
|
class RefRAM(object):
|
|
""" RAM reference. """
|
|
def __init__(self, index, oram):
|
|
if debug_ram_size:
|
|
@if_(index >= oram.n_buckets())
|
|
def f():
|
|
print_ln('invalid bucket index %s for %s buckets', \
|
|
index, oram.n_buckets())
|
|
crash()
|
|
self.size = oram.bucket_size
|
|
self.entry_type = oram.entry_type
|
|
self.l = [oram.get_array(self.size, t, array.address + \
|
|
index * oram.bucket_size) \
|
|
for t,array in zip(self.entry_type,oram.ram.l)]
|
|
self.index = index
|
|
def init_mem(self, empty_entry):
|
|
print('init ram')
|
|
for a,value in zip(self.l, list(empty_entry.defaults.values())):
|
|
# don't use threads if n_threads explicitly set to 1
|
|
a.assign_all(value, n_threads=n_threads, conv=False)
|
|
def get_empty_bits(self):
|
|
return self.l[0]
|
|
def get_indices(self):
|
|
return self.l[1]
|
|
def get_values(self, skip=0):
|
|
return [ValueTuple(x) for x in zip(*self.l[2+skip:])]
|
|
def get_value(self, index, skip=0):
|
|
return ValueTuple(a[index] for a in self.l[2+skip:])
|
|
def get_value_length(self):
|
|
return len(self.l) - 2
|
|
def get_value_arrays(self):
|
|
return self.l[2:]
|
|
def get_value_array(self, index):
|
|
return [Value(self.l[2+index][i], self.l[0][i]) for i in range(self.size)]
|
|
def __getitem__(self, index):
|
|
if print_access:
|
|
print('get', id(self), index)
|
|
return Entry(a[index] for a in self.l)
|
|
def __setitem__(self, index, value):
|
|
if print_access:
|
|
print('set', id(self), index)
|
|
if not isinstance(value, Entry):
|
|
raise Exception('entries only please: %s' % str(value))
|
|
for i,(a,v) in enumerate(zip(self.l, list(value.values()))):
|
|
a[index] = v
|
|
def __len__(self):
|
|
return self.size
|
|
def has_empty_entry(self):
|
|
return 1 - tree_reduce(operator.mul, [1 - bit for bit in self.get_empty_bits()])
|
|
def is_empty(self):
|
|
return tree_reduce(operator.mul, list(self.get_empty_bits()))
|
|
def reveal(self):
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
res = RAM(self.size, [t.clear_type for t in self.entry_type], \
|
|
lambda *args: Array(*args), self.index)
|
|
for i,a in enumerate(self.l):
|
|
for j,x in enumerate(a):
|
|
res.l[i][j] = x.reveal()
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
return res
|
|
def output(self):
|
|
print_ln('%s', [x.reveal() for x in self])
|
|
def print_reg(self):
|
|
print_ln('listing of RAM at index %s', self.index)
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
for i,array in enumerate(self.l):
|
|
for j,reg in enumerate(array):
|
|
print_str('%s:%s ', j, reg)
|
|
print_ln()
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
def __repr__(self):
|
|
return repr(self.l)
|
|
|
|
class RAM(RefRAM):
|
|
""" List of entries in memory. """
|
|
def __init__(self, size, entry_type, get_array, index=0):
|
|
#print_reg(cint(0), 'r in')
|
|
self.size = size
|
|
self.entry_type = entry_type
|
|
self.l = [get_array(self.size, t) for t in entry_type]
|
|
self.index = index
|
|
|
|
class AbstractORAM(object):
|
|
""" Implements reading and writing using read_and_remove and add. """
|
|
@staticmethod
|
|
def get_array(size, t, *args, **kwargs):
|
|
return t.dynamic_array(size, t, *args, **kwargs)
|
|
def read(self, index):
|
|
res = self._read(self.index_type.hard_conv(index))
|
|
res = [self.value_type._new(x) for x in res]
|
|
return res
|
|
def write(self, index, value):
|
|
value = util.tuplify(value)
|
|
value = [self.value_type.conv(x) for x in value]
|
|
new_value = [self.value_type.get_type(length).hard_conv(v) \
|
|
for length,v in zip(self.entry_size, value)]
|
|
return self._write(self.index_type.hard_conv(index), *new_value)
|
|
def access(self, index, new_value, write, new_empty=False):
|
|
return self._access(self.index_type.hard_conv(index),
|
|
self.value_type.bit_type.hard_conv(write),
|
|
self.value_type.bit_type.hard_conv(new_empty),
|
|
*[self.value_type.get_type(length).hard_conv(v) \
|
|
for length,v in zip(self.entry_size, \
|
|
tuplify(new_value))])
|
|
def read_and_maybe_remove(self, index):
|
|
return self.read_and_remove(self.index_type.hard_conv(index)), \
|
|
self.state.read()
|
|
@method_block
|
|
def _read(self, index):
|
|
return self.access(index, tuple(self.value_type.get_type(l)(0) \
|
|
for l in self.entry_size), \
|
|
False)
|
|
@method_block
|
|
def _write(self, index, *value):
|
|
self.access(index, value, True)
|
|
@method_block
|
|
def _access(self, index, write, new_empty, *new_value):
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='abstract-access-remove-%d' % self.size)
|
|
index = MemValue(self.index_type.hard_conv(index))
|
|
read_value, read_empty = self.read_and_remove(index)
|
|
if len(read_value) != self.value_length:
|
|
raise Exception('read_and_remove() of %s returns wrong length of ' \
|
|
'read value: %d, should be %d' % \
|
|
(type(self), len(read_value), \
|
|
self.value_length))
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='abstract-access-add-%d' % self.size)
|
|
new_value = ValueTuple(new_value) \
|
|
if isinstance(new_value, (tuple, list)) \
|
|
else ValueTuple((new_value,))
|
|
if len(new_value) != self.value_length:
|
|
raise Exception('wrong length of new value')
|
|
value = tuple(MemValue(i) for i in if_else(write, new_value, read_value))
|
|
empty = self.value_type.bit_type.hard_conv(new_empty)
|
|
self.add(Entry(index, value, if_else(write, empty, read_empty), \
|
|
value_type=self.value_type), evict=False)
|
|
self.recursive_evict()
|
|
return read_value, read_empty
|
|
@method_block
|
|
def delete(self, index, for_real=True):
|
|
self.access(index, (self.value_type(0),) * self.value_length, \
|
|
for_real, True)
|
|
def __getitem__(self, index):
|
|
res, empty = self.read(index)
|
|
if len(res) == 1:
|
|
res = res[0]
|
|
return res
|
|
__setitem__ = write
|
|
|
|
class EmptyException(Exception):
|
|
pass
|
|
|
|
class EndRecursiveEviction(object):
|
|
recursive_evict = lambda self: None
|
|
recursive_evict_rounds = lambda self: itertools.repeat([None])
|
|
|
|
class RefTrivialORAM(EndRecursiveEviction):
|
|
""" Trivial ORAM reference. """
|
|
contiguous = False
|
|
def empty_entry(self, apply_type=True):
|
|
return Entry.get_empty(self.value_type, self.entry_size, \
|
|
apply_type, self.index_size)
|
|
def __init__(self, index, oram):
|
|
self.ram = RefRAM(index, oram)
|
|
self.index_size = oram.index_size
|
|
self.value_type, self.value_length = oram.internal_value_type()
|
|
self.value_type, self.entry_size = oram.internal_entry_size()
|
|
self.size = oram.bucket_size
|
|
def init_mem(self):
|
|
print('init trivial oram')
|
|
self.ram.init_mem(self.empty_entry(apply_type=False))
|
|
def search(self, read_index):
|
|
if use_binary_search and self.value_type == sgf2n:
|
|
return self.binary_search(read_index)
|
|
else:
|
|
indices = self.ram.get_indices()
|
|
empty_bits = self.ram.get_empty_bits()
|
|
parallel = 1024
|
|
if comparison.const_rounds:
|
|
parallel /= 4
|
|
if self.size >= 128:
|
|
#n_threads = 8 if self.size >= 8 * parallel else 1
|
|
found = Array(self.size, self.value_type)
|
|
read_index = MemValue(read_index)
|
|
@for_range_multithread(n_threads, parallel, self.size)
|
|
def f(j):
|
|
found[j] = indices[j].equal(read_index, self.index_size) * \
|
|
(1 - empty_bits[j])
|
|
else:
|
|
found = [indices[j].equal(read_index, self.index_size) * \
|
|
(1 - empty_bits[j]) for j in range(self.size)]
|
|
# at most one 1 in found
|
|
empty = 1 - sum(found)
|
|
return found, empty
|
|
def read_and_remove(self, read_index, skip=0):
|
|
empty_entry = self.empty_entry(False)
|
|
self.last_index = read_index
|
|
found, empty = self.search(read_index)
|
|
entries = [entry for entry in self.ram]
|
|
prod_entries = list(map(operator.mul, found, entries))
|
|
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
|
|
empty * empty_entry.x.skip(skip))
|
|
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
|
self.ram[i] = entry - prod_entry + found[i] * empty_entry
|
|
self.check(index=read_index, op='rar')
|
|
return read_value, empty
|
|
def read_and_maybe_remove(self, index):
|
|
return self.read_and_remove(index), 0
|
|
def read_and_remove_by_public(self, index):
|
|
empty_entry = self.empty_entry(False)
|
|
entries = [entry for entry in self.ram]
|
|
prod_entries = list(map(operator.mul, index, entries))
|
|
read_entry = reduce(operator.add, prod_entries)
|
|
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
|
self.ram[i] = entry - prod_entry + index[i] * empty_entry
|
|
return read_entry
|
|
@method_block
|
|
def _read(self, index):
|
|
found, empty = self.search(index)
|
|
read_value = sum(list(map(operator.mul, found, self.ram.get_values())), \
|
|
empty * self.empty_entry(False).x)
|
|
return read_value, empty
|
|
@method_block
|
|
def _access(self, index, write, new_empty, *new_value):
|
|
empty_entry = self.empty_entry(False)
|
|
found, not_found = self.search(index)
|
|
add_here = self.find_first_empty()
|
|
entries = [entry for entry in self.ram]
|
|
prod_values = list(map(operator.mul, found, \
|
|
(entry.x for entry in entries)))
|
|
read_value = sum(prod_values, not_found * empty_entry.x)
|
|
new_value = ValueTuple(new_value) \
|
|
if isinstance(new_value, (tuple, list)) \
|
|
else ValueTuple((new_value,))
|
|
for i,(entry,prod_value) in enumerate(zip(entries, prod_values)):
|
|
access_here = found[i] + not_found * add_here[i]
|
|
delta_entry = Entry(access_here * (index - entry.v), \
|
|
access_here * (new_value - entry.x), \
|
|
found[i] - \
|
|
if_else(new_empty, 0, access_here))
|
|
self.ram[i] = entry + write * delta_entry
|
|
return read_value, not_found
|
|
def check(self, found=None, index=None, new_entry=None, op=''):
|
|
if debug:
|
|
if found is None:
|
|
found = set()
|
|
for i,entry in enumerate(self.ram):
|
|
if not entry.empty():
|
|
if entry.v in found:
|
|
raise Exception('found double %s in %s' % (str(entry.v), str(self.ram.l)))
|
|
found.add(entry.v)
|
|
if index is not None:
|
|
for i,entry in enumerate(self.ram):
|
|
if not entry.empty() and index == entry.v:
|
|
raise Exception('not removed %s in %s' % \
|
|
(str(index), str(self.ram.l)))
|
|
if debug_online or debug:
|
|
#cint(0).print_reg(op)
|
|
entries = self.ram.reveal()
|
|
if index is not None:
|
|
index = index.reveal()
|
|
if new_entry is not None:
|
|
new_entry = Entry(x.reveal() for x in new_entry)
|
|
n_found = MemValue(0)
|
|
@for_range(self.size)
|
|
def f(i):
|
|
entry = entries[i]
|
|
@if_(entry.empty() != 1)
|
|
def f():
|
|
@if_e(entry.empty() == 0)
|
|
def f():
|
|
if index is not None:
|
|
@if_(entry.v == index)
|
|
def f():
|
|
entries.print_reg()
|
|
cint(0).print_reg(op)
|
|
cint(i).print_reg('trre')
|
|
entry.empty().print_reg('empt')
|
|
entry.v.print_reg('v')
|
|
index.print_reg('idx')
|
|
crash()
|
|
if new_entry is not None:
|
|
@if_(regint(1 - new_entry.empty()))
|
|
def f():
|
|
comps = Entry(x == y for x,y in \
|
|
zip(entry,new_entry))
|
|
@if_(reduce(operator.mul, comps))
|
|
def f():
|
|
n_found.iadd(1)
|
|
@else_
|
|
def f():
|
|
entries.print_reg()
|
|
cint(0).print_reg(op)
|
|
cint(i).print_reg('trem')
|
|
entry.empty().print_reg('empt')
|
|
crash()
|
|
if new_entry is not None:
|
|
@if_((n_found != 1) * (1 - new_entry.empty()))
|
|
def f():
|
|
entries.print_reg()
|
|
cint(0).print_reg(op)
|
|
cint(0).print_reg('trad')
|
|
cint(n_found).print_reg('n')
|
|
new_entry.v.print_reg('v')
|
|
for i,x in enumerate(new_entry.x):
|
|
x.print_reg('x%d' % i)
|
|
crash()
|
|
|
|
def binary_search(self, index):
|
|
if (self.size & (self.size-1)) != 0:
|
|
n = 2**(int(math.log(self.size,2)) + 1)
|
|
else:
|
|
n = self.size
|
|
|
|
indices = [i for i in self.ram.get_indices()]
|
|
if self.contiguous and n <= 256:
|
|
logn = int(math.log(n,2))
|
|
expand = 5
|
|
for i,x in enumerate(indices):
|
|
indices[i] = sum(y << (j * expand) for j,y in \
|
|
enumerate(x.bit_decompose(logn)))
|
|
index = sum(y << (j * expand) for j,y in \
|
|
enumerate(index.bit_decompose(logn)))
|
|
else:
|
|
expand = 1
|
|
|
|
# now search for zero
|
|
logn = int(round(math.log(n,2)))
|
|
mult_tree = [1] * 2*n
|
|
bit_prods = [None] * 2*n
|
|
for i in range(n-1, n-1 + self.size):
|
|
mult_tree[i] = indices[i - n + 1] - index
|
|
for i in range(n-2, -1, -1):
|
|
mult_tree[i] = mult_tree[2*i+1] * mult_tree[2*i+2]
|
|
|
|
b = 1 - mult_tree[0].equal(0, 40, expand)
|
|
|
|
bit_prods[0] = 1 - b
|
|
|
|
for j in range(1,logn+1):
|
|
M = 0
|
|
for k in range(2**(j)):
|
|
t = k + 2**(j) - 1
|
|
if k % 2 == 0:
|
|
M += bit_prods[(t-1)//2] * mult_tree[t]
|
|
|
|
b = 1 - M.equal(0, 40, expand)
|
|
|
|
for k in range(2**j):
|
|
t = k + 2**j - 1
|
|
if k % 2 == 0:
|
|
v = bit_prods[(t-1)//2] * b
|
|
bit_prods[t] = bit_prods[(t-1)//2] - v
|
|
else:
|
|
bit_prods[t] = v
|
|
return bit_prods[n-1:n-1+self.size], 1 - bit_prods[0]
|
|
|
|
def find_first_empty(self):
|
|
prefix_empty = [0] + \
|
|
floatingpoint.PreOR([empty for empty in self.ram.get_empty_bits()], \
|
|
Program.prog.security)
|
|
return [prefix_empty[i+1] - prefix_empty[i] \
|
|
for i in range(len(self.ram))]
|
|
def add(self, new_entry, state=None, evict=None):
|
|
# if self.last_index != new_entry.v:
|
|
# raise Exception('index mismatch: %s / %s' %
|
|
# (str(self.last_index), str(new_entry.v)))
|
|
add_here = self.find_first_empty()
|
|
for i,entry in enumerate(self.ram):
|
|
self.ram[i] = if_else(add_here[i], new_entry, entry)
|
|
if crash_on_overflow:
|
|
@if_(or_op(sum(add_here), new_entry.is_empty).reveal() == 0)
|
|
def f():
|
|
self.output()
|
|
print_ln('New entry: %s:%s (empty: %s)', new_entry.v.reveal(),
|
|
new_entry.x[0].reveal(), new_entry.is_empty.reveal())
|
|
print_ln('Bucket overflow')
|
|
crash()
|
|
if debug and not sum(add_here) and not new_entry.empty():
|
|
print(self.empty_entry())
|
|
raise Exception('no space for %s in %s' % (str(new_entry), str(self)))
|
|
self.check(new_entry=new_entry, op='add')
|
|
def pop(self):
|
|
self.last_index = None
|
|
empty_entry = self.empty_entry(False)
|
|
prefix_empty = [0] + \
|
|
floatingpoint.PreOR([1 - empty for empty in self.ram.get_empty_bits()], \
|
|
Program.prog.security)
|
|
pop_here = [prefix_empty[i+1] - prefix_empty[i] \
|
|
for i in range(len(self.ram))]
|
|
entries = [entry for entry in self.ram]
|
|
prod_entries = list(map(operator.mul, pop_here, self.ram))
|
|
result = (1 - sum(pop_here)) * empty_entry
|
|
result = sum(prod_entries, result)
|
|
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
|
self.ram[i] = entry - prod_entry + pop_here[i] * empty_entry
|
|
self.check(index=result.v, op='pop')
|
|
if debug_online:
|
|
entry = Entry(x.reveal() for x in result)
|
|
@if_(entry.empty())
|
|
def f():
|
|
for i,x in enumerate((entry.v,) + entry.x):
|
|
@if_(x != 0)
|
|
def f():
|
|
print_ln('pop error:' + ' %s' * len(entry), *entry)
|
|
print_ln('%s ' * len(pop_here), \
|
|
*(x.reveal() for x in pop_here))
|
|
crash()
|
|
return result
|
|
def output(self):
|
|
self.ram.output()
|
|
def __repr__(self):
|
|
return repr(self.ram)
|
|
|
|
def batch_init(self, values):
|
|
for i,value in enumerate(values):
|
|
index = MemValue(self.value_type.hard_conv(i))
|
|
new_value = [MemValue(self.value_type.hard_conv(v)) \
|
|
for v in (value if isinstance(
|
|
value, (tuple, list, Array)) \
|
|
else (value,))]
|
|
self.ram[i] = Entry(index, new_value, value_type=self.value_type)
|
|
|
|
class TrivialORAM(RefTrivialORAM, AbstractORAM):
|
|
""" Trivial ORAM (obviously). """
|
|
ref_type = RefTrivialORAM
|
|
def __init__(self, size, value_type=None, value_length=1, index_size=None, \
|
|
entry_size=None, contiguous=True, init_rounds=-1):
|
|
self.index_size = index_size or log2(size)
|
|
self.value_type = value_type or sint
|
|
self.index_type = self.value_type.get_type(self.index_size)
|
|
if entry_size is None:
|
|
self.value_length = value_length
|
|
self.entry_size = [None] * value_length
|
|
else:
|
|
self.value_length = len(tuplify(entry_size))
|
|
self.entry_size = tuplify(entry_size)
|
|
self.contiguous = contiguous
|
|
entry_type = self.empty_entry().types()
|
|
self.size = size
|
|
self.ram = RAM(size, entry_type, self.get_array)
|
|
if init_rounds != -1:
|
|
# put memory initialization in different timer
|
|
stop_timer()
|
|
start_timer(1)
|
|
self.init_mem()
|
|
if init_rounds != -1:
|
|
stop_timer(1)
|
|
start_timer()
|
|
get_program().reading('ORAM', 'KS14')
|
|
|
|
def get_n_threads(n_loops):
|
|
if n_threads is None and not single_thread:
|
|
if n_loops > 2048:
|
|
return 8
|
|
else:
|
|
return None
|
|
else:
|
|
return n_threads
|
|
|
|
class LinearORAM(TrivialORAM):
|
|
""" Contiguous ORAM that stores entries in order and accesses the
|
|
entire array for reading and writing in order to hide the address.
|
|
|
|
:param size: number of entries
|
|
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
|
|
:py:class:`sfix`
|
|
:param value_length: number of values per entry (default: 1)
|
|
|
|
"""
|
|
@staticmethod
|
|
def get_array(size, t, *args, **kwargs):
|
|
return Array(size, t, *args, **kwargs)
|
|
def __init__(self, *args, **kwargs):
|
|
TrivialORAM.__init__(self, *args, **kwargs)
|
|
self.index_vector = self.get_array(2 ** self.index_size, \
|
|
self.index_type.bit_type)
|
|
def read_and_maybe_remove(self, index):
|
|
return self.read(index), 0
|
|
def add(self, entry, state=None, evict=None):
|
|
if entry.created_non_empty is True:
|
|
self.write(entry.v, entry.x)
|
|
else:
|
|
self.access(entry.v, entry.x, True, entry.empty())
|
|
def read_and_remove(self, *args):
|
|
raise CompilerError('not implemented')
|
|
@method_block
|
|
def _read(self, index):
|
|
maybe_start_timer(6)
|
|
empty_entry = self.empty_entry(False)
|
|
demux_array(bit_decompose(index, self.index_size), \
|
|
self.index_vector)
|
|
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
|
|
@map_sum(get_n_threads(self.size), None, self.size, \
|
|
self.value_length + 1, t)
|
|
def f(i):
|
|
entry = self.ram[i]
|
|
access_here = self.index_vector[i]
|
|
return access_here * ValueTuple((entry.empty(),) + entry.x)
|
|
not_found = self.value_type.bit_type(f()[0])
|
|
read_value = ValueTuple(self.value_type.get_type(l)(x) for l, x in zip(self.entry_size, f()[1:])) + \
|
|
not_found * empty_entry.x
|
|
maybe_stop_timer(6)
|
|
return read_value, not_found
|
|
@method_block
|
|
def _write(self, index, *new_value):
|
|
maybe_start_timer(7)
|
|
empty_entry = self.empty_entry(False)
|
|
demux_array(bit_decompose(index, self.index_size), \
|
|
self.index_vector)
|
|
new_value = make_array(
|
|
new_value, self.value_type.get_type(
|
|
max(x or 0 for x in self.entry_size)))
|
|
@for_range_multithread(get_n_threads(self.size), None, self.size)
|
|
def f(i):
|
|
entry = self.ram[i]
|
|
access_here = self.index_vector[i]
|
|
nv = ValueTuple(new_value)
|
|
delta_entry = \
|
|
Entry(0, access_here * (nv - entry.x), \
|
|
- access_here * entry.empty())
|
|
self.ram[i] = entry + delta_entry
|
|
maybe_stop_timer(7)
|
|
@method_block
|
|
def _access(self, index, write, new_empty, *new_value):
|
|
empty_entry = self.empty_entry(False)
|
|
index_vector = \
|
|
demux_array(bit_decompose(index, self.index_size))
|
|
new_value = make_array(
|
|
new_value, self.value_type.get_type(
|
|
max(x or 0 for x in self.entry_size)))
|
|
new_empty = MemValue(new_empty)
|
|
write = MemValue(write)
|
|
@map_sum(get_n_threads(self.size), None, self.size, \
|
|
self.value_length + 1, [self.value_type.bit_type] + \
|
|
[self.value_type] * self.value_length)
|
|
def f(i):
|
|
entry = self.ram[i]
|
|
access_here = index_vector[i]
|
|
nv = ValueTuple(new_value)
|
|
delta_entry = \
|
|
Entry(0, access_here * (nv - entry.x), \
|
|
access_here * (new_empty - entry.empty()))
|
|
self.ram[i] = entry + write * delta_entry
|
|
return access_here * ValueTuple((entry.empty(),) + entry.x)
|
|
not_found = f()[0]
|
|
read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x
|
|
return read_value, not_found
|
|
|
|
class RefBucket(object):
|
|
""" Bucket for tree ORAM. Contains an ORAM of some type and
|
|
possibly two children. """
|
|
def __init__(self, index, oram):
|
|
self.bucket = oram.bucket_oram.ref_type(index, oram)
|
|
self.p_children = lambda i: regint.conv((index << 1) + i)
|
|
self.ref_children = lambda i: RefBucket(self.p_children(i), oram)
|
|
self.oram = oram
|
|
def check(self, depth, found=None, index=None):
|
|
if found is None:
|
|
found = set()
|
|
self.bucket.check(found, index)
|
|
if depth:
|
|
for i in (0,1):
|
|
self.ref_children(i).check(depth - 1, found, index)
|
|
def __repr__(self, depth=0):
|
|
result = ' ' * depth + repr(self.bucket) + '\n'
|
|
if depth < self.oram.D:
|
|
result += self.ref_children(0).__repr__(depth + 1) + \
|
|
self.ref_children(1).__repr__(depth + 1)
|
|
return result
|
|
def output(self):
|
|
print_reg(cint(self.depth), 'buck')
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
self.bucket.output()
|
|
print_reg(cint(self.depth), 'dep')
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
@if_(self.p_children(1) < oram.n_buckets())
|
|
def f():
|
|
for i in (0,1):
|
|
child = self.ref_children(i)
|
|
print_reg(cint(i), 'chil')
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
child.output()
|
|
|
|
def random_block(length, value_type):
|
|
return bit_compose(value_type.bit_type.get_random_bit() for i in range(length))
|
|
|
|
class List(EndRecursiveEviction):
|
|
""" Debugging only. List which accepts secret values as indices
|
|
and *reveals* them. """
|
|
def __init__(self, size, value_type, value_length=1, \
|
|
init_rounds=None, entry_size=None):
|
|
self.value_type = value_type
|
|
self.index_type = value_type.get_type(log2(size))
|
|
self.value_length = value_length
|
|
if entry_size is None:
|
|
self.l = [value_type.dynamic_array(size, value_type) \
|
|
for i in range(value_length)]
|
|
else:
|
|
self.l = [value_type.dynamic_array(size, \
|
|
value_type.get_type(length)) \
|
|
for length in entry_size]
|
|
self.value_length = len(entry_size)
|
|
for l in self.l:
|
|
l.assign_all(0)
|
|
__getitem__ = lambda self,index: [self.l[i][regint(reveal(index))] \
|
|
for i in range(self.value_length)]
|
|
def __setitem__(self, index, value):
|
|
# print 'set', index, value, cint(reveal(index))
|
|
# print self.l
|
|
Program.prog.curr_tape.start_new_basicblock(name='List-pre-write')
|
|
for i in range(self.value_length):
|
|
self.l[i][regint(reveal(index))] = value[i]
|
|
Program.prog.curr_tape.start_new_basicblock(name='List-post-write')
|
|
read_and_remove = lambda self,i: (self[i], None)
|
|
def read_and_maybe_remove(self, *args, **kwargs):
|
|
return self.read_and_remove(*args, **kwargs), 0
|
|
add = lambda self,entry,**kwargs: self.__setitem__(entry.v.read(), \
|
|
[v.read() for v in entry.x])
|
|
recursive_evict = lambda *args,**kwargs: None
|
|
def batch_init(self, values):
|
|
for i,value in enumerate(values):
|
|
index = self.value_type.hard_conv(i)
|
|
new_value = [self.value_type.hard_conv(v) \
|
|
for v in (value if isinstance(
|
|
value, (tuple, list, Array)) \
|
|
else (value,))]
|
|
self.__setitem__(index, new_value)
|
|
def __repr__(self):
|
|
return repr(self.l)
|
|
|
|
class LocalIndexStructure(List):
|
|
""" Debugging only. Implements a tree ORAM index as list of
|
|
values, *revealing* which elements are accessed. """
|
|
def __init__(self, size, entry_size, value_type=sint, init_rounds=-1, \
|
|
random_init=False):
|
|
List.__init__(self, size, value_type)
|
|
if init_rounds:
|
|
@for_range(init_rounds if init_rounds > 0 else size)
|
|
def f(i):
|
|
self.l[0][i] = random_block(entry_size, value_type)
|
|
print('index size:', size)
|
|
def update(self, index, value, evict=None):
|
|
read_value = self[index]
|
|
#print 'read', index, read_value
|
|
#print self.l
|
|
self[index] = (value,)
|
|
return self.value_type(read_value)
|
|
def output(self):
|
|
for i,v in enumerate(self):
|
|
print_reg(v.reveal(), 'i %d' % i)
|
|
__getitem__ = lambda self,index: List.__getitem__(self, index)[0]
|
|
|
|
def get_n_threads_for_tree(size):
|
|
if n_threads_for_tree is None and not single_thread:
|
|
if size >= 2**13:
|
|
return 8
|
|
else:
|
|
return 1
|
|
else:
|
|
return n_threads_for_tree
|
|
|
|
class TreeORAM(AbstractORAM):
|
|
""" Tree ORAM. """
|
|
def __init__(self, size, value_type=None, value_length=1, entry_size=None, \
|
|
bucket_oram=TrivialORAM, init_rounds=-1):
|
|
value_type = value_type or sint
|
|
print('create oram of size', size)
|
|
self.bucket_oram = bucket_oram
|
|
# heuristic bucket size
|
|
delta = 3
|
|
k = (math.log(size * size * log2(size) * 100, 2) + 21) / (1 + delta)
|
|
# size + 1 for bucket overflow check
|
|
self.bucket_size = min(int(math.ceil((1 + delta) * k)), size + 1)
|
|
self.D = log2(max(size / k, 2))
|
|
print('bucket size:', self.bucket_size)
|
|
print('depth:', self.D)
|
|
print('complexity:', self.bucket_size * (self.D + 1))
|
|
self.value_type = value_type
|
|
if entry_size is not None:
|
|
self.value_length = len(tuplify(entry_size))
|
|
self.entry_size = tuplify(entry_size)
|
|
else:
|
|
self.value_length = value_length
|
|
self.entry_size = [None] * value_length
|
|
self.index_size = log2(size)
|
|
self.index_type = value_type.get_type(self.index_size)
|
|
self.size = size
|
|
empty_entry = Entry.get_empty(*self.internal_entry_size(), \
|
|
index_size=self.D)
|
|
self.entry_type = empty_entry.types()
|
|
self.ram = RAM(self.n_buckets() * self.bucket_size, self.entry_type, \
|
|
self.get_array)
|
|
if init_rounds != -1:
|
|
# put memory initialization in different timer
|
|
stop_timer()
|
|
start_timer(1)
|
|
self.ram.init_mem(empty_entry)
|
|
if init_rounds != -1:
|
|
stop_timer(1)
|
|
start_timer()
|
|
self.root = RefBucket(1, self)
|
|
self.index = self.index_structure(size, self.D, self.index_type,
|
|
init_rounds, True)
|
|
|
|
self.read_value = Array(self.value_length, value_type.default_type)
|
|
self.read_non_empty = MemValue(self.value_type.bit_type(0))
|
|
self.state = MemValue(self.value_type.default_type(0))
|
|
@method_block
|
|
def add_to_root(self, state, is_empty, v, *x):
|
|
if len(x) != self.value_length:
|
|
raise CompilerError('value length mismatch: %s, should be %s' % \
|
|
(len(x), self.value_length))
|
|
l = state
|
|
self.root.bucket.add(Entry(v, (l,) + x, is_empty))
|
|
def evict_bucket(self, bucket, d):
|
|
#print_reg(cint(0), 'evb')
|
|
#print 'pre', bucket
|
|
entry = bucket.bucket.pop()
|
|
#print 'evict', entry
|
|
#print 'from', bucket
|
|
b = if_else(entry.empty(), self.value_type.bit_type.get_random_bit(), \
|
|
get_bit(entry.x[0], self.D - 1 - d, self.D))
|
|
block = cond_swap(b, entry, self.root.bucket.empty_entry())
|
|
#print 'empty', entry.empty()
|
|
#print 'b', b
|
|
for b in (0,1):
|
|
# not sure if secure other than with trivial ORAM
|
|
bucket.ref_children(b).bucket.add(block[b])
|
|
#print 'block', block
|
|
#print 'post', bucket
|
|
if debug_online:
|
|
secret_entry = entry
|
|
entry = Entry(x.reveal() for x in entry)
|
|
@if_(1 - entry.empty())
|
|
def f():
|
|
b = regint((entry.x[0] >> self.D - 1 - d) & 1)
|
|
bucket.ref_children(b).bucket.check(new_entry=secret_entry, \
|
|
op='evic')
|
|
bucket.ref_children(1-b).bucket.check(index=secret_entry.v, \
|
|
op='evic')
|
|
@method_block
|
|
def evict2(self, p_bucket1, p_bucket2, d):
|
|
self.evict_bucket(RefBucket(p_bucket1, self), d)
|
|
self.evict_bucket(RefBucket(p_bucket2, self), d)
|
|
@method_block
|
|
def read_and_renew_index(self, u):
|
|
l_star = random_block(self.D, self.index_type)
|
|
if use_insecure_randomness:
|
|
new_path = regint.get_random(self.D)
|
|
l_star = self.index_type(new_path)
|
|
self.state.write(l_star)
|
|
res = self.index.update(u, l_star, evict=False).reveal()
|
|
if isinstance(res, types._clear):
|
|
res = regint(cint.conv(res))
|
|
return res
|
|
@method_block
|
|
def read_and_remove_levels(self, u, read_path):
|
|
u = MemValue(u)
|
|
read_path = MemValue(read_path)
|
|
levels = self.D + 1
|
|
parallel = get_parallel(self.index_size, *self.internal_value_type())
|
|
@map_sum(get_n_threads_for_tree(self.size), parallel, levels, \
|
|
self.value_length + 1, [self.value_type.bit_type] + \
|
|
[self.value_type.default_type] * self.value_length)
|
|
def process(level):
|
|
b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level))
|
|
bucket = RefBucket(b_index, self)
|
|
#print 'pre-rar level', i, 'from', bucket
|
|
value, empty = bucket.bucket.read_and_remove(u, 1)
|
|
self.check()
|
|
return (1 - empty,) + value
|
|
self.read_non_empty.write(process()[0])
|
|
self.read_value.assign(process()[1:])
|
|
if debug_online:
|
|
n_found = self.read_non_empty.reveal()
|
|
@if_((n_found != 0) * (n_found != 1))
|
|
def f():
|
|
cint(0).print_reg('rere')
|
|
u.reveal().print_reg('u')
|
|
n_found.print_reg('n')
|
|
for i,x in enumerate(self.read_value):
|
|
x.reveal().print_reg('x%d' % i)
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
crash()
|
|
def internal_value_type(self):
|
|
return self.value_type.default_type, self.value_length + 1
|
|
def internal_entry_size(self):
|
|
return self.value_type.default_type, [self.D] + list(self.entry_size)
|
|
def n_buckets(self):
|
|
return 2**(self.D+1)
|
|
@method_block
|
|
def read_and_remove(self, u):
|
|
#print 'rar', id(self)
|
|
#print 'pre-rar', self
|
|
read_path = self.read_and_renew_index(u)
|
|
#print 'rar for', u, self.read_path
|
|
self.check()
|
|
maybe_start_timer(3)
|
|
self.read_and_remove_levels(u, read_path)
|
|
read_empty = 1 - self.read_non_empty
|
|
read_value = self.read_value
|
|
maybe_stop_timer(3)
|
|
self.check(u)
|
|
#print 'rar result', u, read_value, read_empty
|
|
#print 'post-rar', self
|
|
# if empty:
|
|
# raise EmptyException('read empty value %s at %s, path %s' % \
|
|
# (str(res), str(u), str(l)))
|
|
Program.prog.curr_tape.\
|
|
start_new_basicblock(name='read_and_remove-%d-end' % self.size)
|
|
return [MemValue(v) for v in read_value], MemValue(read_empty)
|
|
def add(self, entry, state=None, evict=True):
|
|
if state is None:
|
|
state = self.state.read()
|
|
#print_reg(cint(0), 'add')
|
|
#print 'add', id(self)
|
|
#print 'pre-add', self
|
|
maybe_start_timer(4)
|
|
self.add_to_root(state, entry.empty(), \
|
|
self.index_type(entry.v.read()), \
|
|
*(self.value_type.default_type(i.read())
|
|
for i in entry.x))
|
|
maybe_stop_timer(4)
|
|
#print 'pre-evict', self
|
|
if evict:
|
|
maybe_start_timer(5)
|
|
self.evict()
|
|
maybe_stop_timer(5)
|
|
#print 'post-evict', self
|
|
def evict(self):
|
|
#print 'evict root', id(self)
|
|
#print_reg(cint(0), 'ev_r')
|
|
self.evict_bucket(self.root, 0)
|
|
self.check()
|
|
if self.D > 1:
|
|
#print 'evict 1', id(self)
|
|
#print_reg(cint(0), 'ev1')
|
|
self.evict2(self.root.p_children(0), self.root.p_children(1), 1)
|
|
self.check()
|
|
if self.D > 2:
|
|
#print_reg(cint(self.D), 'D')
|
|
@for_range(2, self.D)
|
|
def f(d):
|
|
#print_reg(d, 'ev2')
|
|
#print 'evict 2', id(self)
|
|
#print_reg(d, 'evl2')
|
|
s1 = regint.get_random(d)
|
|
s2 = MemValue(regint(0))
|
|
@do_while
|
|
def f():
|
|
s2.write(regint.get_random(d))
|
|
return s2 == s1
|
|
#print 's1, s2', s1, s2
|
|
#print 'S', S
|
|
#print 'd, 2^d', d, 1 << d
|
|
self.evict2(s1 + (1 << d), s2 + (1 << d), d)
|
|
self.check()
|
|
def recursive_evict(self):
|
|
self.evict()
|
|
self.index.recursive_evict()
|
|
|
|
def batch_init(self, values):
|
|
""" Batch initalization. Obliviously shuffles and adds N entries to
|
|
random leaf buckets. """
|
|
m = len(values)
|
|
if m != self.size:
|
|
raise CompilerError('Batch initialization must have N values.')
|
|
if self.value_type != sint:
|
|
raise CompilerError('Batch initialization only possible with sint.')
|
|
|
|
depth = log2(m)
|
|
leaves = self.value_type.Array(m)
|
|
indexed_values = \
|
|
self.value_type.Matrix(m, len(values[0]) + 1)
|
|
|
|
# assign indices 0, ..., m-1
|
|
@for_range(m)
|
|
def _(i):
|
|
value = values[i]
|
|
index = MemValue(self.value_type.hard_conv(i))
|
|
new_value = [MemValue(self.value_type.hard_conv(v)) \
|
|
for v in value]
|
|
indexed_values[i] = [index] + new_value
|
|
|
|
entries = sint.Matrix(self.bucket_size * 2 ** self.D,
|
|
len(Entry(0, list(indexed_values[0]), False)))
|
|
|
|
# assign leaves
|
|
@for_range(len(indexed_values))
|
|
def _(i):
|
|
index_value = list(indexed_values[i])
|
|
leaves[i] = random_block(self.D, self.value_type)
|
|
|
|
index = index_value[0]
|
|
value = [leaves[i]] + index_value[1:]
|
|
entries[i] = Entry(index, value, \
|
|
self.value_type.hard_conv(False), value_type=self.value_type)
|
|
|
|
# save unsorted leaves for position map
|
|
unsorted_leaves = leaves
|
|
|
|
# add all possible leaves to ensure appearance in B
|
|
leaves = self.value_type.Array(m + 2 ** self.D)
|
|
leaves[:] = unsorted_leaves
|
|
leaves.assign(regint.inc(2 ** self.D), base=m)
|
|
leaves.sort()
|
|
|
|
bucket_sz = 0
|
|
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
|
|
B = sint.Matrix(len(leaves), 3)
|
|
B[0] = [0, leaves[0], 0]
|
|
B[-1] = [0, 0, sint(1)]
|
|
s = MemValue(sint(0))
|
|
|
|
@for_range_opt(len(B) - 1)
|
|
def _(j):
|
|
i = j + 1
|
|
eq = leaves[i].equal(leaves[i-1])
|
|
s.write((s + eq) * eq)
|
|
B[i][0] = s
|
|
B[i][1] = leaves[i]
|
|
B[i-1][2] = 1 - eq
|
|
#pos[i] = [s, leaves[i]]
|
|
#last_in_bucket[i-1] = 1 - eq
|
|
|
|
# delete to avoid further usage
|
|
del leaves
|
|
# shuffle
|
|
B.secure_shuffle()
|
|
#cint(0).print_reg('shuf')
|
|
|
|
sz = MemValue(0) #cint(0)
|
|
nleaves = 2**self.D
|
|
empty_positions = Array(nleaves, self.value_type)
|
|
empty_leaves = Array(nleaves, self.value_type)
|
|
|
|
@for_range(len(B))
|
|
def _(i):
|
|
if_then(reveal(B[i][2]))
|
|
#if B[i][2] == 1:
|
|
#cint(i).print_reg('last')
|
|
if isinstance(sz, int):
|
|
szval = sz
|
|
else:
|
|
szval = sz.read()
|
|
#szval.print_reg('sz')
|
|
# subtract one to undo adding above
|
|
empty_positions[szval] = B[i][0] - 1 #pos[i][0]
|
|
#empty_positions[szval].reveal().print_reg('ps0')
|
|
empty_leaves[szval] = B[i][1] #pos[i][1]
|
|
sz.iadd(1)
|
|
end_if()
|
|
|
|
pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2)
|
|
|
|
@for_range_opt(nleaves)
|
|
def _(i):
|
|
leaf = empty_leaves[i]
|
|
# split into 2 if bucket size can't fit into one field elem
|
|
if self.bucket_size + Program.prog.security > 128:
|
|
parity = (empty_positions[i]+1) % 2
|
|
half = (empty_positions[i]+1 - parity) // 2
|
|
half_max = self.bucket_size // 2
|
|
|
|
bits = floatingpoint.B2U(half, half_max)[0]
|
|
bits2 = floatingpoint.B2U(half+parity, half_max)[0]
|
|
# (doesn't work)
|
|
#bits2 = [0] * half_max
|
|
## second half with parity bit
|
|
#for j in range(half_max-1, 0, -1):
|
|
# bits2[j] = bits[j] + (bits[j-1] - bits[j]) * parity
|
|
#bits2[0] = (1 - bits[0]) * parity
|
|
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
|
|
else:
|
|
bucket_bits = floatingpoint.B2U(empty_positions[i]+1,
|
|
self.bucket_size)[0]
|
|
assert len(bucket_bits) == self.bucket_size
|
|
for j, b in enumerate(bucket_bits):
|
|
pos_bits[i * self.bucket_size + j] = [b, leaf]
|
|
|
|
# sort to get empty positions first
|
|
pos_bits.sort(n_bits=1)
|
|
|
|
# now assign positions to empty entries
|
|
@for_range(len(entries) - m)
|
|
def _(i):
|
|
vtype, vlength = self.internal_value_type()
|
|
leaf = vtype(pos_bits[i][1])
|
|
# set leaf in empty entry for assigning after shuffle
|
|
value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)])
|
|
entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype)
|
|
entries[m + i] = entry
|
|
|
|
# now shuffle, reveal positions and place entries
|
|
entries.secure_shuffle()
|
|
clear_leaves = Array.create_from(
|
|
Entry(entries.get_columns()).x[0].reveal())
|
|
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
|
|
bucket_sizes = Array(2**self.D, regint)
|
|
bucket_sizes.assign_all(0)
|
|
|
|
@for_range_opt(len(entries))
|
|
def _(k):
|
|
leaf = clear_leaves[k]
|
|
bucket = RefBucket(leaf + (1 << self.D), self)
|
|
bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k])
|
|
bucket_sizes[leaf] += 1
|
|
|
|
self.index.batch_init(unsorted_leaves)
|
|
|
|
def check(self, index=None):
|
|
if debug:
|
|
self.root.check(self.D, index=index)
|
|
def __repr__(self):
|
|
return repr(self.root) + '\n' + repr(self.index)
|
|
def output(self):
|
|
self.root.output()
|
|
self.index.output()
|
|
|
|
class BaseORAM(TreeORAM):
|
|
""" Debugging only. Tree ORAM revealing the access pattern. """
|
|
index_structure = LocalIndexStructure
|
|
|
|
def put_in_new_block(function):
|
|
def wrapper(*args, **kwargs):
|
|
class BlockCall(object):
|
|
def start(self):
|
|
Program.prog.curr_tape.start_new_basicblock()
|
|
function(*args, **kwargs)
|
|
return self
|
|
def join(self):
|
|
pass
|
|
return BlockCall()
|
|
return wrapper
|
|
|
|
def get_log_value_size(value_type):
|
|
""" Return log of element size. """
|
|
if value_type == sgf2n:
|
|
return 5
|
|
else:
|
|
return sint_bit_length
|
|
|
|
def get_value_size(value_type):
|
|
""" Return element size. """
|
|
if value_type == sgf2n:
|
|
return Program.prog.galois_length
|
|
elif value_type == sint:
|
|
ring = Program.prog.options.ring
|
|
if ring:
|
|
return int(ring)
|
|
else:
|
|
return 127 - Program.prog.security
|
|
else:
|
|
return value_type.max_length
|
|
|
|
def get_parallel(index_size, value_type, value_length):
|
|
""" Returning the number of parallel readings feasible, based on
|
|
experiments. """
|
|
value_size = get_value_size(value_type)
|
|
if value_type == sint:
|
|
value_size *= 2
|
|
res = max(1, min(50 * 32 // (value_length * value_size), \
|
|
800 * 32 // (value_length * index_size)))
|
|
if comparison.const_rounds:
|
|
res = max(1, res // 2)
|
|
print('Reading %d buckets in parallel' % res)
|
|
return res
|
|
|
|
class PackedIndexStructure(object):
|
|
""" Abstract class for ORAM using bit packing. """
|
|
def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1, \
|
|
random_init=False):
|
|
self.size = size
|
|
if entry_size is None:
|
|
self.entry_size = (log2(size),)
|
|
else:
|
|
self.entry_size = tuplify(entry_size)
|
|
self.value_type = value_type
|
|
for demux_bits in range(max_demux_bits + 1):
|
|
self.log_entries_per_element = min(log2(size), \
|
|
int(math.floor(math.log(float(get_value_size(value_type)) / \
|
|
sum(self.entry_size), 2))))
|
|
self.log_elements_per_block = \
|
|
max(0, min(demux_bits, log2(size) - \
|
|
self.log_entries_per_element))
|
|
if self.log_entries_per_element < 0:
|
|
self.entries_per_block = 1
|
|
max_bits = get_value_size(value_type)
|
|
self.split_sizes = [[]]
|
|
for s in self.entry_size:
|
|
if s > max_bits:
|
|
raise CompilerError('Inadequate entry size %d, ' \
|
|
'maximum %d' % \
|
|
(s, max_bits))
|
|
if sum(self.split_sizes[-1]) + s > max_bits:
|
|
self.split_sizes.append([])
|
|
self.split_sizes[-1].append(s)
|
|
self.elements_per_entry = len(self.split_sizes)
|
|
self.log_elements_per_block = log2(self.elements_per_entry)
|
|
self.log_entries_per_element = -self.log_elements_per_block
|
|
print('split sizes:', self.split_sizes)
|
|
self.log_entries_per_block = \
|
|
self.log_elements_per_block + self.log_entries_per_element
|
|
self.elements_per_block = 2**self.log_elements_per_block
|
|
self.entries_per_element = 2**self.log_entries_per_element
|
|
self.entries_per_block = 2**self.log_entries_per_block
|
|
self.used_bits = self.entries_per_element * sum(self.entry_size)
|
|
real_size = -(-size // self.entries_per_block)
|
|
print('packed size:', real_size)
|
|
print('index size:', size)
|
|
print('entry size:', self.entry_size)
|
|
print('log(entries per element):', self.log_entries_per_element)
|
|
print('entries per element:', self.entries_per_element)
|
|
print('log(entries per block):', self.log_entries_per_block)
|
|
print('entries per block:', self.entries_per_block)
|
|
print('log(elements per block):', self.log_elements_per_block)
|
|
print('elements per block:', self.elements_per_block)
|
|
print('used bits:', self.used_bits)
|
|
entry_size = [self.used_bits] * self.elements_per_block
|
|
if real_size > 1:
|
|
# no need to init underlying ORAM, will be initialized implicitely
|
|
self.l = self.storage(real_size, value_type, \
|
|
entry_size=entry_size, init_rounds=0)
|
|
self.small = False
|
|
else:
|
|
self.l = List(1, value_type, self.elements_per_block, \
|
|
entry_size=entry_size)
|
|
self.small = True
|
|
self.index_type = self.l.index_type
|
|
if init_rounds:
|
|
if init_rounds > 0:
|
|
real_init_rounds = init_rounds * real_size // size
|
|
else:
|
|
real_init_rounds = real_size
|
|
print('packed init rounds:', real_init_rounds)
|
|
@for_range(real_init_rounds)
|
|
def f(i):
|
|
if random_init:
|
|
self.l[i] = [random_block(self.used_bits, self.value_type) \
|
|
for j in range(self.elements_per_block)]
|
|
else:
|
|
self.l[i] = [0] * self.elements_per_block
|
|
time()
|
|
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
|
|
print_ln('packed ORAM init done')
|
|
print('index initialized, size', size)
|
|
def translate_index(self, index):
|
|
""" Bit slicing *index* according parameters. Output is tuple
|
|
(storage address, index with storage cell, index within
|
|
element). """
|
|
if self.value_type == sint:
|
|
rem = mod2m(index, self.log_entries_per_block, log2(self.size), False)
|
|
c = mod2m(rem, self.log_entries_per_element, \
|
|
self.log_entries_per_block, False)
|
|
b = trunc_zeros(rem - c, self.log_entries_per_element,
|
|
self.log_entries_per_block)
|
|
if self.small:
|
|
return 0, b, c
|
|
else:
|
|
return trunc_zeros(index - rem, self.log_entries_per_block,
|
|
log2(self.size)), b, c
|
|
else:
|
|
index_bits = bit_decompose(index, log2(self.size))
|
|
l1 = self.log_entries_per_element
|
|
l2 = self.log_entries_per_block
|
|
c = bit_compose(index_bits[:l1])
|
|
b = bit_compose(index_bits[l1:l2])
|
|
if self.small:
|
|
return 0, b, c
|
|
else:
|
|
a = bit_compose(index_bits[l2:])
|
|
return a, b, c
|
|
raise CompilerError('Cannot process indices of type', self.value_type)
|
|
class Slicer(object):
|
|
def __init__(self, pack, index):
|
|
self.pack = pack
|
|
self.a, self.b, self.c = pack.translate_index(index)
|
|
def read(self, block):
|
|
self.block = block
|
|
self.index_vector = \
|
|
demux(bit_decompose(self.b, self.pack.log_elements_per_block))
|
|
self.vector = list(map(operator.mul, self.index_vector, block))
|
|
self.element = get_block(sum(self.vector), self.c, \
|
|
self.pack.entry_size, \
|
|
self.pack.entries_per_element)
|
|
return tuple(self.element.get_slice())
|
|
def write(self, value):
|
|
self.element.set_slice(value)
|
|
anti_vector = list(map(operator.sub, self.block, self.vector))
|
|
updated_vector = [self.element.value * i for i in self.index_vector]
|
|
updated_block = list(map(operator.add, anti_vector, updated_vector))
|
|
return updated_block
|
|
class MultiSlicer(object):
|
|
def __init__(self, pack, index):
|
|
self.pack = pack
|
|
self.a = index
|
|
def read(self, block):
|
|
res = []
|
|
for element,sizes in zip(block,self.pack.split_sizes):
|
|
bits = element.bit_decompose(sum(sizes))
|
|
for size in sizes:
|
|
res.append(sum(bit << i \
|
|
for i,bit in enumerate(bits[-size:])))
|
|
del bits[-size:]
|
|
return tuple(res)
|
|
def write(self, value):
|
|
res = []
|
|
i = 0
|
|
for sizes in self.pack.split_sizes:
|
|
res.append(0)
|
|
for size in sizes:
|
|
res[-1] <<= size
|
|
res[-1] += value[i]
|
|
i += 1
|
|
return res
|
|
def get_slicer(self, index):
|
|
if self.log_entries_per_element < 0:
|
|
return self.MultiSlicer(self, index)
|
|
else:
|
|
return self.Slicer(self, index)
|
|
def update(self, index, value, evict=True):
|
|
""" Updating index return current value. Has to be done in one
|
|
step to avoid exponential blow-up in ORAM recursion. """
|
|
return self.access(index, value, True, evict=evict)
|
|
def access(self, index, value, write, evict=True):
|
|
slicer = self.get_slicer(index)
|
|
block = self.l.read_and_maybe_remove(slicer.a)[0][0]
|
|
read_value = slicer.read(block)
|
|
value = if_else(write, ValueTuple(tuplify(value)), \
|
|
ValueTuple(read_value))
|
|
self.l.add(Entry(MemValue(self.l.index_type(slicer.a)), \
|
|
ValueTuple(MemValue(v) \
|
|
for v in slicer.write(value)), \
|
|
value_type=self.value_type), evict=evict)
|
|
return untuplify(read_value)
|
|
def __getitem__(self, index):
|
|
slicer = self.get_slicer(index)
|
|
return untuplify(slicer.read(self.l[slicer.a]))
|
|
def __setitem__(self, index, value):
|
|
if self.log_entries_per_element < 0:
|
|
# no need for reading first
|
|
self.l[index] = self.get_slicer(index).write(value)
|
|
else:
|
|
self.access(index, value, True, False)
|
|
self.l.recursive_evict()
|
|
recursive_evict = lambda self: self.l.recursive_evict()
|
|
|
|
def batch_init(self, values):
|
|
""" Initialize m values with indices 0, ..., m-1 """
|
|
m = len(values)
|
|
n_entries = int(math.ceil(m / self.entries_per_block))
|
|
new_values = sint.Matrix(n_entries, self.elements_per_block)
|
|
values = Array.create_from(values)
|
|
|
|
@for_range(n_entries)
|
|
def _(i):
|
|
block = Array.create_from([sint(0)] * self.elements_per_block)
|
|
for j in range(self.elements_per_block):
|
|
base = i * self.entries_per_block + j * self.entries_per_element
|
|
for k in range(self.entries_per_element):
|
|
@if_(base + k < m)
|
|
def _():
|
|
block[j] += \
|
|
values[base + k] << (k * sum(self.entry_size))
|
|
|
|
new_values[i] = block
|
|
|
|
self.l.batch_init(new_values)
|
|
|
|
def __repr__(self):
|
|
return repr(self.l)
|
|
def output(self):
|
|
if self.small:
|
|
print_reg(self.l[0].reveal(), 'i0')
|
|
print_reg(self.l[1].reveal(), 'i1')
|
|
|
|
class PackedORAMWithEmpty(AbstractORAM, PackedIndexStructure):
|
|
def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1):
|
|
if entry_size is None:
|
|
entry_size = log2(size)
|
|
PackedIndexStructure.__init__(self, size, (1,) + tuplify(entry_size), \
|
|
value_type, init_rounds=init_rounds)
|
|
self.value_length = len(self.entry_size)
|
|
@method_block
|
|
def _read(self, index):
|
|
res = PackedIndexStructure.__getitem__(self, index)
|
|
return res[1:], 1 - res[0]
|
|
def access(self, index, new_value, write, new_empty=False, evict=True):
|
|
res = PackedIndexStructure.access(self, index, (1 - new_empty,) + \
|
|
tuplify(new_value), write, \
|
|
evict=evict)
|
|
return res[1:], 1 - res[0]
|
|
def read_and_maybe_remove(self, index):
|
|
return self.read(index), 0
|
|
def add(self, entry, state=None, evict=True):
|
|
self.access(entry.v, entry.x, True, entry.empty(), evict=evict)
|
|
|
|
class LocalPackedIndexStructure(PackedIndexStructure):
|
|
""" Debugging only. Packed tree ORAM index revealing the access
|
|
pattern. """
|
|
storage = staticmethod(lambda *args,**kwargs: List(*args,**kwargs))
|
|
|
|
class LocalPackedORAM(TreeORAM):
|
|
""" Debugging only. Tree ORAM using index revealing the access
|
|
pattern. """
|
|
index_structure = LocalPackedIndexStructure
|
|
|
|
class BaseORAMIndexStructure(PackedIndexStructure):
|
|
""" Debugging only. Tree ORAM index revealing the access
|
|
pattern after one recursion. """
|
|
storage = BaseORAM
|
|
|
|
class OneLevelORAM(TreeORAM):
|
|
""" Debugging only. Tree ORAM using index revealing the access
|
|
pattern after one recursion. """
|
|
index_structure = BaseORAMIndexStructure
|
|
|
|
class BinaryORAM:
|
|
def __init__(self, size, value_type=None, **kwargs):
|
|
from Compiler import circuit_oram
|
|
from Compiler.GC import types
|
|
n_bits = int(get_program().options.binary)
|
|
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
|
|
self.index_type = self.value_type
|
|
oram_value_type = types.sbits.get_type(64)
|
|
if 'entry_size' not in kwargs:
|
|
kwargs['entry_size'] = n_bits
|
|
self.oram = circuit_oram.OptimalCircuitORAM(
|
|
size, value_type=oram_value_type, **kwargs)
|
|
self.size = size
|
|
def get_index(self, index):
|
|
return self.oram.value_type(self.index_type.conv(index).elements()[0])
|
|
def __setitem__(self, index, value):
|
|
value = list(self.oram.value_type(
|
|
self.value_type.conv(v).elements()[0]) for v in tuplify(value))
|
|
self.oram[self.get_index(index)] = value
|
|
def __getitem__(self, index):
|
|
value = self.oram[self.get_index(index)]
|
|
return untuplify(tuple(self.value_type(v) for v in tuplify(value)))
|
|
def read(self, index):
|
|
return self.oram.read(index)
|
|
def read_and_maybe_remove(self, index):
|
|
return self.oram.read_and_maybe_remove(index)
|
|
def access(self, *args):
|
|
return self.oram.access(*args)
|
|
def add(self, *args, **kwargs):
|
|
return self.oram.add(*args, **kwargs)
|
|
def delete(self, *args, **kwargs):
|
|
return self.oram.delete(*args, **kwargs)
|
|
|
|
def OptimalORAM(size,*args,**kwargs):
|
|
""" Create an ORAM instance suitable for the size based on
|
|
experiments. This uses :py:class:`LinearORAM` for sizes up to a
|
|
few thousand and :py:class:`RecursiveORAM` above that.
|
|
|
|
:param size: number of entries
|
|
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
|
|
:py:class:`sfix`
|
|
:param value_length: number of values per entry (default: 1)
|
|
|
|
"""
|
|
if not util.is_constant(size):
|
|
raise CompilerError('ORAM size has be a compile-time constant')
|
|
if get_program().options.binary:
|
|
return BinaryORAM(size, *args, **kwargs)
|
|
if optimal_threshold is None:
|
|
if n_threads == 1:
|
|
threshold = 2**11
|
|
else:
|
|
threshold = 2**13
|
|
else:
|
|
threshold = optimal_threshold
|
|
if size <= threshold:
|
|
return LinearORAM(size,*args,**kwargs)
|
|
else:
|
|
return RecursiveORAM(size,*args,**kwargs)
|
|
|
|
class RecursiveIndexStructure(PackedIndexStructure):
|
|
""" Secure index using secure tree ORAM. """
|
|
storage = lambda self,*args,**kwargs: OptimalORAM(*args,**kwargs)
|
|
|
|
class RecursiveORAM(TreeORAM):
|
|
""" Secure tree ORAM using secure index. This uses the approach by
|
|
`Keller and Scholl <https://eprint.iacr.org/2014/137>`_.
|
|
|
|
:param size: number of entries
|
|
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
|
|
:py:class:`sfix`
|
|
:param value_length: number of values per entry (default: 1)
|
|
|
|
"""
|
|
index_structure = RecursiveIndexStructure
|
|
|
|
class TrivialORAMIndexStructure(PackedIndexStructure):
|
|
""" Secure index using trivial ORAM. """
|
|
storage = TrivialORAM
|
|
|
|
class TrivialIndexORAM(TreeORAM):
|
|
""" Secure tree ORAM using index using trivial ORAM. """
|
|
index_structure = TrivialORAMIndexStructure
|
|
|
|
class AtLeastOneRecursionIndexStructure(PackedIndexStructure):
|
|
storage = RecursiveORAM
|
|
|
|
OptimalPackedORAM = RecursiveIndexStructure
|
|
|
|
class LinearPackedORAM(PackedIndexStructure):
|
|
storage = LinearORAM
|
|
|
|
class LinearPackedORAMWithEmpty(PackedORAMWithEmpty):
|
|
storage = LinearORAM
|
|
|
|
class AtLeastOneRecursionPackedORAMWithEmpty(PackedORAMWithEmpty):
|
|
storage = RecursiveORAM
|
|
|
|
class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
|
|
storage = staticmethod(OptimalORAM)
|
|
|
|
def test_oram(oram_type, N, value_type=sint, iterations=100):
|
|
stop_grind()
|
|
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
|
|
test_oram_initialized(oram, iterations)
|
|
return oram
|
|
|
|
def test_oram_initialized(oram, iterations=100):
|
|
N = oram.size
|
|
value_type = oram.value_type
|
|
value_type = value_type.get_type(32)
|
|
index_type = value_type.get_type(log2(N))
|
|
start_grind()
|
|
print('initialized')
|
|
print_ln('initialized')
|
|
stop_timer()
|
|
# synchronize
|
|
start_timer(2)
|
|
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
|
value_type(0).reveal()
|
|
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
|
stop_timer(2)
|
|
start_timer()
|
|
#oram[value_type(0)] = -1
|
|
#iterations = N
|
|
@for_range(iterations)
|
|
def f(i):
|
|
time()
|
|
oram[index_type(i % N)] = value_type(i % N)
|
|
#value, empty = oram.read_and_remove(value_type(i))
|
|
#print 'first write'
|
|
time()
|
|
oram[index_type(i % N)].reveal().print_reg('writ')
|
|
#print 'first read'
|
|
@for_range(iterations)
|
|
def f(i):
|
|
time()
|
|
x = oram[index_type(i % N)]
|
|
x.reveal().print_reg('read')
|
|
# print 'second read'
|
|
print_ln('%s accesses', 3 * iterations)
|
|
return oram
|
|
|
|
def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100):
|
|
oram = oram_type(N, value_type=value_type, entry_size=32, \
|
|
init_rounds=0)
|
|
print('initialized')
|
|
print_reg(cint(0), 'init')
|
|
stop_timer()
|
|
# synchronize
|
|
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
|
sint(0).reveal()
|
|
Program.prog.curr_tape.start_new_basicblock(name='sync')
|
|
start_timer()
|
|
#oram[value_type(0)] = -1
|
|
@for_range(iterations)
|
|
def f(i):
|
|
oram.access(value_type(i % N), value_type(0), value_type(True))
|
|
oram.access(value_type(i % N), value_type(i % N), value_type(True))
|
|
print('first write')
|
|
time()
|
|
x = oram.access(value_type(i % N), value_type(0), value_type(False))
|
|
x[0][0].reveal().print_reg('writ')
|
|
print('first read')
|
|
# @for_range(iterations)
|
|
# def f(i):
|
|
# x = oram.access(value_type(i % N), value_type(0), value_type(False), \
|
|
# value_type(True))
|
|
# x[0][0].reveal().print_reg('read')
|
|
# print 'second read'
|
|
return oram
|
|
|
|
def test_batch_init(oram_type, N):
|
|
value_type = sint
|
|
oram = oram_type(N, value_type)
|
|
print('initialized')
|
|
print_reg(cint(0), 'init')
|
|
oram.batch_init(Array.create_from(sint(regint.inc(N))))
|
|
print_reg(cint(0), 'done')
|
|
@for_range(N)
|
|
def f(i):
|
|
x = oram[value_type(i)]
|
|
x.reveal().print_reg('read')
|
|
return oram
|
|
|
|
def oram_delete(oram, iterations=100):
|
|
@for_range(iterations)
|
|
def f(i):
|
|
x = oram.access(oram.value_type(i % oram.size), oram.value_type(0), \
|
|
oram.value_type(True), oram.value_type(True))
|