Improve multithreading and remove non-multithreaded code

This commit is contained in:
Kevin Witlox
2022-08-01 17:58:45 +02:00
parent 33299e78a5
commit 2cd263dad0

View File

@@ -3,10 +3,10 @@ from abc import abstractmethod
import math
from typing import Any, Generic, Type, TypeVar
from Compiler.program import Program
from Compiler import util
from Compiler import library as lib
from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.program import Program
from Compiler.types import (
Array,
MemValue,
@@ -14,16 +14,28 @@ from Compiler.types import (
_clear,
_secret,
cint,
regint,
sint,
sintbit,
regint
)
from oram import get_n_threads
program = Program.prog
debug = False
debug = True
trace = True
n_threads = 8
multithreading = True
n_parallel = 1
def get_n_threads(n_loops):
if n_threads is None:
if n_loops > 2048:
return 8
else:
return None
else:
return n_threads
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
"""Swap two positions in an Array if a condition is met.
@@ -43,9 +55,11 @@ def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond:
array[pos_b] = cond.if_else(array[pos_a], array[pos_b])
array[pos_a] = cond.if_else(temp, array[pos_a])
T = TypeVar("T", sint, sbitint)
B = TypeVar("B", sintbit, sbit)
class SqrtOram(Generic[T, B]):
# TODO: Preferably this is an Array of vectors, but this is currently not supported
# One should regard these structures as Arrays where an entry may hold more
@@ -75,7 +89,7 @@ class SqrtOram(Generic[T, B]):
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
Args:
data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array).
data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
value_type (sint): The secret type to use, defaults to sint.
k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
period (int): Leave at None, this parameter is used to recursively pass down the top-level period.
@@ -87,7 +101,8 @@ class SqrtOram(Generic[T, B]):
self.n = math.ceil(len(data) // entry_length)
if (len(data) % entry_length != 0):
raise Exception('Data incorrectly padded.')
self.shuffle = MultiArray((self.n, entry_length), value_type=value_type)
self.shuffle = MultiArray(
(self.n, entry_length), value_type=value_type)
self.shuffle.assign_part_vector(data.get_vector())
else:
raise Exception("Incorrect format.")
@@ -96,19 +111,23 @@ class SqrtOram(Generic[T, B]):
if value_type != sint and value_type != sbitint:
raise Exception("The value_type must be either sint or sbitint")
self.bit_type: Type[B] = value_type.bit_type
self.index_type = value_type.get_type(util.log2(self.n))
self.index_size = util.log2(self.n)
self.index_type = value_type.get_type(self.index_size)
self.entry_length = entry_length
if debug:
lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k)
lib.print_ln(
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)])
self.shufflei = Array.create_from(
[self.index_type(i) for i in range(self.n)])
permutation = Array.create_from(self.shuffle_the_shuffle())
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
self.T = int(math.ceil(
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
if debug and not period:
lib.print_ln('Period set to %s', self.T)
# Initialize position map (recursive oram)
@@ -119,28 +138,108 @@ class SqrtOram(Generic[T, B]):
self.stashi = Array(self.T, value_type=value_type)
self.t = MemValue(cint(0))
@lib.method_block
def read(self, index: T):
value = self.value_type(0, size=self.entry_length)
return self.access(index, self.bit_type(False), *value)
@lib.method_block
def write(self, index: T, *value: T):
lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length")
self.access(index, self.bit_type(True), *value)
__getitem__ = read
__setitem__ = write
# Initialize temp variables needed during the computation
self.found_ = self.bit_type.Array(size=self.T)
@lib.method_block
def access(self, index: T, write: B, *value: T):
if debug:
if trace:
@lib.if_e(write.reveal() == 1)
def _():
lib.print_ln('Writing to secret index %s', index.reveal())
@lib.else_
def __():
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# If we are writing, we need to add the value
self.stash[i] += write * access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
@lib.else_
def __():
lib.print_ln('\tDid not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
@lib.if_((write * found.bit_not()).reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
@lib.method_block
def write(self, index: T, *value: T):
if trace:
lib.print_ln('Writing to secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
@@ -150,87 +249,159 @@ class SqrtOram(Generic[T, B]):
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# Scan through the stash
@lib.if_(self.t > 0)
def _():
nonlocal found
found |= index == self.stashi[0]
# We ensure that if the item is found in stash, it ends up in the first
# position (more importantly, a fixed position) of the stash
# This allows us to keep track of it in an oblivious manner
if multithreading:
found_ = self.bit_type.Array(size=self.T)
@lib.multithread(1, self.T)
def _(base, size):
found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size))
& self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
base=base)
@lib.for_range_opt(self.t - 1)
def _(i):
swap(self.stash, 0, i + 1, found_[i+1])
swap(self.stashi, 0, i + 1, found_[i+1])
found.write(sum(found_))
else:
@lib.for_range_opt(self.t - 1)
def _(i):
nonlocal found
found_: B = index == self.stashi[i + 1]
swap(self.stash, 0, i + 1, found_)
swap(self.stashi, 0, i + 1, found_)
found |= found_
# If the item was not in the stash, we move the unknown and unimportant
# stash[0] out of the way (to the end of the stash)
swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
# First we scan the stash for the item
self.found_.assign_all(0)
if debug:
# This will result in an bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
# We update the stash value
self.stash[i] += access_here * (value - entry)
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
@lib.else_
def __():
lib.print_ln('\tItem not in stash')
lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
lib.print_ln('\tDid not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was in the stash (thus currently residing in stash[0]),
# we place the random item retrieved from the shuffle at the end of the stash
self.stash[self.t].assign(found.if_else(
self.shuffle[physical_address][:],
self.stash[self.t][:]))
self.stashi[self.t] = found.if_else(
self.shufflei[physical_address],
self.stashi[self.t])
# If the item was not found in the stash, we place the item retrieved
# from the shuffle (the item we are actually looking for) in stash[0]
self.stash[0].assign(found.bit_not().if_else(
self.shuffle[physical_address][:],
self.stash[0][:]))
self.stashi[0] = found.bit_not().if_else(
self.shufflei[physical_address],
self.stashi[0])
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if debug:
@lib.if_e(found.reveal() == 1)
if trace:
@lib.if_(found.bit_not().reveal())
def _():
lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal())
@lib.else_
def __():
lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal())
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
self.stash[0].assign(write.if_else(value, self.stash[0][:]))
value=write.bit_not().if_else(self.stash[0][:], value)
return value
return result
@lib.method_block
def read(self, index: T, *value: T):
if trace:
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
self.refresh()
found: B = MemValue(self.bit_type(False))
result: T = MemValue(self.value_type(0, size=self.entry_length))
# First we scan the stash for the item
self.found_.assign_all(0)
# This will result in a bit array with at most one True,
# indicating where in the stash 'index' is found
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
# check wheterh the demuxed array contains a True
# TODO: What if the index=0?
found.write(sum(self.found_))
# Store the stash item into the result if found
# If the item is not in the stash, the result will simple remain 0
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
self.entry_length, [self.value_type] * self.entry_length)
def stash_item(i):
entry = self.stash[i][:]
access_here = self.found_[i]
# This is a bit unfortunate
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
# Therefore we loop till self.T (the max value of self.t)
# is_in_time = i < self.t
return (entry * access_here)[:]
result += self.value_type(stash_item(), size=self.entry_length)
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
@lib.else_
def __():
lib.print_ln('\tDid not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
physical_address = self.position_map.get_position(index, found)
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
# with its equivalent on the stash once the period is up.
self.shuffle_used[physical_address] = cbit(True)
# If the item was not found in the stash
# the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
self.stash[self.t].assign(self.shuffle[physical_address][:])
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
return result
__getitem__ = read
__setitem__ = write
@lib.method_block
def shuffle_the_shuffle(self):
@@ -242,12 +413,15 @@ class SqrtOram(Generic[T, B]):
# Random permutation on n elements
random_shuffle = sint.get_secure_shuffle(self.n)
if debug: lib.print_ln('\tGenerated shuffle')
if trace:
lib.print_ln('\tGenerated shuffle')
# Apply the random permutation
self.shuffle.secure_permute(random_shuffle)
if debug: lib.print_ln('\tShuffled shuffle')
if trace:
lib.print_ln('\tShuffled shuffle')
self.shufflei.secure_permute(random_shuffle)
if debug: lib.print_ln('\tShuffled shuffle indexes')
if trace:
lib.print_ln('\tShuffled shuffle indexes')
# Calculate the permutation that would have produced the newly produced
# shuffle order. This can be calculated by regarding the logical
# indexes (shufflei) as a permutation and calculating its inverse,
@@ -256,7 +430,8 @@ class SqrtOram(Generic[T, B]):
# random_shuffle, as the shuffle may already be out of order (e.g. when
# refreshing).
permutation = MemValue(self.shufflei[:].inverse_permutation())
if debug: lib.print_ln('\tCalculated inverse permutation')
if trace:
lib.print_ln('\tCalculated inverse permutation')
return permutation
@lib.method_block
@@ -266,10 +441,12 @@ class SqrtOram(Generic[T, B]):
This must happen after T (period) accesses to the ORAM."""
if debug: lib.print_ln('Refreshing SqrtORAM')
if trace:
lib.print_ln('Refreshing SqrtORAM')
# Shuffle and emtpy the stash, and store elements back into shuffle
j = MemValue(cint(0,size=1))
j = MemValue(cint(0, size=1))
@lib.for_range_opt(self.n)
def _(i):
@lib.if_(self.shuffle_used[i])
@@ -301,13 +478,14 @@ class SqrtOram(Generic[T, B]):
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
# Reset the clock
self.t.write(0)
# Reset shuffle_used
# Reset shuffle_used
self.shuffle_used.assign_all(0)
# Note that the self.shuffle is actually a MultiArray
# This structure is preserved while overwriting the values using
# assign_vector
self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length))
self.shuffle.assign_vector(self.value_type(
data, size=self.n * self.entry_length))
permutation = self.shuffle_the_shuffle()
self.position_map.reinitialize(*permutation)
@@ -316,13 +494,13 @@ class PositionMap(Generic[T, B]):
PACK_LOG: int = 3
PACK: int = 1 << PACK_LOG
n: int # n in the paper
depth: int # k in the paper
n: int # n in the paper
depth: int # k in the paper
value_type: Type[T]
def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None:
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
self.n = n
self.depth=MemValue(cint(k))
self.depth = MemValue(cint(k))
self.value_type = value_type
self.bit_type = value_type.bit_type
self.index_type = self.value_type.get_type(util.log2(n))
@@ -330,8 +508,9 @@ class PositionMap(Generic[T, B]):
@abstractmethod
def get_position(self, logical_address: _secret, fake: B) -> Any:
"""Retrieve the block at the given (secret) logical address."""
if debug:
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
if trace:
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth,
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, *permutation: T):
"""Reinitialize this PositionMap.
@@ -352,11 +531,13 @@ class PositionMap(Generic[T, B]):
if n / PositionMap.PACK <= period:
if debug:
lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n)
lib.print_ln(
'Initializing LinearPositionMap at depth %s of size %s', k, n)
res = LinearPositionMap(permutation, value_type, k=k)
else:
if debug:
lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n)
lib.print_ln(
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
res = RecursivePositionMap(permutation, period, value_type, k=k)
return res
@@ -364,7 +545,7 @@ class PositionMap(Generic[T, B]):
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None:
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None:
PositionMap.__init__(self, len(permutation), k=k)
pack = PositionMap.PACK
@@ -377,7 +558,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
permutation[i*pack:(i+1)*pack])
# TODO: Should this be n or packed_size?
SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth)
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
@@ -389,7 +571,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
# The item at logical_address
# will be in block with index h (block.<h>)
# at position l in block.data (block.data<l>)
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length)))
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
logical_address).right_shift(pack_log, program.bit_length)))
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
# The resulting physical address
@@ -401,32 +584,37 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
# First we scan the stash for the block we need
condition1 = self.bit_type.Array(self.T)
@lib.for_range_opt_multithread(8, self.T)
def _(i):
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
found = sum(condition1)
# Once a block is found, we use condition2 to pick the correct item from that block
condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack))
condition2 = Array.create_from(
regint.inc(pack) == l.expand_to_vector(pack))
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
condition3 = self.bit_type.Array(self.T * pack)
@lib.for_range_opt_multithread(8, [self.T, pack])
def _(i, j):
condition3[i*pack + j] = condition1[i] & condition2[j]
# Finally we use condition3 to conditionally write p
@lib.for_range(self.t)
def _(i):
@lib.for_range(pack)
def _(j):
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
if debug:
if trace:
@lib.if_(condition1[i].reveal() == 1)
def _():
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal())
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(
), self.stash[i].reveal())
# Then we try and retrieve the item from the shuffle (the actual memory)
if debug:
if trace:
@lib.if_(found.reveal() == 0)
def _():
lib.print_ln('\t%s Position not in stash', self.depth)
@@ -438,15 +626,16 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
# The block retrieved from the shuffle
block_p_prime: Array = self.shuffle[p_prime]
if debug:
if trace:
@lib.if_e(found.reveal() == 0)
def _():
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
@lib.else_
def __():
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
# We add the retrieved block from the shuffle to the stash
self.stash[self.t].assign(block_p_prime[:])
@@ -458,7 +647,9 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
# Retrieve l'th item from block
# l is secret, so we must use linear scan
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack))
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
pack)) & condition.expand_to_vector(pack))
@lib.for_range_opt(pack)
def _(i):
p.write((hit[i]).if_else(block_p_prime[i], p))
@@ -469,67 +660,63 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
def reinitialize(self, *permutation: T):
SqrtOram.reinitialize(self, *permutation)
class LinearPositionMap(PositionMap):
physical: Array
used: Array
def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None:
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
PositionMap.__init__(self, len(data), value_type, k=k)
self.physical = data
self.used = self.bit_type.Array(self.n)
# Initialize random temp variables needed during the computation
self.physical_demux: Array = self.bit_type.Array(self.n)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
"""
This method corresponds to GetPosBase in the paper.
"""
super().get_position(logical_address, fake)
fake = MemValue(self.bit_type(fake))
logical_address = MemValue(logical_address)
p: MemValue = MemValue(self.index_type(-1))
done: B = self.bit_type(False)
if multithreading:
conditions:Array = self.bit_type.Array(self.n)
conditions.assign_all(0)
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
self.physical_demux.assign_all(0)
@lib.for_range_opt_multithread(8, self.n)
def condition_i(i):
self.physical_demux.assign((self.bit_type(fake).bit_not()
& self.bit_type(logical_address == i)) | (fake
& self.used[i].bit_not()), base=i)
@lib.for_range_opt_multithread(8, self.n)
def condition_i(i):
conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i)
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
# We only need once, so we pick the first one we find
@lib.for_range_opt(self.n)
def _(i):
nonlocal done
self.physical_demux[i] &= done.bit_not()
done |= self.physical_demux[i]
@lib.for_range_opt(self.n)
def _(i):
nonlocal done
conditions[i] &= done.bit_not()
done |= conditions[i]
@lib.map_sum_opt(8, self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * conditions[i]
p.write(calc_p())
# Retrieve the value from the physical memory obliviously
@lib.map_sum_opt(8, self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * self.physical_demux[i]
p.write(calc_p())
self.used.assign(self.used[:] | conditions[:])
else:
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
linear_scan = self.bit_type.Array(self.n)
@lib.for_range_opt(self.n)
def _(i):
linear_scan[i] = logical_address == i
# Update self.used
self.used.assign(self.used[:] | self.physical_demux[:])
@lib.for_range_opt(self.n)
def __(j):
nonlocal done, fake
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
p.write(condition.if_else(self.physical[j], p))
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
done = self.bit_type(condition.if_else(self.bit_type(True), done))
if debug:
if trace:
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
def _():
lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
lib.runtime_error(
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
return p.reveal()