Optimize performance of SqrtORAM

This commit is contained in:
Kevin Witlox
2022-07-29 11:24:07 +02:00
parent 06520ea7a1
commit b070c23a26

View File

@@ -1,17 +1,41 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar
import math
from typing import Any, Generic, Type, TypeVar
from Compiler.program import Program
from Compiler import util
from Compiler import library as lib
from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.oram import AbstractORAM, get_n_threads
from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint
import numpy as np
from Compiler.types import (
Array,
MemValue,
MultiArray,
_clear,
_secret,
cint,
sint,
sintbit,
regint
)
debug = True
reveal = True
program = Program.prog
debug = False
n_parallel = 1024
n_threads = 8
multithreading = True
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
"""Swap two positions in an Array if a condition is met.
Args:
array (Array | MultiArray): The array in which to swap the first and second position
pos_a (int | cint): The first position
pos_b (int | cint): The second position
cond (sintbit | sbit): The condition determining whether to swap
"""
if isinstance(array, MultiArray):
temp = array[pos_b][:]
array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:]))
@@ -49,7 +73,7 @@ class SqrtOram(Generic[T, B]):
# the stash)
t: cint
def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
Args:
@@ -64,55 +88,51 @@ class SqrtOram(Generic[T, B]):
if value_type != sint and value_type != sbitint:
raise Exception("The value_type must be either sint or sbitint")
self.bit_type: Type[B] = value_type.bit_type
self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) ))
self.index_type = value_type.get_type(util.log2(self.n))
self.entry_length = entry_length
if debug:
lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k)
self.shuffle_used = cint.Array(self.n)
# Random permutation on the data
self.shuffle = data
if isinstance(data, MultiArray):
self.shuffle = data
elif isinstance(data, sint):
self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type)
self.shuffle.assign_vector(data.get_vector())
else:
raise Exception("Incorrect format.")
self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)])
permutation = Array.create_from(self.shuffle_the_shuffle())
# Calculate the period if not given
# upon recursion, the period should stay the same ("in sync"),
# therefore it can be passed as a constructor parameter
self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1))
) if not period else period
self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
if debug and not period:
lib.print_ln('Period set to %s', self.T)
# Initialize position map (recursive oram)
self.position_map = PositionMap.create(permutation, k + 1, self.T)
# Initialize stash
self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type)
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
self.stashi = Array(self.T, value_type=value_type)
self.t = MemValue(cint(0))
@lib.method_block
def read(self, index: T):
data = self.value_type.Array(self.entry_length)
return self.access(index, self.bit_type(False), data)
value = self.value_type(0, size=self.entry_length)
return self.access(index, self.bit_type(False), *value)
def write(self, index: T, value: Array):
self.access(index, self.bit_type(True), value)
@lib.method_block
def write(self, index: T, value: T):
lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length")
self.access(index, self.bit_type(True), *value)
__getitem__ = read
__setitem__ = write
def access(self, index: T, write: B, value: Array):
if len(value) != self.entry_length:
raise Exception("A block must be of size entry_length={}".format(self.entry_length))
# Method Blocks do not accepts arrays as arguments
# workaround by temporarily storing it as a class field
# arrays are stored in memory so this is fine
index = MemValue(index)
return Array.create_from(self._access(index, write, value[:]))
@lib.method_block
def _access(self, index: T, write: B, *value: list[T]):
item: T = self.value_type(*value)
def access(self, index: T, write: B, *value: T):
if debug:
@lib.if_e(write.reveal() == 1)
def _():
@@ -120,6 +140,7 @@ class SqrtOram(Generic[T, B]):
@lib.else_
def __():
lib.print_ln('Reading from secret index %s', index.reveal())
value = self.value_type(value)
# Refresh if we have performed T (period) accesses
@lib.if_(self.t == self.T)
@@ -136,14 +157,24 @@ class SqrtOram(Generic[T, B]):
# We ensure that if the item is found in stash, it ends up in the first
# position (more importantly, a fixed position) of the stash
# This allows us to keep track of it in an oblivious manner
@lib.for_range_opt(self.t)
def _(i):
nonlocal found
found_: B = index == self.stashi[i + 1]
swap(self.stash, 0, i, found_)
swap(self.stashi, 0, i, found_)
found |= found_
# found = self.bit_type(found.bit_or(found_))
if multithreading:
found_ = self.bit_type.Array(size=self.T)
@lib.multithread(8, self.T)
def _(base, size):
found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base)
@lib.for_range_opt(self.t - 1)
def _(i):
swap(self.stash, 0, i, found_[i])
swap(self.stashi, 0, i, found_[i])
found.write(sum(found_))
else:
@lib.for_range_opt(self.t - 1)
def _(i):
nonlocal found
found_: B = index == self.stashi[i + 1]
swap(self.stash, 0, i, found_)
swap(self.stashi, 0, i, found_)
found |= found_
# If the item was not in the stash, we move the unknown and unimportant
# stash[0] out of the way (to the end of the stash)
swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
@@ -156,7 +187,7 @@ class SqrtOram(Generic[T, B]):
@lib.else_
def __():
lib.print_ln(' Item not in stash')
lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
# Possible fake lookup of the item in the shuffle,
# depending on whether we already found the item in the stash
@@ -171,14 +202,15 @@ class SqrtOram(Generic[T, B]):
self.stashi[self.t] = found.if_else(
self.shufflei[physical_address],
self.stashi[self.t])
# If the item was not found in the stash,
# we place the item retrieved from the shuffle in stash[0]
# If the item was not found in the stash, we place the item retrieved
# from the shuffle (the item we are actually looking for) in stash[0]
self.stash[0].assign(found.bit_not().if_else(
self.shuffle[physical_address][:],
self.stash[0][:]))
self.stashi[0] = found.bit_not().if_else(
self.shufflei[physical_address],
self.stashi[0])
if debug:
@lib.if_e(found.reveal() == 1)
def _():
@@ -191,9 +223,9 @@ class SqrtOram(Generic[T, B]):
# Increase the "time" (i.e. access count in current period)
self.t.iadd(1)
self.stash[0].assign(write.if_else(item, self.stash[0][:]))
item=write.bit_not().if_else(self.stash[0][:], item)
return item
self.stash[0].assign(write.if_else(value, self.stash[0][:]))
value=write.bit_not().if_else(self.stash[0][:], value)
return value
@lib.method_block
@@ -206,12 +238,12 @@ class SqrtOram(Generic[T, B]):
# Random permutation on n elements
random_shuffle = sint.get_secure_shuffle(self.n)
if debug: lib.print_ln('\tGenerated shuffle')
# Apply the random permutation
lib.print_ln('\tGenerated shuffle')
self.shuffle.secure_permute(random_shuffle)
lib.print_ln('\tShuffled shuffle')
if debug: lib.print_ln('\tShuffled shuffle')
self.shufflei.secure_permute(random_shuffle)
lib.print_ln('\tShuffled shuffle indexes')
if debug: lib.print_ln('\tShuffled shuffle indexes')
# Calculate the permutation that would have produced the newly produced
# shuffle order. This can be calculated by regarding the logical
# indexes (shufflei) as a permutation and calculating its inverse,
@@ -220,7 +252,7 @@ class SqrtOram(Generic[T, B]):
# random_shuffle, as the shuffle may already be out of order (e.g. when
# refreshing).
permutation = MemValue(self.shufflei[:].inverse_permutation())
lib.print_ln('\tCalculated inverse permutation')
if debug: lib.print_ln('\tCalculated inverse permutation')
return permutation
@lib.method_block
@@ -229,7 +261,8 @@ class SqrtOram(Generic[T, B]):
reshuffling the shuffle.
This must happen after T (period) accesses to the ORAM."""
lib.print_ln('Refreshing SqrtORAM')
if debug: lib.print_ln('Refreshing SqrtORAM')
# Shuffle and emtpy the stash, and store elements back into shuffle
j = MemValue(cint(0,size=1))
@@ -276,7 +309,7 @@ class SqrtOram(Generic[T, B]):
class PositionMap(Generic[T, B]):
PACK_LOG: int = 2
PACK_LOG: int = 3
PACK: int = 1 << PACK_LOG
n: int # n in the paper
@@ -288,7 +321,7 @@ class PositionMap(Generic[T, B]):
self.depth=MemValue(cint(k))
self.value_type = value_type
self.bit_type = value_type.bit_type
self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n))))
self.index_type = self.value_type.get_type(util.log2(n))
@abstractmethod
def get_position(self, logical_address: _secret, fake: B) -> Any:
@@ -332,7 +365,7 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
pack = PositionMap.PACK
# We pack the permutation into a smaller structure, index with a new permutation
packed_size = int(np.ceil(self.n / pack))
packed_size = int(math.ceil(self.n / pack))
packed_structure = MultiArray(
(packed_size, pack), value_type=value_type)
for i in range(packed_size):
@@ -359,28 +392,33 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
p = MemValue(self.index_type(0))
found: B = MemValue(self.bit_type(False))
# First we try and retrieve the item from the stash
# First we try and retrieve the item from the stash at position stash[h][l]
# Since h and l are secret, we do this by scanning the entire stash
# We retrieve stash[h]
# Since h is secret, we do this by scanning the entire stash
# First we scan the stash for the block we need
condition1 = self.bit_type.Array(self.T)
@lib.for_range_opt_multithread(8, self.T)
def _(i):
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
found = sum(condition1)
# Once a block is found, we use condition2 to pick the correct item from that block
condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack))
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
condition3 = self.bit_type.Array(self.T * pack)
@lib.for_range_opt_multithread(8, [self.T, pack])
def _(i, j):
condition3[i*pack + j] = condition1[i] & condition2[j]
# Finally we use condition3 to conditionally write p
@lib.for_range(self.t)
def _(j):
nonlocal found
condition = self.stashi[j] == h
found |= condition
# block = stash[h]
# block is itself an array (it holds a permutation)
# we need to grab block[l]
def _(i):
@lib.for_range(pack)
def _(i):
nonlocal condition
condition &= l == i
p.write(condition.if_else(self.stash[j][i], p))
def _(j):
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
if debug:
@lib.if_(condition.reveal() == 1)
@lib.if_(condition1[i].reveal() == 1)
def _():
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal())
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal())
# Then we try and retrieve the item from the shuffle (the actual memory)
@@ -389,22 +427,22 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
def _():
lib.print_ln('\t%s Position not in stash', self.depth)
# Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle
p_prime = self.position_map.get_position(h, found)
self.shuffle_used[p_prime] = cbit(True)
# The block retrieved from the shuffle
# Depending on whether the block has already been `found`, this block
# is either the desired block (found=False) or a random block
# (found=True)
block_p_prime: Array = self.shuffle[p_prime]
if debug:
@lib.if_e(found.reveal() == 0)
def _():
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
@lib.else_
def __():
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
# We add the retrieved block from the shuffle to the stash
self.stash[self.t].assign(block_p_prime[:])
@@ -413,13 +451,13 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
self.t += 1
# if found or not fake
condition = self.bit_type(fake.bit_or(found.bit_not()))
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
# Retrieve l'th item from block
# l is secret, so we must use linear scan
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack))
@lib.for_range_opt(pack)
def _(i):
hit: B = self.bit_type(i == l)
p.write((condition & hit).if_else(block_p_prime[i], p))
p.write((hit[i]).if_else(block_p_prime[i], p))
return p.reveal()