diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 5a5e24dd..5de1174f 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -22,9 +22,7 @@ from Compiler.types import ( program = Program.prog debug = False -n_parallel = 1024 n_threads = 8 - multithreading = True def swap(array: Array | MultiArray, pos_a: int | cint, pos_b: int | cint, cond: sintbit | sbit): @@ -486,26 +484,47 @@ class LinearPositionMap(PositionMap): This method corresponds to GetPosBase in the paper. """ super().get_position(logical_address, fake) - fake = self.bit_type(fake) - - # In order to get an address at secret logical_address, - # we need to perform a linear scan. - linear_scan = self.bit_type.Array(self.n) - @lib.for_range_opt(self.n) - def _(i): - linear_scan[i] = logical_address == i + fake = MemValue(self.bit_type(fake)) + logical_address = MemValue(logical_address) p: MemValue = MemValue(self.index_type(-1)) done: B = self.bit_type(False) - @lib.for_range_opt(self.n) - def _(j): - nonlocal done, fake - condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ - .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) - p.write(condition.if_else(self.physical[j], p)) - self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) - done = self.bit_type(condition.if_else(self.bit_type(True), done)) + if multithreading: + conditions:Array = self.bit_type.Array(self.n) + conditions.assign_all(0) + + @lib.for_range_opt_multithread(8, self.n) + def condition_i(i): + conditions.assign((self.bit_type(fake).bit_not() & self.bit_type(logical_address == i)) | (fake & self.used[i].bit_not()), base=i) + + @lib.for_range_opt(self.n) + def _(i): + nonlocal done + conditions[i] &= done.bit_not() + done |= conditions[i] + @lib.map_sum_opt(8, self.n, [self.value_type]) + def calc_p(i): + return self.physical[i] * conditions[i] + p.write(calc_p()) + + self.used.assign(self.used[:] | conditions[:]) + else: + # In order to get an address at secret logical_address, + # we need to perform a linear scan. + linear_scan = self.bit_type.Array(self.n) + @lib.for_range_opt(self.n) + def _(i): + linear_scan[i] = logical_address == i + + @lib.for_range_opt(self.n) + def __(j): + nonlocal done, fake + condition: B = (self.bit_type(fake.bit_not()) & linear_scan[j]) \ + .bit_or(fake & self.bit_type((self.used[j]).bit_not()) & done.bit_not()) + p.write(condition.if_else(self.physical[j], p)) + self.used[j] = condition.if_else(self.bit_type(True), self.used[j]) + done = self.bit_type(condition.if_else(self.bit_type(True), done)) if debug: @lib.if_((p.reveal() < 0).bit_or(p.reveal() > len(self.physical)))