Add multithreading to LinearPositionMap in SqrtORAM

This commit is contained in:
Kevin Witlox
2022-07-29 17:35:48 +02:00
parent 8af345a713
commit 33299e78a5

View File

@@ -22,9 +22,7 @@ from Compiler.types import (
program = Program.prog
debug = False
n_parallel = 1024
n_threads = 8
multithreading = True
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
@@ -486,26 +484,47 @@ class LinearPositionMap(PositionMap):
This method corresponds to GetPosBase in the paper.
"""
super().get_position(logical_address, fake)
fake = self.bit_type(fake)
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
linear_scan = self.bit_type.Array(self.n)
@lib.for_range_opt(self.n)
def _(i):
linear_scan[i] = logical_address == i
fake = MemValue(self.bit_type(fake))
logical_address = MemValue(logical_address)
p: MemValue = MemValue(self.index_type(-1))
done: B = self.bit_type(False)
@lib.for_range_opt(self.n)
def _(j):
nonlocal done, fake
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
p.write(condition.if_else(self.physical[j], p))
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
done = self.bit_type(condition.if_else(self.bit_type(True), done))
if multithreading:
conditions:Array = self.bit_type.Array(self.n)
conditions.assign_all(0)
@lib.for_range_opt_multithread(8, self.n)
def condition_i(i):
conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i)
@lib.for_range_opt(self.n)
def _(i):
nonlocal done
conditions[i] &= done.bit_not()
done |= conditions[i]
@lib.map_sum_opt(8, self.n, [self.value_type])
def calc_p(i):
return self.physical[i] * conditions[i]
p.write(calc_p())
self.used.assign(self.used[:] | conditions[:])
else:
# In order to get an address at secret logical_address,
# we need to perform a linear scan.
linear_scan = self.bit_type.Array(self.n)
@lib.for_range_opt(self.n)
def _(i):
linear_scan[i] = logical_address == i
@lib.for_range_opt(self.n)
def __(j):
nonlocal done, fake
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
p.write(condition.if_else(self.physical[j], p))
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
done = self.bit_type(condition.if_else(self.bit_type(True), done))
if debug:
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))