mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
Optimize performance of SqrtORAM
This commit is contained in:
@@ -1,17 +1,41 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Generic, Iterable, Literal, Type, Any, TypeVar
|
||||
import math
|
||||
from typing import Any, Generic, Type, TypeVar
|
||||
|
||||
from Compiler.program import Program
|
||||
from Compiler import util
|
||||
from Compiler import library as lib
|
||||
from Compiler.GC.types import cbit, sbit, sbitint, sbits
|
||||
from Compiler.oram import AbstractORAM, get_n_threads
|
||||
from Compiler.types import MultiArray, sgf2n, sint, _secret, MemValue, Array, _clear, sintbit, cint
|
||||
import numpy as np
|
||||
from Compiler.types import (
|
||||
Array,
|
||||
MemValue,
|
||||
MultiArray,
|
||||
_clear,
|
||||
_secret,
|
||||
cint,
|
||||
sint,
|
||||
sintbit,
|
||||
regint
|
||||
)
|
||||
|
||||
debug = True
|
||||
reveal = True
|
||||
program = Program.prog
|
||||
|
||||
debug = False
|
||||
n_parallel = 1024
|
||||
n_threads = 8
|
||||
|
||||
multithreading = True
|
||||
|
||||
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
|
||||
"""Swap two positions in an Array if a condition is met.
|
||||
|
||||
Args:
|
||||
array (Array | MultiArray): The array in which to swap the first and second position
|
||||
pos_a (int | cint): The first position
|
||||
pos_b (int | cint): The second position
|
||||
cond (sintbit | sbit): The condition determining whether to swap
|
||||
"""
|
||||
if isinstance(array, MultiArray):
|
||||
temp = array[pos_b][:]
|
||||
array[pos_b].assign(cond.if_else(array[pos_a][:], array[pos_b][:]))
|
||||
@@ -49,7 +73,7 @@ class SqrtOram(Generic[T, B]):
|
||||
# the stash)
|
||||
t: cint
|
||||
|
||||
def __init__(self, data: MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
|
||||
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None) -> None:
|
||||
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
|
||||
|
||||
Args:
|
||||
@@ -64,55 +88,51 @@ class SqrtOram(Generic[T, B]):
|
||||
if value_type != sint and value_type != sbitint:
|
||||
raise Exception("The value_type must be either sint or sbitint")
|
||||
self.bit_type: Type[B] = value_type.bit_type
|
||||
self.index_type = value_type.get_type(int(np.ceil(np.log2(self.n)) ))
|
||||
self.index_type = value_type.get_type(util.log2(self.n))
|
||||
self.entry_length = entry_length
|
||||
|
||||
if debug:
|
||||
lib.print_ln('Initializing SqrtORAM of size %s at depth %s', self.n, k)
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shuffle = data
|
||||
if isinstance(data, MultiArray):
|
||||
self.shuffle = data
|
||||
elif isinstance(data, sint):
|
||||
self.shuffle = MultiArray((self.n, self.entry_length), value_type=value_type)
|
||||
self.shuffle.assign_vector(data.get_vector())
|
||||
else:
|
||||
raise Exception("Incorrect format.")
|
||||
self.shufflei = Array.create_from([self.index_type(i) for i in range(self.n)])
|
||||
permutation = Array.create_from(self.shuffle_the_shuffle())
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
self.T = int(np.ceil(np.sqrt(self.n * np.log2(self.n) - self.n + 1))
|
||||
) if not period else period
|
||||
self.T = int(math.ceil(math.sqrt(self.n * util.log2(self.n) - self.n + 1))) if not period else period
|
||||
if debug and not period:
|
||||
lib.print_ln('Period set to %s', self.T)
|
||||
# Initialize position map (recursive oram)
|
||||
self.position_map = PositionMap.create(permutation, k + 1, self.T)
|
||||
|
||||
# Initialize stash
|
||||
self.stash = MultiArray((self.T, data.sizes[1]), value_type=value_type)
|
||||
self.stash = MultiArray((self.T, entry_length), value_type=value_type)
|
||||
self.stashi = Array(self.T, value_type=value_type)
|
||||
self.t = MemValue(cint(0))
|
||||
|
||||
|
||||
@lib.method_block
|
||||
def read(self, index: T):
|
||||
data = self.value_type.Array(self.entry_length)
|
||||
return self.access(index, self.bit_type(False), data)
|
||||
value = self.value_type(0, size=self.entry_length)
|
||||
return self.access(index, self.bit_type(False), *value)
|
||||
|
||||
def write(self, index: T, value: Array):
|
||||
self.access(index, self.bit_type(True), value)
|
||||
@lib.method_block
|
||||
def write(self, index: T, value: T):
|
||||
lib.runtime_error_if(value.size != self.entry_length, "A block must be of size entry_length")
|
||||
self.access(index, self.bit_type(True), *value)
|
||||
|
||||
__getitem__ = read
|
||||
__setitem__ = write
|
||||
|
||||
def access(self, index: T, write: B, value: Array):
|
||||
if len(value) != self.entry_length:
|
||||
raise Exception("A block must be of size entry_length={}".format(self.entry_length))
|
||||
# Method Blocks do not accepts arrays as arguments
|
||||
# workaround by temporarily storing it as a class field
|
||||
# arrays are stored in memory so this is fine
|
||||
index = MemValue(index)
|
||||
return Array.create_from(self._access(index, write, value[:]))
|
||||
|
||||
@lib.method_block
|
||||
def _access(self, index: T, write: B, *value: list[T]):
|
||||
item: T = self.value_type(*value)
|
||||
|
||||
def access(self, index: T, write: B, *value: T):
|
||||
if debug:
|
||||
@lib.if_e(write.reveal() == 1)
|
||||
def _():
|
||||
@@ -120,6 +140,7 @@ class SqrtOram(Generic[T, B]):
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('Reading from secret index %s', index.reveal())
|
||||
value = self.value_type(value)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@lib.if_(self.t == self.T)
|
||||
@@ -136,14 +157,24 @@ class SqrtOram(Generic[T, B]):
|
||||
# We ensure that if the item is found in stash, it ends up in the first
|
||||
# position (more importantly, a fixed position) of the stash
|
||||
# This allows us to keep track of it in an oblivious manner
|
||||
@lib.for_range_opt(self.t)
|
||||
def _(i):
|
||||
nonlocal found
|
||||
found_: B = index == self.stashi[i + 1]
|
||||
swap(self.stash, 0, i, found_)
|
||||
swap(self.stashi, 0, i, found_)
|
||||
found |= found_
|
||||
# found = self.bit_type(found.bit_or(found_))
|
||||
if multithreading:
|
||||
found_ = self.bit_type.Array(size=self.T)
|
||||
@lib.multithread(8, self.T)
|
||||
def _(base, size):
|
||||
found_.assign_vector(self.stashi.get_vector(base, size)[:] == index, base=base)
|
||||
@lib.for_range_opt(self.t - 1)
|
||||
def _(i):
|
||||
swap(self.stash, 0, i, found_[i])
|
||||
swap(self.stashi, 0, i, found_[i])
|
||||
found.write(sum(found_))
|
||||
else:
|
||||
@lib.for_range_opt(self.t - 1)
|
||||
def _(i):
|
||||
nonlocal found
|
||||
found_: B = index == self.stashi[i + 1]
|
||||
swap(self.stash, 0, i, found_)
|
||||
swap(self.stashi, 0, i, found_)
|
||||
found |= found_
|
||||
# If the item was not in the stash, we move the unknown and unimportant
|
||||
# stash[0] out of the way (to the end of the stash)
|
||||
swap(self.stash, self.t, 0, sintbit(found.bit_not().bit_and(self.t > 0)))
|
||||
@@ -156,7 +187,7 @@ class SqrtOram(Generic[T, B]):
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln(' Item not in stash')
|
||||
lib.print_ln(' Moved stash[0]=(%s: %s) to stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
|
||||
lib.print_ln(' Moved stash[0]=(%s: %s) to the back of the stash[t=%s]=(%s: %s)', self.stashi[0].reveal(), self.stash[0].reveal(), self.t, self.stashi[self.t].reveal(), self.stash[self.t].reveal())
|
||||
|
||||
# Possible fake lookup of the item in the shuffle,
|
||||
# depending on whether we already found the item in the stash
|
||||
@@ -171,14 +202,15 @@ class SqrtOram(Generic[T, B]):
|
||||
self.stashi[self.t] = found.if_else(
|
||||
self.shufflei[physical_address],
|
||||
self.stashi[self.t])
|
||||
# If the item was not found in the stash,
|
||||
# we place the item retrieved from the shuffle in stash[0]
|
||||
# If the item was not found in the stash, we place the item retrieved
|
||||
# from the shuffle (the item we are actually looking for) in stash[0]
|
||||
self.stash[0].assign(found.bit_not().if_else(
|
||||
self.shuffle[physical_address][:],
|
||||
self.stash[0][:]))
|
||||
self.stashi[0] = found.bit_not().if_else(
|
||||
self.shufflei[physical_address],
|
||||
self.stashi[0])
|
||||
|
||||
if debug:
|
||||
@lib.if_e(found.reveal() == 1)
|
||||
def _():
|
||||
@@ -191,9 +223,9 @@ class SqrtOram(Generic[T, B]):
|
||||
# Increase the "time" (i.e. access count in current period)
|
||||
self.t.iadd(1)
|
||||
|
||||
self.stash[0].assign(write.if_else(item, self.stash[0][:]))
|
||||
item=write.bit_not().if_else(self.stash[0][:], item)
|
||||
return item
|
||||
self.stash[0].assign(write.if_else(value, self.stash[0][:]))
|
||||
value=write.bit_not().if_else(self.stash[0][:], value)
|
||||
return value
|
||||
|
||||
|
||||
@lib.method_block
|
||||
@@ -206,12 +238,12 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
# Random permutation on n elements
|
||||
random_shuffle = sint.get_secure_shuffle(self.n)
|
||||
if debug: lib.print_ln('\tGenerated shuffle')
|
||||
# Apply the random permutation
|
||||
lib.print_ln('\tGenerated shuffle')
|
||||
self.shuffle.secure_permute(random_shuffle)
|
||||
lib.print_ln('\tShuffled shuffle')
|
||||
if debug: lib.print_ln('\tShuffled shuffle')
|
||||
self.shufflei.secure_permute(random_shuffle)
|
||||
lib.print_ln('\tShuffled shuffle indexes')
|
||||
if debug: lib.print_ln('\tShuffled shuffle indexes')
|
||||
# Calculate the permutation that would have produced the newly produced
|
||||
# shuffle order. This can be calculated by regarding the logical
|
||||
# indexes (shufflei) as a permutation and calculating its inverse,
|
||||
@@ -220,7 +252,7 @@ class SqrtOram(Generic[T, B]):
|
||||
# random_shuffle, as the shuffle may already be out of order (e.g. when
|
||||
# refreshing).
|
||||
permutation = MemValue(self.shufflei[:].inverse_permutation())
|
||||
lib.print_ln('\tCalculated inverse permutation')
|
||||
if debug: lib.print_ln('\tCalculated inverse permutation')
|
||||
return permutation
|
||||
|
||||
@lib.method_block
|
||||
@@ -229,7 +261,8 @@ class SqrtOram(Generic[T, B]):
|
||||
reshuffling the shuffle.
|
||||
|
||||
This must happen after T (period) accesses to the ORAM."""
|
||||
lib.print_ln('Refreshing SqrtORAM')
|
||||
|
||||
if debug: lib.print_ln('Refreshing SqrtORAM')
|
||||
|
||||
# Shuffle and emtpy the stash, and store elements back into shuffle
|
||||
j = MemValue(cint(0,size=1))
|
||||
@@ -276,7 +309,7 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
|
||||
class PositionMap(Generic[T, B]):
|
||||
PACK_LOG: int = 2
|
||||
PACK_LOG: int = 3
|
||||
PACK: int = 1 << PACK_LOG
|
||||
|
||||
n: int # n in the paper
|
||||
@@ -288,7 +321,7 @@ class PositionMap(Generic[T, B]):
|
||||
self.depth=MemValue(cint(k))
|
||||
self.value_type = value_type
|
||||
self.bit_type = value_type.bit_type
|
||||
self.index_type = self.value_type.get_type(int(np.ceil(np.log2(n))))
|
||||
self.index_type = self.value_type.get_type(util.log2(n))
|
||||
|
||||
@abstractmethod
|
||||
def get_position(self, logical_address: _secret, fake: B) -> Any:
|
||||
@@ -332,7 +365,7 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
pack = PositionMap.PACK
|
||||
|
||||
# We pack the permutation into a smaller structure, index with a new permutation
|
||||
packed_size = int(np.ceil(self.n / pack))
|
||||
packed_size = int(math.ceil(self.n / pack))
|
||||
packed_structure = MultiArray(
|
||||
(packed_size, pack), value_type=value_type)
|
||||
for i in range(packed_size):
|
||||
@@ -359,28 +392,33 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
p = MemValue(self.index_type(0))
|
||||
found: B = MemValue(self.bit_type(False))
|
||||
|
||||
# First we try and retrieve the item from the stash
|
||||
# First we try and retrieve the item from the stash at position stash[h][l]
|
||||
# Since h and l are secret, we do this by scanning the entire stash
|
||||
|
||||
# We retrieve stash[h]
|
||||
# Since h is secret, we do this by scanning the entire stash
|
||||
# First we scan the stash for the block we need
|
||||
condition1 = self.bit_type.Array(self.T)
|
||||
@lib.for_range_opt_multithread(8, self.T)
|
||||
def _(i):
|
||||
condition1[i] = (self.stashi[i] == h) & self.bit_type(i < self.t)
|
||||
found = sum(condition1)
|
||||
# Once a block is found, we use condition2 to pick the correct item from that block
|
||||
condition2 = Array.create_from(regint.inc(pack) == l.expand_to_vector(pack))
|
||||
# condition3 combines condition1 & condition2, only returning true at stash[h][l]
|
||||
condition3 = self.bit_type.Array(self.T * pack)
|
||||
@lib.for_range_opt_multithread(8, [self.T, pack])
|
||||
def _(i, j):
|
||||
condition3[i*pack + j] = condition1[i] & condition2[j]
|
||||
# Finally we use condition3 to conditionally write p
|
||||
@lib.for_range(self.t)
|
||||
def _(j):
|
||||
nonlocal found
|
||||
condition = self.stashi[j] == h
|
||||
found |= condition
|
||||
# block = stash[h]
|
||||
# block is itself an array (it holds a permutation)
|
||||
# we need to grab block[l]
|
||||
def _(i):
|
||||
@lib.for_range(pack)
|
||||
def _(i):
|
||||
nonlocal condition
|
||||
condition &= l == i
|
||||
p.write(condition.if_else(self.stash[j][i], p))
|
||||
def _(j):
|
||||
p.write(condition3[i*pack + j].if_else(self.stash[i][j], p))
|
||||
|
||||
if debug:
|
||||
@lib.if_(condition.reveal() == 1)
|
||||
@lib.if_(condition1[i].reveal() == 1)
|
||||
def _():
|
||||
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, j, self.stashi[j].reveal(), self.stash[j].reveal())
|
||||
lib.print_ln('\t%s Found position in stash[%s]=(%s: %s)', self.depth, i, self.stashi[i].reveal(), self.stash[i].reveal())
|
||||
|
||||
# Then we try and retrieve the item from the shuffle (the actual memory)
|
||||
|
||||
@@ -389,22 +427,22 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
def _():
|
||||
lib.print_ln('\t%s Position not in stash', self.depth)
|
||||
|
||||
|
||||
# Depending on whether we found the item in the stash, we either retrieve h or a random element from the shuffle
|
||||
p_prime = self.position_map.get_position(h, found)
|
||||
self.shuffle_used[p_prime] = cbit(True)
|
||||
|
||||
# The block retrieved from the shuffle
|
||||
# Depending on whether the block has already been `found`, this block
|
||||
# is either the desired block (found=False) or a random block
|
||||
# (found=True)
|
||||
block_p_prime: Array = self.shuffle[p_prime]
|
||||
|
||||
if debug:
|
||||
@lib.if_e(found.reveal() == 0)
|
||||
def _():
|
||||
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)', self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
lib.print_ln('\t%s Retrieved stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
@lib.else_
|
||||
def __():
|
||||
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
lib.print_ln('\t%s Retrieved dummy stash[%s]=(%s: %s)',
|
||||
self.depth, p_prime.reveal(), self.shufflei[p_prime.reveal()].reveal(), self.shuffle[p_prime.reveal()].reveal())
|
||||
|
||||
# We add the retrieved block from the shuffle to the stash
|
||||
self.stash[self.t].assign(block_p_prime[:])
|
||||
@@ -413,13 +451,13 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
self.t += 1
|
||||
|
||||
# if found or not fake
|
||||
condition = self.bit_type(fake.bit_or(found.bit_not()))
|
||||
condition: B = self.bit_type(fake.bit_or(found.bit_not()))
|
||||
# Retrieve l'th item from block
|
||||
# l is secret, so we must use linear scan
|
||||
hit = Array.create_from((regint.inc(pack) == l.expand_to_vector(pack)) & condition.expand_to_vector(pack))
|
||||
@lib.for_range_opt(pack)
|
||||
def _(i):
|
||||
hit: B = self.bit_type(i == l)
|
||||
p.write((condition & hit).if_else(block_p_prime[i], p))
|
||||
p.write((hit[i]).if_else(block_p_prime[i], p))
|
||||
|
||||
return p.reveal()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user