mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
Add allow_memory_allocation option to SqrtORAM
Also remove unused swap function in SqrtORAM
This commit is contained in:
@@ -1,32 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
from Compiler import util
|
||||
from Compiler import library as lib
|
||||
from Compiler import util
|
||||
from Compiler.GC.types import cbit, sbit, sbitint, sbits
|
||||
from Compiler.program import Program
|
||||
from Compiler.types import (
|
||||
Array,
|
||||
MemValue,
|
||||
MultiArray,
|
||||
_clear,
|
||||
_secret,
|
||||
cint,
|
||||
regint,
|
||||
sint,
|
||||
sintbit,
|
||||
)
|
||||
from oram import get_n_threads
|
||||
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
|
||||
regint, sint, sintbit)
|
||||
from oram import demux_array, get_n_threads
|
||||
|
||||
program = Program.prog
|
||||
|
||||
debug = True
|
||||
trace = True
|
||||
n_threads = 8
|
||||
# Adds messages on completion of heavy computation steps
|
||||
debug = False
|
||||
# Finer grained trace of steps that the ORAM performs
|
||||
# + runtime error checks
|
||||
# Warning: reveals information and makes the computation insecure
|
||||
trace = False
|
||||
|
||||
n_threads = 16
|
||||
n_parallel = 1
|
||||
|
||||
# Avoids any memory allocation
|
||||
# This prevents some optimizations but allows for using the ORAMs outside of the main tape
|
||||
allow_memory_allocation = True
|
||||
|
||||
|
||||
def get_n_threads(n_loops):
|
||||
if n_threads is None:
|
||||
if n_loops > 2048:
|
||||
@@ -37,25 +39,6 @@ def get_n_threads(n_loops):
|
||||
return n_threads
|
||||
|
||||
|
||||
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
|
||||
"""Swap two positions in an Array if a condition is met.
|
||||
|
||||
Args:
|
||||
array (Array | MultiArray): The array in which to swap the first and second position
|
||||
pos_a (int | cint): The first position
|
||||
pos_b (int | cint): The second position
|
||||
cond (sintbit | sbit): The condition determining whether to swap
|
||||
"""
|
||||
if isinstance(array, MultiArray):
|
||||
temp = array[pos_b][:]
|
||||
array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:]))
|
||||
array[pos_a].assign(cond.if_else(temp, array[pos_a][:]))
|
||||
if isinstance(array, Array):
|
||||
temp = array[pos_b]
|
||||
array[pos_b] = cond.if_else(array[pos_a], array[pos_b])
|
||||
array[pos_a] = cond.if_else(temp, array[pos_a])
|
||||
|
||||
|
||||
T = TypeVar("T", sint, sbitint)
|
||||
B = TypeVar("B", sintbit, sbit)
|
||||
|
||||
@@ -85,7 +68,7 @@ class SqrtOram(Generic[T, B]):
|
||||
# the stash)
|
||||
t: cint
|
||||
|
||||
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
|
||||
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True) -> None:
|
||||
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
|
||||
|
||||
Args:
|
||||
@@ -94,6 +77,9 @@ class SqrtOram(Generic[T, B]):
|
||||
k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
|
||||
period (int): Leave at None, this parameter is used to recursively pass down the top-level period.
|
||||
"""
|
||||
global debug, allow_memory_allocation
|
||||
|
||||
# Correctly initialize the shuffle (memory) depending on the type of data
|
||||
if isinstance(data, MultiArray):
|
||||
self.shuffle = data
|
||||
self.n = len(data)
|
||||
@@ -107,9 +93,12 @@ class SqrtOram(Generic[T, B]):
|
||||
else:
|
||||
raise Exception("Incorrect format.")
|
||||
|
||||
self.value_type = value_type
|
||||
# Only sint is supported
|
||||
if value_type != sint and value_type != sbitint:
|
||||
raise Exception("The value_type must be either sint or sbitint")
|
||||
|
||||
# Set derived constants
|
||||
self.value_type = value_type
|
||||
self.bit_type: Type[B] = value_type.bit_type
|
||||
self.index_size = util.log2(self.n)
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
@@ -118,11 +107,11 @@ class SqrtOram(Generic[T, B]):
|
||||
if debug:
|
||||
lib.print_ln(
|
||||
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
||||
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shufflei = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
permutation = Array.create_from(self.shuffle_the_shuffle())
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
@@ -130,8 +119,21 @@ class SqrtOram(Generic[T, B]):
|
||||
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
||||
if debug and not period:
|
||||
lib.print_ln('Period set to %s', self.T)
|
||||
|
||||
# Here we allocate the memory for the permutation
|
||||
# Note that self.shuffle_the_shuffle mutates this field
|
||||
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
|
||||
self.permutation = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
# We allow the caller to postpone the initialization of the shuffle
|
||||
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
|
||||
# Note that if you do not initialize, the ORAM is insecure
|
||||
if initialize:
|
||||
self.shuffle_the_shuffle()
|
||||
else:
|
||||
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
|
||||
# Initialize position map (recursive oram)
|
||||
self.position_map = PositionMap.create(permutation, k + 1, self.T)
|
||||
self.position_map = PositionMap.create(self.permutation, k + 1, self.T)
|
||||
|
||||
# Initialize stash
|
||||
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
|
||||
@@ -140,19 +142,28 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
# Initialize temp variables needed during the computation
|
||||
self.found_ = self.bit_type.Array(size=self.T)
|
||||
self.j = MemValue(cint(0, size=1))
|
||||
|
||||
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
|
||||
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
|
||||
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
|
||||
self.shuffle_the_shuffle = lib.method_block(self.shuffle_the_shuffle) if allow_memory_allocation else self.shuffle_the_shuffle
|
||||
self.refresh = lib.method_block(self.refresh) if allow_memory_allocation else self.refresh
|
||||
|
||||
@lib.method_block
|
||||
def access(self, index: T, write: B, *value: T):
|
||||
global trace,n_parallel
|
||||
if trace:
|
||||
@lib.if_e(write.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
lib.print_ln(' Writing to secret index %s', index.reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
lib.print_ln(' Reading from secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length)
|
||||
value = self.value_type(value, size=self.entry_length).get_vector(
|
||||
0, size=self.entry_length)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@@ -171,8 +182,9 @@ class SqrtOram(Generic[T, B]):
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
@@ -200,11 +212,11 @@ class SqrtOram(Generic[T, B]):
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
lib.print_ln(' Found item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
lib.print_ln(' Did not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
@@ -215,7 +227,8 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
self.shuffle[physical_address] += write * \
|
||||
found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
@@ -225,10 +238,8 @@ class SqrtOram(Generic[T, B]):
|
||||
if trace:
|
||||
@lib.if_((write * found.bit_not()).reveal())
|
||||
def _():
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
||||
), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
@@ -237,8 +248,9 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
global trace, n_parallel
|
||||
if trace:
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
lib.print_ln(' Writing to secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value)
|
||||
index = MemValue(index)
|
||||
@@ -259,8 +271,9 @@ class SqrtOram(Generic[T, B]):
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
@@ -286,11 +299,11 @@ class SqrtOram(Generic[T, B]):
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
lib.print_ln(' Found item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
lib.print_ln(' Did not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
@@ -301,7 +314,8 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
self.shuffle[physical_address] += found.bit_not() * \
|
||||
(value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
@@ -311,9 +325,10 @@ class SqrtOram(Generic[T, B]):
|
||||
if trace:
|
||||
@lib.if_(found.bit_not().reveal())
|
||||
def _():
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
|
||||
lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
|
||||
), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
@@ -323,14 +338,20 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
@lib.method_block
|
||||
def read(self, index: T, *value: T):
|
||||
global debug, trace, n_parallel
|
||||
if trace:
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
lib.print_ln(' Reading from secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
if debug:
|
||||
lib.print_ln(' Refreshing SqrtORAM')
|
||||
lib.print_ln(' t=%s according to me', self.t)
|
||||
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
@@ -344,8 +365,9 @@ class SqrtOram(Generic[T, B]):
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
|
||||
self.bit_type(regint.inc(size, base=base) <
|
||||
self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
@@ -371,11 +393,11 @@ class SqrtOram(Generic[T, B]):
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
lib.print_ln(' Found item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
lib.print_ln(' Did not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
@@ -392,7 +414,7 @@ class SqrtOram(Generic[T, B]):
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
@@ -403,25 +425,36 @@ class SqrtOram(Generic[T, B]):
|
||||
__getitem__ = read
|
||||
__setitem__ = write
|
||||
|
||||
@lib.method_block
|
||||
def shuffle_the_shuffle(self):
|
||||
def shuffle_the_shuffle(self) -> None:
|
||||
"""Permute the memory using a newly generated permutation and return
|
||||
the permutation that would generate this particular shuffling.
|
||||
|
||||
This permutation is needed to know how to map logical addresses to
|
||||
physical addresses, and is used as such by the postition map."""
|
||||
|
||||
global trace
|
||||
# Random permutation on n elements
|
||||
random_shuffle = sint.get_secure_shuffle(self.n)
|
||||
if trace:
|
||||
lib.print_ln('\tGenerated shuffle')
|
||||
lib.print_ln(' Generated shuffle')
|
||||
# Apply the random permutation
|
||||
self.shuffle.secure_permute(random_shuffle)
|
||||
if trace:
|
||||
lib.print_ln('\tShuffled shuffle')
|
||||
lib.print_ln(' Shuffled shuffle')
|
||||
self.shufflei.secure_permute(random_shuffle)
|
||||
if trace:
|
||||
lib.print_ln('\tShuffled shuffle indexes')
|
||||
lib.print_ln(' Shuffled shuffle indexes')
|
||||
|
||||
if trace:
|
||||
# If shufflei does not contain exactly the indices [i for i in
|
||||
# range(self.n)], the underlying waksman network of
|
||||
# 'inverse_permutation' will hang.
|
||||
tmp_shuffli = Array.create_from(self.shufflei[:])
|
||||
@lib.if_(sum(lib.sort(tmp_shuffli)[:] == Array.create_from([cint(i) for i in range(self.n)])[:]).reveal() != self.n)
|
||||
def _():
|
||||
lib.print_ln(
|
||||
'Shufflei is corrupted! You have found a bug in the implementation :c\nThe computation will now hang...')
|
||||
|
||||
# Calculate the permutation that would have produced the newly produced
|
||||
# shuffle order. This can be calculated by regarding the logical
|
||||
# indexes (shufflei) as a permutation and calculating its inverse,
|
||||
@@ -429,45 +462,45 @@ class SqrtOram(Generic[T, B]):
|
||||
# this is not necessarily equal to the inverse of the above generated
|
||||
# random_shuffle, as the shuffle may already be out of order (e.g. when
|
||||
# refreshing).
|
||||
permutation = MemValue(self.shufflei[:].inverse_permutation())
|
||||
self.permutation.assign(self.shufflei[:].inverse_permutation())
|
||||
if trace:
|
||||
lib.print_ln('\tCalculated inverse permutation')
|
||||
return permutation
|
||||
lib.print_ln(' Calculated inverse permutation')
|
||||
|
||||
@lib.method_block
|
||||
def refresh(self):
|
||||
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
|
||||
reshuffling the shuffle.
|
||||
|
||||
This must happen after T (period) accesses to the ORAM."""
|
||||
|
||||
if trace:
|
||||
lib.print_ln('Refreshing SqrtORAM')
|
||||
This must happen on the T'th (period) accesses to the ORAM."""
|
||||
|
||||
self.j.write(0)
|
||||
# Shuffle and emtpy the stash, and store elements back into shuffle
|
||||
j = MemValue(cint(0, size=1))
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
@lib.if_(self.shuffle_used[i])
|
||||
def _():
|
||||
nonlocal j
|
||||
self.shuffle[i] = self.stash[j]
|
||||
self.shufflei[i] = self.stashi[j]
|
||||
j += 1
|
||||
self.shuffle[i] = self.stash[self.j]
|
||||
self.shufflei[i] = self.stashi[self.j]
|
||||
self.j += 1
|
||||
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
self.shuffle_used.assign_all(0)
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
self.shuffle_used.assign_all(0)
|
||||
else:
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
self.shuffle_used[i] = cint(0)
|
||||
|
||||
# Reinitialize position map
|
||||
permutation = self.shuffle_the_shuffle()
|
||||
self.shuffle_the_shuffle()
|
||||
# Note that we skip here the step of "packing" the permutation.
|
||||
# Since the underlying memory of the position map is already aligned in
|
||||
# this packed structure, we can simply overwrite the memory while
|
||||
# maintaining the structure.
|
||||
self.position_map.reinitialize(*permutation)
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
|
||||
@lib.method_block
|
||||
def reinitialize(self, *data: T):
|
||||
@@ -478,7 +511,7 @@ class SqrtOram(Generic[T, B]):
|
||||
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
# Reset shuffle_used
|
||||
self.shuffle_used.assign_all(0)
|
||||
|
||||
# Note that the self.shuffle is actually a MultiArray
|
||||
@@ -486,8 +519,10 @@ class SqrtOram(Generic[T, B]):
|
||||
# assign_vector
|
||||
self.shuffle.assign_vector(self.value_type(
|
||||
data, size=self.n * self.entry_length))
|
||||
permutation = self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*permutation)
|
||||
# Note that this updates self.permutation (see constructor for explanation)
|
||||
self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
|
||||
|
||||
|
||||
class PositionMap(Generic[T, B]):
|
||||
@@ -508,8 +543,9 @@ class PositionMap(Generic[T, B]):
|
||||
@abstractmethod
|
||||
def get_position(self, logical_address: _secret, fake: B) -> Any:
|
||||
"""Retrieve the block at the given (secret) logical address."""
|
||||
global trace
|
||||
if trace:
|
||||
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth,
|
||||
lib.print_ln(' %s Scanning %s for logical address %s (fake=%s)', self.depth,
|
||||
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
@@ -529,6 +565,7 @@ class PositionMap(Generic[T, B]):
|
||||
a LinearPositionMap."""
|
||||
n = len(permutation)
|
||||
|
||||
global debug
|
||||
if n / PositionMap.PACK <= period:
|
||||
if debug:
|
||||
lib.print_ln(
|
||||
@@ -561,6 +598,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
||||
period=period, entry_length=pack, k=self.depth)
|
||||
|
||||
# Initialize random temp variables needed during the computation
|
||||
self.block_index_demux: Array = self.bit_type.Array(self.T)
|
||||
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
|
||||
|
||||
@lib.method_block
|
||||
def get_position(self, logical_address: T, fake: B) -> _clear:
|
||||
super().get_position(logical_address, fake)
|
||||
@@ -576,50 +617,42 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
|
||||
|
||||
# The resulting physical address
|
||||
p = MemValue(self.index_type(0))
|
||||
p = MemValue(self.index_type(-1))
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
|
||||
# First we try and retrieve the item from the stash at position stash[h][l]
|
||||
# Since h and l are secret, we do this by scanning the entire stash
|
||||
|
||||
# First we scan the stash for the block we need
|
||||
condition1 = self.bit_type.Array(self.T)
|
||||
self.block_index_demux.assign_all(0)
|
||||
|
||||
@lib.for_range_opt_multithread(8, self.T)
|
||||
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
|
||||
def _(i):
|
||||
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
|
||||
found = sum(condition1)
|
||||
# Once a block is found, we use condition2 to pick the correct item from that block
|
||||
condition2 = Array.create_from(
|
||||
regint.inc(pack) == l.expand_to_vector(pack))
|
||||
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
|
||||
condition3 = self.bit_type.Array(self.T * pack)
|
||||
self.block_index_demux[i] = (
|
||||
self.stashi[i] == h) & self.bit_type(i < self.t)
|
||||
# We can determine if the 'index' is in the stash by checking the
|
||||
# block_index_demux array
|
||||
found = sum(self.block_index_demux)
|
||||
# Once a block is found, we use the following condition to pick the correct item from that block
|
||||
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
|
||||
|
||||
@lib.for_range_opt_multithread(8, [self.T, pack])
|
||||
def _(i, j):
|
||||
condition3[i*pack + j] = condition1[i] & condition2[j]
|
||||
# Finally we use condition3 to conditionally write p
|
||||
|
||||
@lib.for_range(self.t)
|
||||
def _(i):
|
||||
@lib.for_range(pack)
|
||||
def _(j):
|
||||
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
|
||||
|
||||
if trace:
|
||||
@lib.if_(condition1[i].reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(
|
||||
), self.stash[i].reveal())
|
||||
|
||||
# Then we try and retrieve the item from the shuffle (the actual memory)
|
||||
# Finally we use the conditions to conditionally write p
|
||||
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
|
||||
def p_(i):
|
||||
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
|
||||
# Therefore we include the check (i < self.t)
|
||||
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i < self.t)
|
||||
p.write(p_())
|
||||
|
||||
global trace
|
||||
if trace:
|
||||
@lib.if_(found.reveal() == 0)
|
||||
def _():
|
||||
lib.print_ln('\t%s Position not in stash', self.depth)
|
||||
lib.print_ln(' %s Position not in stash', self.depth)
|
||||
|
||||
# Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle
|
||||
# Then we try and retrieve the item from the shuffle (the actual memory)
|
||||
# Depending on whether we found the item in the stash, we either
|
||||
# block 'h' in which 'index' resides, or a random block from the shuffle
|
||||
p_prime = self.position_map.get_position(h, found)
|
||||
self.shuffle_used[p_prime] = cbit(True)
|
||||
|
||||
@@ -629,12 +662,12 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 0)
|
||||
def _():
|
||||
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
|
||||
lib.print_ln(' %s Retrieved stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
|
||||
lib.print_ln(' %s Retrieved dummy stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
|
||||
# We add the retrieved block from the shuffle to the stash
|
||||
@@ -680,6 +713,13 @@ class LinearPositionMap(PositionMap):
|
||||
"""
|
||||
super().get_position(logical_address, fake)
|
||||
|
||||
global trace
|
||||
if trace:
|
||||
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
|
||||
def _():
|
||||
lib.runtime_error(
|
||||
'logical_address must lie between 0 and self.n - 1')
|
||||
|
||||
fake = MemValue(self.bit_type(fake))
|
||||
logical_address = MemValue(logical_address)
|
||||
|
||||
@@ -689,11 +729,12 @@ class LinearPositionMap(PositionMap):
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
self.physical_demux.assign_all(0)
|
||||
@lib.for_range_opt_multithread(8, self.n)
|
||||
|
||||
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
|
||||
def condition_i(i):
|
||||
self.physical_demux.assign((self.bit_type(fake).bit_not()
|
||||
& self.bit_type(logical_address == i)) | (fake
|
||||
& self.used[i].bit_not()), base=i)
|
||||
self.physical_demux[i] = \
|
||||
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
|
||||
| (fake & self.used[i].bit_not())
|
||||
|
||||
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
|
||||
# We only need once, so we pick the first one we find
|
||||
@@ -704,7 +745,7 @@ class LinearPositionMap(PositionMap):
|
||||
done |= self.physical_demux[i]
|
||||
|
||||
# Retrieve the value from the physical memory obliviously
|
||||
@lib.map_sum_opt(8, self.n, [self.value_type])
|
||||
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
|
||||
def calc_p(i):
|
||||
return self.physical[i] * self.physical_demux[i]
|
||||
p.write(calc_p())
|
||||
@@ -720,7 +761,13 @@ class LinearPositionMap(PositionMap):
|
||||
|
||||
return p.reveal()
|
||||
|
||||
@lib.method_block
|
||||
def reinitialize(self, *data: T):
|
||||
self.physical.assign_vector(data)
|
||||
self.used.assign_all(False)
|
||||
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
self.used.assign_all(False)
|
||||
else:
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
self.used[i] = self.bit_type(0)
|
||||
|
||||
Reference in New Issue
Block a user