mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
Add multithreading to LinearPositionMap in SqrtORAM
This commit is contained in:
@@ -22,9 +22,7 @@ from Compiler.types import (
|
||||
program = Program.prog
|
||||
|
||||
debug = False
|
||||
n_parallel = 1024
|
||||
n_threads = 8
|
||||
|
||||
multithreading = True
|
||||
|
||||
def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit):
|
||||
@@ -486,26 +484,47 @@ class LinearPositionMap(PositionMap):
|
||||
This method corresponds to GetPosBase in the paper.
|
||||
"""
|
||||
super().get_position(logical_address, fake)
|
||||
fake = self.bit_type(fake)
|
||||
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
linear_scan = self.bit_type.Array(self.n)
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
linear_scan[i] = logical_address == i
|
||||
fake = MemValue(self.bit_type(fake))
|
||||
logical_address = MemValue(logical_address)
|
||||
|
||||
p: MemValue = MemValue(self.index_type(-1))
|
||||
done: B = self.bit_type(False)
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(j):
|
||||
nonlocal done, fake
|
||||
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
|
||||
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
|
||||
p.write(condition.if_else(self.physical[j], p))
|
||||
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
|
||||
done = self.bit_type(condition.if_else(self.bit_type(True), done))
|
||||
if multithreading:
|
||||
conditions:Array = self.bit_type.Array(self.n)
|
||||
conditions.assign_all(0)
|
||||
|
||||
@lib.for_range_opt_multithread(8, self.n)
|
||||
def condition_i(i):
|
||||
conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i)
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
nonlocal done
|
||||
conditions[i] &= done.bit_not()
|
||||
done |= conditions[i]
|
||||
@lib.map_sum_opt(8, self.n, [self.value_type])
|
||||
def calc_p(i):
|
||||
return self.physical[i] * conditions[i]
|
||||
p.write(calc_p())
|
||||
|
||||
self.used.assign(self.used[:] | conditions[:])
|
||||
else:
|
||||
# In order to get an address at secret logical_address,
|
||||
# we need to perform a linear scan.
|
||||
linear_scan = self.bit_type.Array(self.n)
|
||||
@lib.for_range_opt(self.n)
|
||||
def _(i):
|
||||
linear_scan[i] = logical_address == i
|
||||
|
||||
@lib.for_range_opt(self.n)
|
||||
def __(j):
|
||||
nonlocal done, fake
|
||||
condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \
|
||||
.bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not())
|
||||
p.write(condition.if_else(self.physical[j], p))
|
||||
self.used[j] = condition.if_else(self.bit_type(True), self.used[j])
|
||||
done = self.bit_type(condition.if_else(self.bit_type(True), done))
|
||||
|
||||
if debug:
|
||||
@lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))
|
||||
|
||||
Reference in New Issue
Block a user