mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
816 lines
34 KiB
Python
816 lines
34 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
from abc import abstractmethod
|
|
from typing import Any, Generic, Type, TypeVar
|
|
|
|
from Compiler import library as lib
|
|
from Compiler import util
|
|
from Compiler.GC.types import cbit, sbit, sbitint, sbits
|
|
from Compiler.program import Program
|
|
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
|
|
regint, sint, sintbit)
|
|
from Compiler.oram import demux_array, get_n_threads
|
|
|
|
# Adds messages on completion of heavy computation steps
|
|
debug = False
|
|
# Finer grained trace of steps that the ORAM performs
|
|
# + runtime error checks
|
|
# Warning: reveals information and makes the computation insecure
|
|
trace = False
|
|
|
|
n_threads = 16
|
|
n_parallel = 1024
|
|
|
|
# Avoids any memory allocation if set to False
|
|
# Setting to False prevents some optimizations but allows for controlling the ORAMs outside of the main tape
|
|
allow_memory_allocation = True
|
|
|
|
|
|
def get_n_threads(n_loops):
|
|
if n_threads is None:
|
|
if n_loops > 2048:
|
|
return 8
|
|
else:
|
|
return None
|
|
else:
|
|
return n_threads
|
|
|
|
|
|
T = TypeVar("T", sint, sbitint)
|
|
B = TypeVar("B", sintbit, sbit)
|
|
|
|
|
|
class SqrtOram(Generic[T, B]):
|
|
"""Oblivious RAM using the "Square-Root" algorithm.
|
|
|
|
:param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
|
|
:param sint value_type: The secret type to use, defaults to sint.
|
|
:param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
|
|
:param int period: Leave at None, this parameter is used to recursively pass down the top-level period.
|
|
"""
|
|
# TODO: Preferably this is an Array of vectors, but this is currently not supported
|
|
# One should regard these structures as Arrays where an entry may hold more
|
|
# than one value (which is a nice property to have when using the ORAM in
|
|
# practise).
|
|
shuffle: MultiArray
|
|
stash: MultiArray
|
|
# A block has an index and data
|
|
# `shuffle` and `stash` store the data,
|
|
# `shufflei` and `stashi` store the index
|
|
shufflei: Array
|
|
stashi: Array
|
|
|
|
shuffle_used: Array
|
|
position_map: PositionMap
|
|
|
|
# The size of the ORAM, i.e. how many elements it stores
|
|
n: int
|
|
# The period, i.e. how many calls can be made to the ORAM before it needs to be refreshed
|
|
T: int
|
|
# Keep track of how far we are in the period, and coincidentally how large
|
|
# the stash is (each access results in a fake or real block being put on
|
|
# the stash)
|
|
t: cint
|
|
|
|
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None:
|
|
global debug, allow_memory_allocation
|
|
|
|
# Correctly initialize the shuffle (memory) depending on the type of data
|
|
if isinstance(data, MultiArray):
|
|
self.shuffle = data
|
|
self.n = len(data)
|
|
elif isinstance(data, sint):
|
|
self.n = math.ceil(len(data) // entry_length)
|
|
if (len(data) % entry_length != 0):
|
|
raise Exception('Data incorrectly padded.')
|
|
self.shuffle = MultiArray(
|
|
(self.n, entry_length), value_type=value_type)
|
|
self.shuffle.assign_part_vector(data.get_vector())
|
|
else:
|
|
raise Exception("Incorrect format.")
|
|
|
|
# Only sint is supported
|
|
if value_type != sint and value_type != sbitint:
|
|
raise Exception("The value_type must be either sint or sbitint")
|
|
|
|
# Set derived constants
|
|
self.value_type = value_type
|
|
self.bit_type: Type[B] = value_type.bit_type
|
|
self.index_size = util.log2(self.n) + 1 # +1 because signed
|
|
self.index_type = value_type.get_type(self.index_size)
|
|
self.entry_length = entry_length
|
|
self.size = self.n
|
|
|
|
if debug:
|
|
lib.print_ln(
|
|
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
|
|
|
self.shuffle_used = cint.Array(self.n)
|
|
# Random permutation on the data
|
|
self.shufflei = Array.create_from(
|
|
self.index_type(regint.inc(self.n)))
|
|
# Calculate the period if not given
|
|
# upon recursion, the period should stay the same ("in sync"),
|
|
# therefore it can be passed as a constructor parameter
|
|
self.T = int(math.ceil(
|
|
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
|
if debug and not period:
|
|
lib.print_ln('Period set to %s', self.T)
|
|
|
|
# Here we allocate the memory for the permutation
|
|
# Note that self.shuffle_the_shuffle mutates this field
|
|
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
|
|
self.permutation = Array.create_from(
|
|
self.index_type(regint.inc(self.n)))
|
|
# We allow the caller to postpone the initialization of the shuffle
|
|
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
|
|
# Note that if you do not initialize, the ORAM is insecure
|
|
if initialize:
|
|
# If the ORAM is not initialized with existing data, we can apply
|
|
# a small optimization by forgoing shuffling the shuffle, as all
|
|
# entries of the shuffle are equal and empty.
|
|
if empty_data:
|
|
random_shuffle = sint.get_secure_shuffle(self.n)
|
|
self.shufflei.secure_permute(random_shuffle)
|
|
self.permutation.assign(self.shufflei[:].inverse_permutation())
|
|
if trace:
|
|
lib.print_ln('Calculated inverse permutation')
|
|
else:
|
|
self.shuffle_the_shuffle()
|
|
else:
|
|
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
|
|
# Initialize position map (recursive oram)
|
|
self.position_map = PositionMap.create(self.permutation, k + 1,
|
|
self.T, initialize=initialize)
|
|
|
|
# Initialize stash
|
|
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
|
|
self.stashi = Array(self.T, value_type=value_type)
|
|
self.t = MemValue(cint(0))
|
|
|
|
# Initialize temp variables needed during the computation
|
|
self.found_ = self.bit_type.Array(size=self.T)
|
|
self.j = MemValue(cint(0, size=1))
|
|
|
|
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
|
|
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
|
|
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
|
|
SqrtOram.shuffle_the_shuffle = lib.method_block(SqrtOram.shuffle_the_shuffle) if allow_memory_allocation else SqrtOram.shuffle_the_shuffle
|
|
SqrtOram.refresh = lib.method_block(SqrtOram.refresh) if allow_memory_allocation else SqrtOram.refresh
|
|
SqrtOram.reinitialize = lib.method_block(SqrtOram.reinitialize) if allow_memory_allocation else SqrtOram.reinitialize
|
|
|
|
@lib.method_block
|
|
def access(self, index: T, write: B, *value: T):
|
|
global trace,n_parallel
|
|
if trace:
|
|
@lib.if_e(write.reveal() == 1)
|
|
def _():
|
|
lib.print_ln('Writing to secret index %s', index.reveal())
|
|
|
|
@lib.else_
|
|
def __():
|
|
lib.print_ln('Reading from secret index %s', index.reveal())
|
|
|
|
value = self.value_type(value, size=self.entry_length).get_vector(
|
|
0, size=self.entry_length)
|
|
index = MemValue(index)
|
|
|
|
# Refresh if we have performed T (period) accesses
|
|
@lib.if_(self.t == self.T)
|
|
def _():
|
|
self.refresh()
|
|
|
|
found: B = MemValue(self.bit_type(False))
|
|
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
|
|
|
# First we scan the stash for the item
|
|
self.found_.assign_all(0)
|
|
|
|
# This will result in a bit array with at most one True,
|
|
# indicating where in the stash 'index' is found
|
|
@lib.multithread(get_n_threads(self.T), self.T)
|
|
def _(base, size):
|
|
self.found_.assign_vector(
|
|
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
|
self.bit_type(regint.inc(size, base=base) <
|
|
self.t.expand_to_vector(size)),
|
|
base=base)
|
|
|
|
# To determine whether the item is found in the stash, we simply
|
|
# check wheterh the demuxed array contains a True
|
|
# TODO: What if the index=0?
|
|
found.write(sum(self.found_))
|
|
|
|
# Store the stash item into the result if found
|
|
# If the item is not in the stash, the result will simple remain 0
|
|
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
|
self.entry_length, [self.value_type] * self.entry_length)
|
|
def stash_item(i):
|
|
entry = self.stash[i][:]
|
|
access_here = self.found_[i]
|
|
# This is a bit unfortunate
|
|
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
|
# Therefore we loop till self.T (the max value of self.t)
|
|
# is_in_time = i < self.t
|
|
|
|
# If we are writing, we need to add the value
|
|
self.stash[i] += write * access_here * (value - entry)
|
|
return (entry * access_here)[:]
|
|
result += self.value_type(stash_item(), size=self.entry_length)
|
|
|
|
if trace:
|
|
@lib.if_e(found.reveal() == 1)
|
|
def _():
|
|
lib.print_ln('Found item in stash')
|
|
|
|
@lib.else_
|
|
def __():
|
|
lib.print_ln('Did not find item in stash')
|
|
|
|
# Possible fake lookup of the item in the shuffle,
|
|
# depending on whether we already found the item in the stash
|
|
physical_address = self.position_map.get_position(index, found)
|
|
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
|
# with its equivalent on the stash once the period is up.
|
|
self.shuffle_used[physical_address] = cbit(True)
|
|
|
|
# If the item was not found in the stash
|
|
# ...we update the item in the shuffle
|
|
self.shuffle[physical_address] += write * \
|
|
found.bit_not() * (value - self.shuffle[physical_address][:])
|
|
# ...and the item retrieved from the shuffle is our result
|
|
result += self.shuffle[physical_address] * found.bit_not()
|
|
# We append the newly retrieved item to the stash
|
|
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
|
self.stashi[self.t] = self.shufflei[physical_address]
|
|
|
|
if trace:
|
|
@lib.if_((write * found.bit_not()).reveal())
|
|
def _():
|
|
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
|
), self.shuffle[physical_address].reveal(), physical_address)
|
|
|
|
# Increase the "time" (i.e. access count in current period)
|
|
self.t.iadd(1)
|
|
|
|
return result
|
|
|
|
def write(self, index: T, *value: T):
|
|
global trace, n_parallel
|
|
if trace:
|
|
lib.print_ln('Writing to secret index %s', index.reveal())
|
|
|
|
if isinstance(value, tuple) or isinstance(value,list):
|
|
value = self.value_type(value, size=self.entry_length)
|
|
print(value, type(value))
|
|
elif isinstance(value, self.value_type):
|
|
value = self.value_type(*value, size=self.entry_length)
|
|
print(value, type(value))
|
|
else:
|
|
raise Exception("Cannot handle type of value passed")
|
|
print(self.entry_length, value, type(value),len(value))
|
|
|
|
self._write(index, *value)
|
|
|
|
@lib.method_block
|
|
def _write(self, index: T, *value: T):
|
|
value = MemValue(self.value_type(value))
|
|
index = MemValue(index)
|
|
|
|
# Refresh if we have performed T (period) accesses
|
|
@lib.if_(self.t == self.T)
|
|
def _():
|
|
self.refresh()
|
|
|
|
found: B = MemValue(self.bit_type(False))
|
|
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
|
|
|
# First we scan the stash for the item
|
|
self.found_.assign_all(0)
|
|
|
|
# This will result in an bit array with at most one True,
|
|
# indicating where in the stash 'index' is found
|
|
@lib.multithread(get_n_threads(self.T), self.T)
|
|
def _(base, size):
|
|
self.found_.assign_vector(
|
|
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
|
self.bit_type(regint.inc(size, base=base) <
|
|
self.t.expand_to_vector(size)),
|
|
base=base)
|
|
|
|
# To determine whether the item is found in the stash, we simply
|
|
# check wheterh the demuxed array contains a True
|
|
# TODO: What if the index=0?
|
|
found.write(sum(self.found_))
|
|
|
|
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
|
self.entry_length, [self.value_type] * self.entry_length)
|
|
def stash_item(i):
|
|
entry = self.stash[i][:]
|
|
access_here = self.found_[i]
|
|
# This is a bit unfortunate
|
|
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
|
# Therefore we loop till self.T (the max value of self.t)
|
|
# is_in_time = i < self.t
|
|
|
|
# We update the stash value
|
|
self.stash[i] += access_here * (value - entry)
|
|
return (entry * access_here)[:]
|
|
result += self.value_type(stash_item(), size=self.entry_length)
|
|
|
|
if trace:
|
|
@lib.if_e(found.reveal() == 1)
|
|
def _():
|
|
lib.print_ln('Found item in stash')
|
|
|
|
@lib.else_
|
|
def __():
|
|
lib.print_ln('Did not find item in stash')
|
|
|
|
# Possible fake lookup of the item in the shuffle,
|
|
# depending on whether we already found the item in the stash
|
|
physical_address = self.position_map.get_position(index, found)
|
|
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
|
# with its equivalent on the stash once the period is up.
|
|
self.shuffle_used[physical_address] = cbit(True)
|
|
|
|
# If the item was not found in the stash
|
|
# ...we update the item in the shuffle
|
|
self.shuffle[physical_address] += found.bit_not() * \
|
|
(value - self.shuffle[physical_address][:])
|
|
# ...and the item retrieved from the shuffle is our result
|
|
result += self.shuffle[physical_address] * found.bit_not()
|
|
# We append the newly retrieved item to the stash
|
|
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
|
self.stashi[self.t] = self.shufflei[physical_address]
|
|
|
|
if trace:
|
|
@lib.if_(found.bit_not().reveal())
|
|
def _():
|
|
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
|
), self.shuffle[physical_address].reveal(), physical_address)
|
|
|
|
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
|
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
|
|
|
# Increase the "time" (i.e. access count in current period)
|
|
self.t.iadd(1)
|
|
|
|
return result
|
|
|
|
@lib.method_block
|
|
def read(self, index: T, *value: T):
|
|
global debug, trace, n_parallel
|
|
if trace:
|
|
lib.print_ln('Reading from secret index %s', index.reveal())
|
|
|
|
value = self.value_type(value)
|
|
index = MemValue(index)
|
|
|
|
# Refresh if we have performed T (period) accesses
|
|
@lib.if_(self.t == self.T)
|
|
def _():
|
|
if debug:
|
|
lib.print_ln('Refreshing SqrtORAM')
|
|
lib.print_ln('t=%s according to me', self.t)
|
|
|
|
self.refresh()
|
|
|
|
found: B = MemValue(self.bit_type(False))
|
|
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
|
|
|
# First we scan the stash for the item
|
|
self.found_.assign_all(0)
|
|
|
|
# This will result in a bit array with at most one True,
|
|
# indicating where in the stash 'index' is found
|
|
@lib.multithread(get_n_threads(self.T), self.T)
|
|
def _(base, size):
|
|
self.found_.assign_vector(
|
|
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
|
self.bit_type(regint.inc(size, base=base) <
|
|
self.t.expand_to_vector(size)),
|
|
base=base)
|
|
|
|
# To determine whether the item is found in the stash, we simply
|
|
# check whether the demuxed array contains a True
|
|
# TODO: What if the index=0?
|
|
found.write(sum(self.found_))
|
|
lib.check_point()
|
|
|
|
# Store the stash item into the result if found
|
|
# If the item is not in the stash, the result will simple remain 0
|
|
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
|
self.entry_length, [self.value_type] * self.entry_length)
|
|
def stash_item(i):
|
|
entry = self.stash[i][:]
|
|
access_here = self.found_[i]
|
|
# This is a bit unfortunate
|
|
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
|
# Therefore we loop till self.T (the max value of self.t)
|
|
# is_in_time = i < self.t
|
|
|
|
return (entry * access_here)[:]
|
|
result += self.value_type(stash_item(), size=self.entry_length)
|
|
|
|
if trace:
|
|
# @lib.for_range(self.t)
|
|
# def _(i):
|
|
# lib.print_ln("stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
|
|
|
|
@lib.if_e(found.reveal() == 1)
|
|
def _():
|
|
lib.print_ln('Found item in stash (found=%s)', found.reveal())
|
|
|
|
@lib.else_
|
|
def __():
|
|
lib.print_ln('Did not find item in stash (found=%s)', found.reveal())
|
|
|
|
# Possible fake lookup of the item in the shuffle,
|
|
# depending on whether we already found the item in the stash
|
|
physical_address = self.position_map.get_position(index, found)
|
|
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
|
# with its equivalent on the stash once the period is up.
|
|
self.shuffle_used[physical_address] = cbit(True)
|
|
|
|
# If the item was not found in the stash
|
|
# the item retrieved from the shuffle is our result
|
|
result += self.shuffle[physical_address] * found.bit_not()
|
|
# We append the newly retrieved item to the stash
|
|
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
|
self.stashi[self.t] = self.shufflei[physical_address]
|
|
|
|
if trace:
|
|
lib.print_ln('Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
|
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
|
|
|
|
|
# Increase the "time" (i.e. access count in current period)
|
|
self.t.iadd(1)
|
|
|
|
return result
|
|
|
|
__getitem__ = read
|
|
__setitem__ = write
|
|
|
|
def shuffle_the_shuffle(self) -> None:
|
|
"""Permute the memory using a newly generated permutation and return
|
|
the permutation that would generate this particular shuffling.
|
|
|
|
This permutation is needed to know how to map logical addresses to
|
|
physical addresses, and is used as such by the postition map."""
|
|
|
|
global trace
|
|
# Random permutation on n elements
|
|
random_shuffle = sint.get_secure_shuffle(self.n)
|
|
if trace:
|
|
lib.print_ln('Generated shuffle')
|
|
# Apply the random permutation
|
|
self.shuffle.secure_permute(random_shuffle)
|
|
if trace:
|
|
lib.print_ln('Shuffled shuffle')
|
|
self.shufflei.secure_permute(random_shuffle)
|
|
if trace:
|
|
lib.print_ln('Shuffled shuffle indexes')
|
|
|
|
lib.check_point()
|
|
# Calculate the permutation that would have produced the newly produced
|
|
# shuffle order. This can be calculated by regarding the logical
|
|
# indexes (shufflei) as a permutation and calculating its inverse,
|
|
# i.e. find P such that P([1,2,3,...]) = shufflei.
|
|
# this is not necessarily equal to the inverse of the above generated
|
|
# random_shuffle, as the shuffle may already be out of order (e.g. when
|
|
# refreshing).
|
|
self.permutation.assign(self.shufflei[:].inverse_permutation())
|
|
# If shufflei does not contain exactly the indices
|
|
# [i for i in range(self.n)],
|
|
# the underlying waksman network of 'inverse_permutation' will hang.
|
|
if trace:
|
|
lib.print_ln('Calculated inverse permutation')
|
|
|
|
def refresh(self):
|
|
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
|
|
reshuffling the shuffle.
|
|
|
|
This must happen on the T'th (period) accesses to the ORAM."""
|
|
|
|
self.j.write(0)
|
|
# Shuffle and emtpy the stash, and store elements back into shuffle
|
|
|
|
@lib.for_range_opt(self.n)
|
|
def _(i):
|
|
@lib.if_(self.shuffle_used[i])
|
|
def _():
|
|
self.shuffle[i] = self.stash[self.j]
|
|
self.shufflei[i] = self.stashi[self.j]
|
|
self.j += 1
|
|
|
|
# Reset the clock
|
|
self.t.write(0)
|
|
# Reset shuffle_used
|
|
self._reset_shuffle_used()
|
|
|
|
# Reinitialize position map
|
|
self.shuffle_the_shuffle()
|
|
# Note that we skip here the step of "packing" the permutation.
|
|
# Since the underlying memory of the position map is already aligned in
|
|
# this packed structure, we can simply overwrite the memory while
|
|
# maintaining the structure.
|
|
self.position_map.reinitialize(self.permutation)
|
|
|
|
def reinitialize(self, data: T):
|
|
# Note that this method is only used during refresh, and as such is
|
|
# only called with a permutation as data.
|
|
|
|
# The logical addresses of some previous permutation are irrelevant and must be reset
|
|
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
|
|
# Reset the clock
|
|
self.t.write(0)
|
|
# Reset shuffle_used
|
|
self._reset_shuffle_used()
|
|
|
|
# Note that the self.shuffle is actually a MultiArray
|
|
# This structure is preserved while overwriting the values using
|
|
# assign_vector
|
|
self.shuffle.assign_vector(self.value_type(
|
|
data[:], size=self.n * self.entry_length))
|
|
# Note that this updates self.permutation (see constructor for explanation)
|
|
self.shuffle_the_shuffle()
|
|
self.position_map.reinitialize(self.permutation)
|
|
|
|
def _reset_shuffle_used(self):
|
|
global allow_memory_allocation
|
|
if allow_memory_allocation:
|
|
self.shuffle_used.assign_all(0)
|
|
else:
|
|
@lib.for_range_opt(self.n)
|
|
def _(i):
|
|
self.shuffle_used[i] = cint(0)
|
|
|
|
|
|
class PositionMap(Generic[T, B]):
|
|
PACK_LOG: int = 3
|
|
PACK: int = 1 << PACK_LOG
|
|
|
|
n: int # n in the paper
|
|
depth: cint # k in the paper
|
|
value_type: Type[T]
|
|
|
|
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
|
|
self.n = n
|
|
self.depth = MemValue(cint(k))
|
|
self.value_type = value_type
|
|
self.bit_type = value_type.bit_type
|
|
self.index_type = self.value_type.get_type(util.log2(n) + 1) # +1 because signed
|
|
|
|
@abstractmethod
|
|
def get_position(self, logical_address: _secret, fake: B) -> Any:
|
|
"""Retrieve the block at the given (secret) logical address."""
|
|
global trace
|
|
if trace:
|
|
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
|
|
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
|
|
|
def reinitialize(self, permutation: T):
|
|
"""Reinitialize this PositionMap.
|
|
|
|
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
|
|
we cannot simply call __init__ on self. Instead, we must take care to
|
|
reuse and overwrite the same memory.
|
|
"""
|
|
...
|
|
|
|
@classmethod
|
|
def create(cls, permutation: Array, k: int, period: int,
|
|
value_type: Type[T] = sint, **kwargs) -> PositionMap:
|
|
"""Creates a new PositionMap. This is the method one should call when
|
|
needing a new position map. Depending on the size of the given data, it
|
|
will either instantiate a RecursivePositionMap or
|
|
a LinearPositionMap."""
|
|
n = len(permutation)
|
|
|
|
global debug
|
|
if n / PositionMap.PACK <= period:
|
|
if debug:
|
|
lib.print_ln(
|
|
'Initializing LinearPositionMap at depth %s of size %s', k, n)
|
|
res = LinearPositionMap(permutation, value_type, k=k)
|
|
else:
|
|
if debug:
|
|
lib.print_ln(
|
|
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
|
|
res = RecursivePositionMap(permutation, period, value_type, k=k,
|
|
**kwargs)
|
|
|
|
return res
|
|
|
|
|
|
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
|
|
|
def __init__(self, permutation: Array, period: int,
|
|
value_type: Type[T] = sint, k: int = -1, **kwargs) -> None:
|
|
PositionMap.__init__(self, len(permutation), k=k)
|
|
pack = PositionMap.PACK
|
|
|
|
# We pack the permutation into a smaller structure, index with a new permutation
|
|
packed_size = int(math.ceil(self.n / pack))
|
|
packed_structure = MultiArray(
|
|
(packed_size, pack), value_type=value_type)
|
|
@lib.for_range(packed_size)
|
|
def _(i):
|
|
packed_structure[i] = Array.create_from(
|
|
permutation.get_vector(base=i * pack, size=pack))
|
|
|
|
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
|
period=period, entry_length=pack, k=self.depth,
|
|
**kwargs)
|
|
|
|
# Initialize random temp variables needed during the computation
|
|
self.block_index_demux: Array = self.bit_type.Array(self.T)
|
|
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
|
|
|
|
@lib.method_block
|
|
def get_position(self, logical_address: T, fake: B) -> _clear:
|
|
super().get_position(logical_address, fake)
|
|
|
|
pack = PositionMap.PACK
|
|
pack_log = PositionMap.PACK_LOG
|
|
|
|
# The item at logical_address
|
|
# will be in block with index h (block.<h>)
|
|
# at position l in block.data (block.data<l>)
|
|
program = Program.prog
|
|
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
|
|
logical_address).right_shift(pack_log, program.bit_length)))
|
|
l = self.value_type.bit_compose(
|
|
sbits.get_type(program.bit_length)(logical_address) & (pack - 1))
|
|
|
|
global trace
|
|
if trace:
|
|
print_at_depth(self.depth, '-> logical_address=%s: h=%s, l=%s', logical_address.reveal(), h.reveal(), l.reveal())
|
|
# @lib.for_range(self.t)
|
|
# def _(i):
|
|
# print_at_depth(self.depth, "stash[%s]=(%s: %s)", i, self.stashi[i].reveal() ,self.stash[i].reveal())
|
|
|
|
# The resulting physical address
|
|
p = MemValue(self.index_type(-1))
|
|
found: B = MemValue(self.bit_type(False))
|
|
|
|
# First we try and retrieve the item from the stash at position stash[h][l]
|
|
# Since h and l are secret, we do this by scanning the entire stash
|
|
|
|
# First we scan the stash for the block we need
|
|
self.block_index_demux.assign_all(0)
|
|
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
|
|
def _(i):
|
|
self.block_index_demux[i] = ( self.stashi[i] == h) & self.bit_type(i < self.t)
|
|
# We can determine if the 'index' is in the stash by checking the
|
|
# block_index_demux array
|
|
found = sum(self.block_index_demux)
|
|
# Once a block is found, we use the following condition to pick the correct item from that block
|
|
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
|
|
|
|
# Finally we use the conditions to conditionally write p
|
|
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
|
|
def p_(i):
|
|
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
|
|
# Therefore we include the check (i < self.t)
|
|
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i // pack< self.t)
|
|
p.write(p_())
|
|
|
|
if trace:
|
|
@lib.if_e(found.reveal() == 0)
|
|
def _(): print_at_depth(self.depth, 'Retrieve shuffle[%s]:', h.reveal())
|
|
@lib.else_
|
|
def __():
|
|
print_at_depth(self.depth, 'Retrieve dummy element from shuffle:')
|
|
|
|
# Then we try and retrieve the item from the shuffle (the actual memory)
|
|
# Depending on whether we found the item in the stash, we either
|
|
# block 'h' in which 'index' resides, or a random block from the shuffle
|
|
p_prime = self.position_map.get_position(h, found)
|
|
self.shuffle_used[p_prime] = cbit(True)
|
|
|
|
# The block retrieved from the shuffle
|
|
block_p_prime: Array = self.shuffle[p_prime]
|
|
|
|
if trace:
|
|
@lib.if_e(found.reveal() == 0)
|
|
def _():
|
|
print_at_depth(self.depth, 'Retrieved position from shuffle[%s]=(%s: %s)',
|
|
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
|
|
|
|
@lib.else_
|
|
def __():
|
|
print_at_depth(self.depth, 'Retrieved dummy position from shuffle[%s]=(%s: %s)',
|
|
p_prime.reveal(), self.shufflei[p_prime].reveal(), self.shuffle[p_prime].reveal())
|
|
|
|
# We add the retrieved block from the shuffle to the stash
|
|
self.stash[self.t].assign(block_p_prime[:])
|
|
self.stashi[self.t] = self.shufflei[p_prime]
|
|
# Increase t
|
|
self.t += 1
|
|
|
|
# if found or not fake
|
|
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
|
|
# Retrieve l'th item from block
|
|
# l is secret, so we must use linear scan
|
|
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
|
|
pack)) & condition.expand_to_vector(pack))
|
|
|
|
@lib.for_range_opt(pack)
|
|
def _(i):
|
|
p.write((hit[i]).if_else(block_p_prime[i], p))
|
|
|
|
return p.reveal()
|
|
|
|
def reinitialize(self, permutation: T):
|
|
SqrtOram.reinitialize(self, permutation)
|
|
|
|
|
|
class LinearPositionMap(PositionMap):
|
|
physical: Array
|
|
used: Array
|
|
|
|
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
|
|
PositionMap.__init__(self, len(data), value_type, k=k)
|
|
self.physical = data
|
|
self.used = self.bit_type.Array(self.n)
|
|
|
|
# Initialize random temp variables needed during the computation
|
|
self.physical_demux: Array = self.bit_type.Array(self.n)
|
|
|
|
@lib.method_block
|
|
def get_position(self, logical_address: T, fake: B) -> _clear:
|
|
"""
|
|
This method corresponds to GetPosBase in the paper.
|
|
"""
|
|
super().get_position(logical_address, fake)
|
|
|
|
global trace
|
|
if trace:
|
|
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
|
|
def _():
|
|
lib.runtime_error(
|
|
'logical_address must lie between 0 and self.n - 1')
|
|
|
|
fake = MemValue(self.bit_type(fake))
|
|
logical_address = MemValue(logical_address)
|
|
|
|
p: MemValue = MemValue(self.index_type(-1))
|
|
done: B = self.bit_type(False)
|
|
|
|
# In order to get an address at secret logical_address,
|
|
# we need to perform a linear scan.
|
|
self.physical_demux.assign_all(0)
|
|
|
|
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
|
|
def condition_i(i):
|
|
self.physical_demux[i] = \
|
|
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
|
|
| (fake & self.used[i].bit_not())
|
|
|
|
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
|
|
# We only need once, so we pick the first one we find
|
|
@lib.for_range_opt(self.n)
|
|
def _(i):
|
|
self.physical_demux[i] &= done.bit_not()
|
|
done.update(done | self.physical_demux[i])
|
|
|
|
# Retrieve the value from the physical memory obliviously
|
|
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
|
|
def calc_p(i):
|
|
return self.physical[i] * self.physical_demux[i]
|
|
p.write(calc_p())
|
|
|
|
# Update self.used
|
|
self.used.assign(self.used[:] | self.physical_demux[:])
|
|
|
|
if trace:
|
|
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
|
|
def _():
|
|
lib.runtime_error(
|
|
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
|
|
|
|
return p.reveal()
|
|
|
|
def reinitialize(self, data : T):
|
|
self.physical.assign(data)
|
|
|
|
global allow_memory_allocation
|
|
if allow_memory_allocation:
|
|
self.used.assign_all(False)
|
|
else:
|
|
@lib.for_range_opt(self.n)
|
|
def _(i):
|
|
self.used[i] = self.bit_type(0)
|
|
|
|
def print_at_depth(depth: cint, message: str, *kwargs):
|
|
lib.print_str('%s', depth)
|
|
@lib.for_range(depth)
|
|
def _(i):
|
|
lib.print_char(' ')
|
|
lib.print_char(' ')
|
|
lib.print_ln(message, *kwargs)
|