mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
Improve multithreading and remove non-multithreaded code
This commit is contained in:
@@ -3,10 +3,10 @@ from abc import abstractmethod
|
||||
import math
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
from Compiler.program import Program
|
||||
from Compiler import util
|
||||
from Compiler import library as lib
|
||||
from Compiler.GC.types import cbit, sbit, sbitint, sbits
|
||||
from Compiler.program import Program
|
||||
from Compiler.types import (
|
||||
Array,
|
||||
MemValue,
|
||||
@@ -14,16 +14,28 @@ from Compiler.types import (
|
||||
_clear,
|
||||
_secret,
|
||||
cint,
|
||||
regint,
|
||||
sint,
|
||||
sintbit,
|
||||
regint
|
||||
)
|
||||
from oram import get_n_threads
|
||||
|
||||
program = Program.prog
|
||||
|
||||
debug = False
|
||||
debug = True
|
||||
trace = True
|
||||
n_threads = 8
|
||||
multithreading = True
|
||||
n_parallel = 1
|
||||
|
||||
def get_n_threads(n_loops):
|
||||
if n_threads is None:
|
||||
if n_loops > 2048:
|
||||
return 8
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return n_threads
|
||||
|
||||
|
||||
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
|
||||
"""Swap two positions in an Array if a condition is met.
|
||||
@@ -43,9 +55,11 @@ def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond:
|
||||
array[pos_b] = cond.if_else(array[pos_a], array[pos_b])
|
||||
array[pos_a] = cond.if_else(temp, array[pos_a])
|
||||
|
||||
|
||||
T = TypeVar("T", sint, sbitint)
|
||||
B = TypeVar("B", sintbit, sbit)
|
||||
|
||||
|
||||
class SqrtOram(Generic[T, B]):
|
||||
# TODO: Preferably this is an Array of vectors, but this is currently not supported
|
||||
# One should regard these structures as Arrays where an entry may hold more
|
||||
@@ -75,7 +89,7 @@ class SqrtOram(Generic[T, B]):
|
||||
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
|
||||
|
||||
Args:
|
||||
data (MultiArray): The data with which to initialize the ORAM. For all intents and purposes, data is regarded as a one-dimensional Array. However, one may provide a MultiArray such that every "block" can hold multiple elements (an Array).
|
||||
data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
|
||||
value_type (sint): The secret type to use, defaults to sint.
|
||||
k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
|
||||
period (int): Leave at None, this parameter is used to recursively pass down the top-level period.
|
||||
@@ -87,7 +101,8 @@ class SqrtOram(Generic[T, B]):
|
||||
self.n = math.ceil(len(data) // entry_length)
|
||||
if (len(data) % entry_length != 0):
|
||||
raise Exception('Data incorrectly padded.')
|
||||
self.shuffle = MultiArray((self.n, entry_length), value_type=value_type)
|
||||
self.shuffle = MultiArray(
|
||||
(self.n, entry_length), value_type=value_type)
|
||||
self.shuffle.assign_part_vector(data.get_vector())
|
||||
else:
|
||||
raise Exception("Incorrect format.")
|
||||
@@ -96,19 +111,23 @@ class SqrtOram(Generic[T, B]):
|
||||
if value_type != sint and value_type != sbitint:
|
||||
raise Exception("The value_type must be either sint or sbitint")
|
||||
self.bit_type: Type[B] = value_type.bit_type
|
||||
self.index_type = value_type.get_type(util.log2(self.n))
|
||||
self.index_size = util.log2(self.n)
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.entry_length = entry_length
|
||||
|
||||
if debug:
|
||||
lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
||||
lib.print_ln(
|
||||
'Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)])
|
||||
self.shufflei = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
permutation = Array.create_from(self.shuffle_the_shuffle())
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
||||
self.T = int(math.ceil(
|
||||
math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
||||
if debug and not period:
|
||||
lib.print_ln('Period set to %s', self.T)
|
||||
# Initialize position map (recursive oram)
|
||||
@@ -119,28 +138,108 @@ class SqrtOram(Generic[T, B]):
|
||||
self.stashi = Array(self.T, value_type=value_type)
|
||||
self.t = MemValue(cint(0))
|
||||
|
||||
@lib.method_block
|
||||
def read(self, index: T):
|
||||
value = self.value_type(0, size=self.entry_length)
|
||||
return self.access(index, self.bit_type(False), *value)
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
lib.runtime_error_if(len(value) != self.entry_length, "A block must be of size entry_length")
|
||||
self.access(index, self.bit_type(True), *value)
|
||||
|
||||
__getitem__ = read
|
||||
__setitem__ = write
|
||||
# Initialize temp variables needed during the computation
|
||||
self.found_ = self.bit_type.Array(size=self.T)
|
||||
|
||||
@lib.method_block
|
||||
def access(self, index: T, write: B, *value: T):
|
||||
if debug:
|
||||
if trace:
|
||||
@lib.if_e(write.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value, size=self.entry_length).get_vector(0, size=self.entry_length)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
# This will result in a bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check wheterh the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
|
||||
# Store the stash item into the result if found
|
||||
# If the item is not in the stash, the result will simple remain 0
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
# If we are writing, we need to add the value
|
||||
self.stash[i] += write * access_here * (value - entry)
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += write * found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
@lib.if_((write * found.bit_not()).reveal())
|
||||
def _():
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
if trace:
|
||||
lib.print_ln('Writing to secret index %s', index.reveal())
|
||||
|
||||
value = self.value_type(value)
|
||||
index = MemValue(index)
|
||||
|
||||
@@ -150,87 +249,159 @@ class SqrtOram(Generic[T, B]):
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# Scan through the stash
|
||||
@lib.if_(self.t > 0)
|
||||
def _():
|
||||
nonlocal found
|
||||
found |= index == self.stashi[0]
|
||||
# We ensure that if the item is found in stash, it ends up in the first
|
||||
# position (more importantly, a fixed position) of the stash
|
||||
# This allows us to keep track of it in an oblivious manner
|
||||
if multithreading:
|
||||
found_ = self.bit_type.Array(size=self.T)
|
||||
@lib.multithread(1, self.T)
|
||||
def _(base, size):
|
||||
found_.assign_vector((self.stashi.get_vector(base, size) == index.expand_to_vector(size))
|
||||
& self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
@lib.for_range_opt(self.t - 1)
|
||||
def _(i):
|
||||
swap(self.stash, 0, i + 1, found_[i+1])
|
||||
swap(self.stashi, 0, i + 1, found_[i+1])
|
||||
found.write(sum(found_))
|
||||
else:
|
||||
@lib.for_range_opt(self.t - 1)
|
||||
def _(i):
|
||||
nonlocal found
|
||||
found_: B = index == self.stashi[i + 1]
|
||||
swap(self.stash, 0, i + 1, found_)
|
||||
swap(self.stashi, 0, i + 1, found_)
|
||||
found |= found_
|
||||
# If the item was not in the stash, we move the unknown and unimportant
|
||||
# stash[0] out of the way (to the end of the stash)
|
||||
swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
|
||||
swap(self.stashi, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
if debug:
|
||||
# This will result in an bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check wheterh the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
# We update the stash value
|
||||
self.stash[i] += access_here * (value - entry)
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tItem not in stash')
|
||||
lib.print_ln('\tMoved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was in the stash (thus currently residing in stash[0]),
|
||||
# we place the random item retrieved from the shuffle at the end of the stash
|
||||
self.stash[self.t].assign(found.if_else(
|
||||
self.shuffle[physical_address][:],
|
||||
self.stash[self.t][:]))
|
||||
self.stashi[self.t] = found.if_else(
|
||||
self.shufflei[physical_address],
|
||||
self.stashi[self.t])
|
||||
# If the item was not found in the stash, we place the item retrieved
|
||||
# from the shuffle (the item we are actually looking for) in stash[0]
|
||||
self.stash[0].assign(found.bit_not().if_else(
|
||||
self.shuffle[physical_address][:],
|
||||
self.stash[0][:]))
|
||||
self.stashi[0] = found.bit_not().if_else(
|
||||
self.shufflei[physical_address],
|
||||
self.stashi[0])
|
||||
# If the item was not found in the stash
|
||||
# ...we update the item in the shuffle
|
||||
self.shuffle[physical_address] += found.bit_not() * (value - self.shuffle[physical_address][:])
|
||||
# ...and the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if debug:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
if trace:
|
||||
@lib.if_(found.bit_not().reveal())
|
||||
def _():
|
||||
lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[t]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal())
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tMoved shuffle[%s]=(%s: %s) to stash[0]', physical_address, self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal())
|
||||
lib.print_ln('Wrote (%s: %s) to shuffle[%s]', self.stashi[self.t].reveal(), self.shuffle[physical_address].reveal(), physical_address)
|
||||
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
self.stash[0].assign(write.if_else(value, self.stash[0][:]))
|
||||
value=write.bit_not().if_else(self.stash[0][:], value)
|
||||
return value
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def read(self, index: T, *value: T):
|
||||
if trace:
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
value = self.value_type(value)
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
def _():
|
||||
self.refresh()
|
||||
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
result: T = MemValue(self.value_type(0, size=self.entry_length))
|
||||
|
||||
# First we scan the stash for the item
|
||||
self.found_.assign_all(0)
|
||||
|
||||
# This will result in a bit array with at most one True,
|
||||
# indicating where in the stash 'index' is found
|
||||
@lib.multithread(get_n_threads(self.T), self.T)
|
||||
def _(base, size):
|
||||
self.found_.assign_vector(
|
||||
(self.stashi.get_vector(base, size) == index.expand_to_vector(size)) & \
|
||||
self.bit_type(regint.inc(size, base=base) < self.t.expand_to_vector(size)),
|
||||
base=base)
|
||||
|
||||
# To determine whether the item is found in the stash, we simply
|
||||
# check wheterh the demuxed array contains a True
|
||||
# TODO: What if the index=0?
|
||||
found.write(sum(self.found_))
|
||||
|
||||
# Store the stash item into the result if found
|
||||
# If the item is not in the stash, the result will simple remain 0
|
||||
@lib.map_sum(get_n_threads(self.T), n_parallel, self.T,
|
||||
self.entry_length, [self.value_type] * self.entry_length)
|
||||
def stash_item(i):
|
||||
entry = self.stash[i][:]
|
||||
access_here = self.found_[i]
|
||||
# This is a bit unfortunate
|
||||
# We should loop from 0 to self.t, but t is dynamic thus this is impossible.
|
||||
# Therefore we loop till self.T (the max value of self.t)
|
||||
# is_in_time = i < self.t
|
||||
|
||||
return (entry * access_here)[:]
|
||||
result += self.value_type(stash_item(), size=self.entry_length)
|
||||
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\tFound item in stash')
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\tDid not find item in stash')
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
physical_address = self.position_map.get_position(index, found)
|
||||
# We set shuffle_used to True, to track that this shuffle item needs to be refreshed
|
||||
# with its equivalent on the stash once the period is up.
|
||||
self.shuffle_used[physical_address] = cbit(True)
|
||||
|
||||
# If the item was not found in the stash
|
||||
# the item retrieved from the shuffle is our result
|
||||
result += self.shuffle[physical_address] * found.bit_not()
|
||||
# We append the newly retrieved item to the stash
|
||||
self.stash[self.t].assign(self.shuffle[physical_address][:])
|
||||
self.stashi[self.t] = self.shufflei[physical_address]
|
||||
|
||||
if trace:
|
||||
lib.print_ln('\tAppended shuffle[%s]=(%s: %s) to stash at position t=%s', physical_address,
|
||||
self.shufflei[physical_address].reveal(), self.shuffle[physical_address].reveal(), self.t)
|
||||
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
return result
|
||||
|
||||
__getitem__ = read
|
||||
__setitem__ = write
|
||||
|
||||
@lib.method_block
|
||||
def shuffle_the_shuffle(self):
|
||||
@@ -242,12 +413,15 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
# Random permutation on n elements
|
||||
random_shuffle = sint.get_secure_shuffle(self.n)
|
||||
if debug: lib.print_ln('\tGenerated shuffle')
|
||||
if trace:
|
||||
lib.print_ln('\tGenerated shuffle')
|
||||
# Apply the random permutation
|
||||
self.shuffle.secure_permute(random_shuffle)
|
||||
if debug: lib.print_ln('\tShuffled shuffle')
|
||||
if trace:
|
||||
lib.print_ln('\tShuffled shuffle')
|
||||
self.shufflei.secure_permute(random_shuffle)
|
||||
if debug: lib.print_ln('\tShuffled shuffle indexes')
|
||||
if trace:
|
||||
lib.print_ln('\tShuffled shuffle indexes')
|
||||
# Calculate the permutation that would have produced the newly produced
|
||||
# shuffle order. This can be calculated by regarding the logical
|
||||
# indexes (shufflei) as a permutation and calculating its inverse,
|
||||
@@ -256,7 +430,8 @@ class SqrtOram(Generic[T, B]):
|
||||
# random_shuffle, as the shuffle may already be out of order (e.g. when
|
||||
# refreshing).
|
||||
permutation = MemValue(self.shufflei[:].inverse_permutation())
|
||||
if debug: lib.print_ln('\tCalculated inverse permutation')
|
||||
if trace:
|
||||
lib.print_ln('\tCalculated inverse permutation')
|
||||
return permutation
|
||||
|
||||
@lib.method_block
|
||||
@@ -266,10 +441,12 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
This must happen after T (period) accesses to the ORAM."""
|
||||
|
||||
if debug: lib.print_ln('Refreshing SqrtORAM')
|
||||
if trace:
|
||||
lib.print_ln('Refreshing SqrtORAM')
|
||||
|
||||
# Shuffle and emtpy the stash, and store elements back into shuffle
|
||||
j = MemValue(cint(0,size=1))
|
||||
j = MemValue(cint(0, size=1))
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
@lib.if_(self.shuffle_used[i])
|
||||
@@ -301,13 +478,14 @@ class SqrtOram(Generic[T, B]):
|
||||
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
# Reset shuffle_used
|
||||
self.shuffle_used.assign_all(0)
|
||||
|
||||
# Note that the self.shuffle is actually a MultiArray
|
||||
# This structure is preserved while overwriting the values using
|
||||
# assign_vector
|
||||
self.shuffle.assign_vector(self.value_type(data, size=self.n * self.entry_length))
|
||||
self.shuffle.assign_vector(self.value_type(
|
||||
data, size=self.n * self.entry_length))
|
||||
permutation = self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*permutation)
|
||||
|
||||
@@ -316,13 +494,13 @@ class PositionMap(Generic[T, B]):
|
||||
PACK_LOG: int = 3
|
||||
PACK: int = 1 << PACK_LOG
|
||||
|
||||
n: int # n in the paper
|
||||
depth: int # k in the paper
|
||||
n: int # n in the paper
|
||||
depth: int # k in the paper
|
||||
value_type: Type[T]
|
||||
|
||||
def __init__(self, n: int, value_type: Type[T] = sint, k:int = -1) -> None:
|
||||
def __init__(self, n: int, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
self.n = n
|
||||
self.depth=MemValue(cint(k))
|
||||
self.depth = MemValue(cint(k))
|
||||
self.value_type = value_type
|
||||
self.bit_type = value_type.bit_type
|
||||
self.index_type = self.value_type.get_type(util.log2(n))
|
||||
@@ -330,8 +508,9 @@ class PositionMap(Generic[T, B]):
|
||||
@abstractmethod
|
||||
def get_position(self, logical_address: _secret, fake: B) -> Any:
|
||||
"""Retrieve the block at the given (secret) logical address."""
|
||||
if debug:
|
||||
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth, self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
if trace:
|
||||
lib.print_ln('\t%s Scanning %s for logical address %s (fake=%s)', self.depth,
|
||||
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
"""Reinitialize this PositionMap.
|
||||
@@ -352,11 +531,13 @@ class PositionMap(Generic[T, B]):
|
||||
|
||||
if n / PositionMap.PACK <= period:
|
||||
if debug:
|
||||
lib.print_ln('Initializing LinearPositionMap at depth %s of size %s', k, n)
|
||||
lib.print_ln(
|
||||
'Initializing LinearPositionMap at depth %s of size %s', k, n)
|
||||
res = LinearPositionMap(permutation, value_type, k=k)
|
||||
else:
|
||||
if debug:
|
||||
lib.print_ln('Initializing RecursivePositionMap at depth %s of size %s', k, n)
|
||||
lib.print_ln(
|
||||
'Initializing RecursivePositionMap at depth %s of size %s', k, n)
|
||||
res = RecursivePositionMap(permutation, period, value_type, k=k)
|
||||
|
||||
return res
|
||||
@@ -364,7 +545,7 @@ class PositionMap(Generic[T, B]):
|
||||
|
||||
class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
|
||||
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k:int=-1) -> None:
|
||||
def __init__(self, permutation: Array, period: int, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
PositionMap.__init__(self, len(permutation), k=k)
|
||||
pack = PositionMap.PACK
|
||||
|
||||
@@ -377,7 +558,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
permutation[i*pack:(i+1)*pack])
|
||||
|
||||
# TODO: Should this be n or packed_size?
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type, period=period, entry_length=pack, k=self.depth)
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
||||
period=period, entry_length=pack, k=self.depth)
|
||||
|
||||
@lib.method_block
|
||||
def get_position(self, logical_address: T, fake: B) -> _clear:
|
||||
@@ -389,7 +571,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
# The item at logical_address
|
||||
# will be in block with index h (block.<h>)
|
||||
# at position l in block.data (block.data<l>)
|
||||
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(logical_address).right_shift(pack_log, program.bit_length)))
|
||||
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
|
||||
logical_address).right_shift(pack_log, program.bit_length)))
|
||||
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))
|
||||
|
||||
# The resulting physical address
|
||||
@@ -401,32 +584,37 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
|
||||
# First we scan the stash for the block we need
|
||||
condition1 = self.bit_type.Array(self.T)
|
||||
|
||||
@lib.for_range_opt_multithread(8, self.T)
|
||||
def _(i):
|
||||
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
|
||||
found = sum(condition1)
|
||||
# Once a block is found, we use condition2 to pick the correct item from that block
|
||||
condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack))
|
||||
condition2 = Array.create_from(
|
||||
regint.inc(pack) == l.expand_to_vector(pack))
|
||||
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
|
||||
condition3 = self.bit_type.Array(self.T * pack)
|
||||
|
||||
@lib.for_range_opt_multithread(8, [self.T, pack])
|
||||
def _(i, j):
|
||||
condition3[i*pack + j] = condition1[i] & condition2[j]
|
||||
# Finally we use condition3 to conditionally write p
|
||||
|
||||
@lib.for_range(self.t)
|
||||
def _(i):
|
||||
@lib.for_range(pack)
|
||||
def _(j):
|
||||
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
|
||||
|
||||
if debug:
|
||||
if trace:
|
||||
@lib.if_(condition1[i].reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal())
|
||||
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(
|
||||
), self.stash[i].reveal())
|
||||
|
||||
# Then we try and retrieve the item from the shuffle (the actual memory)
|
||||
|
||||
if debug:
|
||||
if trace:
|
||||
@lib.if_(found.reveal() == 0)
|
||||
def _():
|
||||
lib.print_ln('\t%s Position not in stash', self.depth)
|
||||
@@ -438,15 +626,16 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
# The block retrieved from the shuffle
|
||||
block_p_prime: Array = self.shuffle[p_prime]
|
||||
|
||||
if debug:
|
||||
if trace:
|
||||
@lib.if_e(found.reveal() == 0)
|
||||
def _():
|
||||
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
|
||||
# We add the retrieved block from the shuffle to the stash
|
||||
self.stash[self.t].assign(block_p_prime[:])
|
||||
@@ -458,7 +647,9 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
|
||||
# Retrieve l'th item from block
|
||||
# l is secret, so we must use linear scan
|
||||
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack))
|
||||
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(
|
||||
pack)) & condition.expand_to_vector(pack))
|
||||
|
||||
@lib.for_range_opt(pack)
|
||||
def _(i):
|
||||
p.write((hit[i]).if_else(block_p_prime[i], p))
|
||||
@@ -469,67 +660,63 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
def reinitialize(self, *permutation: T):
|
||||
SqrtOram.reinitialize(self, *permutation)
|
||||
|
||||
|
||||
class LinearPositionMap(PositionMap):
|
||||
physical: Array
|
||||
used: Array
|
||||
|
||||
def __init__(self, data: Array, value_type: Type[T] = sint, k:int =-1) -> None:
|
||||
def __init__(self, data: Array, value_type: Type[T] = sint, k: int = -1) -> None:
|
||||
PositionMap.__init__(self, len(data), value_type, k=k)
|
||||
self.physical = data
|
||||
self.used = self.bit_type.Array(self.n)
|
||||
|
||||
# Initialize random temp variables needed during the computation
|
||||
self.physical_demux: Array = self.bit_type.Array(self.n)
|
||||
|
||||
@lib.method_block
|
||||
def get_position(self, logical_address: T, fake: B) -> _clear:
|
||||
"""
|
||||
This method corresponds to GetPosBase in the paper.
|
||||
"""
|
||||
super().get_position(logical_address, fake)
|
||||
|
||||
fake = MemValue(self.bit_type(fake))
|
||||
logical_address = MemValue(logical_address)
|
||||
|
||||
p: MemValue = MemValue(self.index_type(-1))
|
||||
done: B = self.bit_type(False)
|
||||
|
||||
if multithreading:
|
||||
conditions:Array = self.bit_type.Array(self.n)
|
||||
conditions.assign_all(0)
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
self.physical_demux.assign_all(0)
|
||||
@lib.for_range_opt_multithread(8, self.n)
|
||||
def condition_i(i):
|
||||
self.physical_demux.assign((self.bit_type(fake).bit_not()
|
||||
& self.bit_type(logical_address == i)) | (fake
|
||||
& self.used[i].bit_not()), base=i)
|
||||
|
||||
@lib.for_range_opt_multithread(8, self.n)
|
||||
def condition_i(i):
|
||||
conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i)
|
||||
# In the event that fake=True, there are likely multiple entried in physical_demux set to True (i.e. where self.used[i] = False)
|
||||
# We only need once, so we pick the first one we find
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
nonlocal done
|
||||
self.physical_demux[i] &= done.bit_not()
|
||||
done |= self.physical_demux[i]
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
nonlocal done
|
||||
conditions[i] &= done.bit_not()
|
||||
done |= conditions[i]
|
||||
@lib.map_sum_opt(8, self.n, [self.value_type])
|
||||
def calc_p(i):
|
||||
return self.physical[i] * conditions[i]
|
||||
p.write(calc_p())
|
||||
# Retrieve the value from the physical memory obliviously
|
||||
@lib.map_sum_opt(8, self.n, [self.value_type])
|
||||
def calc_p(i):
|
||||
return self.physical[i] * self.physical_demux[i]
|
||||
p.write(calc_p())
|
||||
|
||||
self.used.assign(self.used[:] | conditions[:])
|
||||
else:
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
linear_scan = self.bit_type.Array(self.n)
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
linear_scan[i] = logical_address == i
|
||||
# Update self.used
|
||||
self.used.assign(self.used[:] | self.physical_demux[:])
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def __(j):
|
||||
nonlocal done, fake
|
||||
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
|
||||
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
|
||||
p.write(condition.if_else(self.physical[j], p))
|
||||
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
|
||||
done = self.bit_type(condition.if_else(self.bit_type(True), done))
|
||||
|
||||
if debug:
|
||||
if trace:
|
||||
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
|
||||
def _():
|
||||
lib.runtime_error('%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
|
||||
lib.runtime_error(
|
||||
'%s Did not find requested logical_address in shuffle, something went wrong.', self.depth)
|
||||
|
||||
return p.reveal()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user