Add allow_memory_allocation option to SqrtORAM

Also remove unused swap function in SqrtORAM
This commit is contained in:
Kevin Witlox
2022-08-11 14:41:20 +02:00
parent 2cd263dad0
commit fb5871a2f8

View File

@@ -1,32 +1,34 @@
from __future__ import annotations
from abc import abstractmethod
import math
from abc import abstractmethod
from typing import Any, Generic, Type, TypeVar
from Compiler import util
from Compiler import library as lib
from Compiler import util
from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.program import Program
from Compiler.types import (
Array,
MemValue,
MultiArray,
_clear,
_secret,
cint,
regint,
sint,
sintbit,
)
from oram import get_n_threads
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
regint, sint, sintbit)
from oram import demux_array, get_n_threads
program = Program.prog
debug = True
trace = True
n_threads = 8
# Adds messages on completion of heavy computation steps
debug = False
# Finer grained trace of steps that the ORAM performs
# + runtime error checks
# Warning: reveals information and makes the computation insecure
trace = False
n_threads = 16
n_parallel = 1
# Avoids any memory allocation
# This prevents some optimizations but allows for using the ORAMs outside of the main tape
allow_memory_allocation = True
def get_n_threads(n_loops):
if n_threads is None:
if n_loops > 2048:
@@ -37,25 +39,6 @@ def get_n_threads(n_loops):
return n_threads
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
"""Swap two positions in an Array if a condition is met.
Args:
array (Array | MultiArray): The array in which to swap the first and second position
pos_a (int | cint): The first position
pos_b (int | cint): The second position
cond (sintbit | sbit): The condition determining whether to swap
"""
if isinstance(array, MultiArray):
temp = array[pos_b][:]
array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:]))
array[pos_a].assign(cond.if_else(temp, array[pos_a][:]))
if isinstance(array, Array):
temp = array[pos_b]
array[pos_b] = cond.if_else(array[pos_a], array[pos_b])
array[pos_a] = cond.if_else(temp, array[pos_a])
T = TypeVar("T", sint, sbitint)
B = TypeVar("B", sintbit, sbit)
@@ -85,7 +68,7 @@ class SqrtOram(Generic[T, B]):
# the stash)
t: cint
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True) -> None:
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
Args:
@@ -94,6 +77,9 @@ class SqrtOram(Generic[T, B]):
k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
period (int): Leave at None, this parameter is used to recursively pass down the top-level period.
"""
global debug, allow_memory_allocation
# Correctly initialize the shuffle (memory) depending on the type of data
if isinstance(data, MultiArray):
self.shuffle = data
self.n = len(data)
@@ -107,9 +93,12 @@ class SqrtOram(Generic[T, B]):
else:
raise Exception("Incorrect format.")
self.value_type = value_type
# Only sint is supported
if value_type != sint and value_type != sbitint:
raise Exception("The value_type must be either sint or sbitint")
# Set derived constants
self.value_type = value_type
self.bit_type: Type[B] = value_type.bit_type
self.index_size = util.log2(self.n)
self.index_type = value_type.get_type(self.index_size)
@@ -118,11 +107,11 @@ class SqrtOram(Generic[T, B]):
if debug:
lib.print_ln(
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shufflei = Array.create_from(
[self.index_type(i) for i in range(self.n)])
permutation = Array.create_from(self.shuffle_the_shuffle())
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
@@ -130,8 +119,21 @@ class SqrtOram(Generic[T, B]):
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
if debug and not period:
lib.print_ln('Period set to %s', self.T)
# Here we allocate the memory for the permutation
# Note that self.shuffle_the_shuffle mutates this field
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
self.permutation = Array.create_from(
[self.index_type(i) for i in range(self.n)])
# We allow the caller to postpone the initialization of the shuffle
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
# Note that if you do not initialize, the ORAM is insecure
if initialize:
self.shuffle_the_shuffle()
else:
print('You are opting out of default initialization for SqrtORAM. Be sure to call refresh before using the SqrtORAM, otherwise the ORAM is not secure.')
# Initialize position map (recursive oram)
self.position_map = PositionMap.create(permutation, k + 1, self.T)
self.position_map = PositionMap.create(self.permutation, k + 1, self.T)
# Initialize stash
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
@@ -140,19 +142,28 @@ class SqrtOram(Generic[T, B]):
# Initialize temp variables needed during the computation
self.found_ = self.bit_type.Array(size=self.T)
self.j = MemValue(cint(0, size=1))
# To prevent the compiler from recompiling the same code over and over again, we should use @method_block
# However, @method_block requires allocation (of return address), which is not allowed when not in the main thread
# Therefore, we only conditionally wrap the methods in a @method_block if we are guaranteed to be running in the main thread
self.shuffle_the_shuffle = lib.method_block(self.shuffle_the_shuffle) if allow_memory_allocation else self.shuffle_the_shuffle
self.refresh = lib.method_block(self.refresh) if allow_memory_allocation else self.refresh
@lib.method_block
def access(self, index: T, write: B, *value: T):
global trace,n_parallel
if trace:
@lib.if_e(write.reveal() == 1)
def _():
lib.print_ln('Writing to secret index %s', index.reveal())
lib.print_ln(' Writing to secret index %s', index.reveal())
@lib.else_
def __():
lib.print_ln('Reading from secret index %s', index.reveal())
lib.print_ln(' Reading from secret index %s', index.reveal())
value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length)
value = self.value_type(value, size=self.entry_length).get_vector(
0, size=self.entry_length)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@@ -171,8 +182,9 @@ class SqrtOram(Generic[T, B]):
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
@@ -200,11 +212,11 @@ class SqrtOram(Generic[T, B]):
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
lib.print_ln(' Found item in stash')
@lib.else_
def __():
lib.print_ln('\tDid not find item in stash')
lib.print_ln(' Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
@@ -215,7 +227,8 @@ class SqrtOram(Generic[T, B]):
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:])
self.shuffle[physical_address] += write * \
found.bit_not() * (value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
@@ -225,10 +238,8 @@ class SqrtOram(Generic[T, B]):
if trace:
@lib.if_((write * found.bit_not()).reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
@@ -237,8 +248,9 @@ class SqrtOram(Generic[T, B]):
@lib.method_block
def write(self, index: T, *value: T):
global trace, n_parallel
if trace:
lib.print_ln('Writing to secret index %s', index.reveal())
lib.print_ln(' Writing to secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
@@ -259,8 +271,9 @@ class SqrtOram(Generic[T, B]):
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
@@ -286,11 +299,11 @@ class SqrtOram(Generic[T, B]):
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
lib.print_ln(' Found item in stash')
@lib.else_
def __():
lib.print_ln('\tDid not find item in stash')
lib.print_ln(' Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
@@ -301,7 +314,8 @@ class SqrtOram(Generic[T, B]):
# If the item was not found in the stash
# ...we update the item in the shuffle
self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:])
self.shuffle[physical_address] += found.bit_not() * \
(value - self.shuffle[physical_address][:])
# ...and the item retrieved from the shuffle is our result
result += self.shuffle[physical_address] * found.bit_not()
# We append the newly retrieved item to the stash
@@ -311,9 +325,10 @@ class SqrtOram(Generic[T, B]):
if trace:
@lib.if_(found.bit_not().reveal())
def _():
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln(' Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(
), self.shuffle[physical_address].reveal(), physical_address)
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
@@ -323,14 +338,20 @@ class SqrtOram(Generic[T, B]):
@lib.method_block
def read(self, index: T, *value: T):
global debug, trace, n_parallel
if trace:
lib.print_ln('Reading from secret index %s', index.reveal())
lib.print_ln(' Reading from secret index %s', index.reveal())
value = self.value_type(value)
index = MemValue(index)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
def _():
if debug:
lib.print_ln(' Refreshing SqrtORAM')
lib.print_ln(' t=%s according to me', self.t)
self.refresh()
found: B = MemValue(self.bit_type(False))
@@ -344,8 +365,9 @@ class SqrtOram(Generic[T, B]):
@lib.multithread(get_n_threads(self.T), self.T)
def _(base, size):
self.found_.assign_vector(
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) &
self.bit_type(regint.inc(size, base=base) <
self.t.expand_to_vector(size)),
base=base)
# To determine whether the item is found in the stash, we simply
@@ -371,11 +393,11 @@ class SqrtOram(Generic[T, B]):
if trace:
@lib.if_e(found.reveal() == 1)
def _():
lib.print_ln('\tFound item in stash')
lib.print_ln(' Found item in stash')
@lib.else_
def __():
lib.print_ln('\tDid not find item in stash')
lib.print_ln(' Did not find item in stash')
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
@@ -392,7 +414,7 @@ class SqrtOram(Generic[T, B]):
self.stashi[self.t] = self.shufflei[physical_address]
if trace:
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
lib.print_ln(' Appended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
# Increase the "time" (i.e. access count in current period)
@@ -403,25 +425,36 @@ class SqrtOram(Generic[T, B]):
__getitem__ = read
__setitem__ = write
@lib.method_block
def shuffle_the_shuffle(self):
def shuffle_the_shuffle(self) -> None:
"""Permute the memory using a newly generated permutation and return
the permutation that would generate this particular shuffling.
This permutation is needed to know how to map logical addresses to
physical addresses, and is used as such by the postition map."""
global trace
# Random permutation on n elements
random_shuffle = sint.get_secure_shuffle(self.n)
if trace:
lib.print_ln('\tGenerated shuffle')
lib.print_ln(' Generated shuffle')
# Apply the random permutation
self.shuffle.secure_permute(random_shuffle)
if trace:
lib.print_ln('\tShuffled shuffle')
lib.print_ln(' Shuffled shuffle')
self.shufflei.secure_permute(random_shuffle)
if trace:
lib.print_ln('\tShuffled shuffle indexes')
lib.print_ln(' Shuffled shuffle indexes')
if trace:
# If shufflei does not contain exactly the indices [i for i in
# range(self.n)], the underlying waksman network of
# 'inverse_permutation' will hang.
tmp_shuffli = Array.create_from(self.shufflei[:])
@lib.if_(sum(lib.sort(tmp_shuffli)[:] == Array.create_from([cint(i) for i in range(self.n)])[:]).reveal() != self.n)
def _():
lib.print_ln(
'Shufflei is corrupted! You have found a bug in the implementation :c\nThe computation will now hang...')
# Calculate the permutation that would have produced the newly produced
# shuffle order. This can be calculated by regarding the logical
# indexes (shufflei) as a permutation and calculating its inverse,
@@ -429,45 +462,45 @@ class SqrtOram(Generic[T, B]):
# this is not necessarily equal to the inverse of the above generated
# random_shuffle, as the shuffle may already be out of order (e.g. when
# refreshing).
permutation = MemValue(self.shufflei[:].inverse_permutation())
self.permutation.assign(self.shufflei[:].inverse_permutation())
if trace:
lib.print_ln('\tCalculated inverse permutation')
return permutation
lib.print_ln(' Calculated inverse permutation')
@lib.method_block
def refresh(self):
"""Refresh the ORAM by reinserting the stash back into the shuffle, and
reshuffling the shuffle.
This must happen after T (period) accesses to the ORAM."""
if trace:
lib.print_ln('Refreshing SqrtORAM')
This must happen on the T'th (period) accesses to the ORAM."""
self.j.write(0)
# Shuffle and emtpy the stash, and store elements back into shuffle
j = MemValue(cint(0, size=1))
@lib.for_range_opt(self.n)
def _(i):
@lib.if_(self.shuffle_used[i])
def _():
nonlocal j
self.shuffle[i] = self.stash[j]
self.shufflei[i] = self.stashi[j]
j += 1
self.shuffle[i] = self.stash[self.j]
self.shufflei[i] = self.stashi[self.j]
self.j += 1
# Reset the clock
self.t.write(0)
# Reset shuffle_used
self.shuffle_used.assign_all(0)
global allow_memory_allocation
if allow_memory_allocation:
self.shuffle_used.assign_all(0)
else:
@lib.for_range_opt(self.n)
def _(i):
self.shuffle_used[i] = cint(0)
# Reinitialize position map
permutation = self.shuffle_the_shuffle()
self.shuffle_the_shuffle()
# Note that we skip here the step of "packing" the permutation.
# Since the underlying memory of the position map is already aligned in
# this packed structure, we can simply overwrite the memory while
# maintaining the structure.
self.position_map.reinitialize(*permutation)
self.position_map.reinitialize(*self.permutation)
@lib.method_block
def reinitialize(self, *data: T):
@@ -478,7 +511,7 @@ class SqrtOram(Generic[T, B]):
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
# Reset the clock
self.t.write(0)
# Reset shuffle_used
# Reset shuffle_used
self.shuffle_used.assign_all(0)
# Note that the self.shuffle is actually a MultiArray
@@ -486,8 +519,10 @@ class SqrtOram(Generic[T, B]):
# assign_vector
self.shuffle.assign_vector(self.value_type(
data, size=self.n * self.entry_length))
permutation = self.shuffle_the_shuffle()
self.position_map.reinitialize(*permutation)
# Note that this updates self.permutation (see constructor for explanation)
self.shuffle_the_shuffle()
self.position_map.reinitialize(*self.permutation)
class PositionMap(Generic[T, B]):
@@ -508,8 +543,9 @@ class PositionMap(Generic[T, B]):
@abstractmethod
def get_position(self, logical_address: _secret, fake: B) -> Any:
"""Retrieve the block at the given (secret) logical address."""
global trace
if trace:
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth,
lib.print_ln(' %s Scanning %s for logical address %s (fake=%s)', self.depth,
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
def reinitialize(self, *permutation: T):
@@ -529,6 +565,7 @@ class PositionMap(Generic[T, B]):
a LinearPositionMap."""
n = len(permutation)
global debug
if n / PositionMap.PACK <= period:
if debug:
lib.print_ln(
@@ -561,6 +598,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
SqrtOram.__init__(self, packed_structure, value_type=value_type,
period=period, entry_length=pack, k=self.depth)
# Initialize random temp variables needed during the computation
self.block_index_demux: Array = self.bit_type.Array(self.T)
self.element_index_demux: Array = self.bit_type.Array(PositionMap.PACK)
@lib.method_block
def get_position(self, logical_address: T, fake: B) -> _clear:
super().get_position(logical_address, fake)
@@ -576,50 +617,42 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
# The resulting physical address
p = MemValue(self.index_type(0))
p = MemValue(self.index_type(-1))
found: B = MemValue(self.bit_type(False))
# First we try and retrieve the item from the stash at position stash[h][l]
# Since h and l are secret, we do this by scanning the entire stash
# First we scan the stash for the block we need
condition1 = self.bit_type.Array(self.T)
self.block_index_demux.assign_all(0)
@lib.for_range_opt_multithread(8, self.T)
@lib.for_range_opt_multithread(get_n_threads(self.T), self.T)
def _(i):
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
found = sum(condition1)
# Once a block is found, we use condition2 to pick the correct item from that block
condition2 = Array.create_from(
regint.inc(pack) == l.expand_to_vector(pack))
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
condition3 = self.bit_type.Array(self.T * pack)
self.block_index_demux[i] = (
self.stashi[i] == h) & self.bit_type(i < self.t)
# We can determine if the 'index' is in the stash by checking the
# block_index_demux array
found = sum(self.block_index_demux)
# Once a block is found, we use the following condition to pick the correct item from that block
demux_array(l.bit_decompose(PositionMap.PACK_LOG), self.element_index_demux)
@lib.for_range_opt_multithread(8, [self.T, pack])
def _(i, j):
condition3[i*pack + j] = condition1[i] & condition2[j]
# Finally we use condition3 to conditionally write p
@lib.for_range(self.t)
def _(i):
@lib.for_range(pack)
def _(j):
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
if trace:
@lib.if_(condition1[i].reveal() == 1)
def _():
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(
), self.stash[i].reveal())
# Then we try and retrieve the item from the shuffle (the actual memory)
# Finally we use the conditions to conditionally write p
@lib.map_sum(get_n_threads(self.T * pack), n_parallel, self.T * pack, 1, [self.value_type])
def p_(i):
# We should loop from 0 through self.t, but runtime loop lengths are not supported by map_sum
# Therefore we include the check (i < self.t)
return self.stash[i // pack][i % pack] * self.block_index_demux[i // pack] * self.element_index_demux[i % pack] * (i < self.t)
p.write(p_())
global trace
if trace:
@lib.if_(found.reveal() == 0)
def _():
lib.print_ln('\t%s Position not in stash', self.depth)
lib.print_ln(' %s Position not in stash', self.depth)
# Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle
# Then we try and retrieve the item from the shuffle (the actual memory)
# Depending on whether we found the item in the stash, we either
# block 'h' in which 'index' resides, or a random block from the shuffle
p_prime = self.position_map.get_position(h, found)
self.shuffle_used[p_prime] = cbit(True)
@@ -629,12 +662,12 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
if trace:
@lib.if_e(found.reveal() == 0)
def _():
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
lib.print_ln(' %s Retrieved stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
@lib.else_
def __():
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
lib.print_ln(' %s Retrieved dummy stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
# We add the retrieved block from the shuffle to the stash
@@ -680,6 +713,13 @@ class LinearPositionMap(PositionMap):
"""
super().get_position(logical_address, fake)
global trace
if trace:
@lib.if_(((logical_address < 0) * (logical_address >= self.n)).reveal())
def _():
lib.runtime_error(
'logical_address must lie between 0 and self.n - 1')
fake = MemValue(self.bit_type(fake))
logical_address = MemValue(logical_address)
@@ -689,11 +729,12 @@ class LinearPositionMap(PositionMap):
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
self.physical_demux.assign_all(0)
@lib.for_range_opt_multithread(8, self.n)
@lib.for_range_opt_multithread(get_n_threads(self.n), self.n)
def condition_i(i):
self.physical_demux.assign((self.bit_type(fake).bit_not()
& self.bit_type(logical_address == i)) | (fake
& self.used[i].bit_not()), base=i)
self.physical_demux[i] = \
(self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) \
| (fake & self.used[i].bit_not())
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
# We only need once, so we pick the first one we find
@@ -704,7 +745,7 @@ class LinearPositionMap(PositionMap):
done |= self.physical_demux[i]
# Retrieve the value from the physical memory obliviously
@lib.map_sum_opt(8, self.n, [self.value_type])
@lib.map_sum_opt(get_n_threads(self.n), self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * self.physical_demux[i]
p.write(calc_p())
@@ -720,7 +761,13 @@ class LinearPositionMap(PositionMap):
return p.reveal()
@lib.method_block
def reinitialize(self, *data: T):
self.physical.assign_vector(data)
self.used.assign_all(False)
global allow_memory_allocation
if allow_memory_allocation:
self.used.assign_all(False)
else:
@lib.for_range_opt(self.n)
def _(i):
self.used[i] = self.bit_type(0)