From 85f2d094a9d7490fd190337f807e99a9f4331be5 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 10 Apr 2025 16:39:06 +1000 Subject: [PATCH] Fix bugs in batch initialization. --- Compiler/oram.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/Compiler/oram.py b/Compiler/oram.py index 68842aa5..3e491225 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1253,8 +1253,6 @@ class TreeORAM(AbstractORAM): """ Batch initalization. Obliviously shuffles and adds N entries to random leaf buckets. """ m = len(values) - if not (m & (m-1)) == 0: - raise CompilerError('Batch size must a power of 2.') if m != self.size: raise CompilerError('Batch initialization must have N values.') if self.value_type != sint: @@ -1289,17 +1287,22 @@ class TreeORAM(AbstractORAM): self.value_type.hard_conv(False), value_type=self.value_type) # save unsorted leaves for position map - unsorted_leaves = Array.create_from(leaves) + unsorted_leaves = leaves + + # add all possible leaves to ensure appearance in B + leaves = self.value_type.Array(m + 2 ** self.D) + leaves[:] = unsorted_leaves + leaves.assign(regint.inc(2 ** self.D), base=m) leaves.sort() bucket_sz = 0 # B[i] = (pos, leaf, "last in bucket" flag) for i-th entry - B = sint.Matrix(m, 3) + B = sint.Matrix(len(leaves), 3) B[0] = [0, leaves[0], 0] - B[-1] = [None, None, sint(1)] + B[-1] = [0, 0, sint(1)] s = MemValue(sint(0)) - @for_range_opt(m - 1) + @for_range_opt(len(B) - 1) def _(j): i = j + 1 eq = leaves[i].equal(leaves[i-1]) @@ -1310,6 +1313,8 @@ class TreeORAM(AbstractORAM): #pos[i] = [s, leaves[i]] #last_in_bucket[i-1] = 1 - eq + # delete to avoid further usage + del leaves # shuffle B.secure_shuffle() #cint(0).print_reg('shuf') @@ -1319,7 +1324,7 @@ class TreeORAM(AbstractORAM): empty_positions = Array(nleaves, self.value_type) empty_leaves = Array(nleaves, self.value_type) - @for_range(m) + @for_range(len(B)) def _(i): if_then(reveal(B[i][2])) #if B[i][2] == 1: @@ -1329,7 +1334,8 @@ class TreeORAM(AbstractORAM): else: szval = sz.read() #szval.print_reg('sz') - empty_positions[szval] = B[i][0] #pos[i][0] + # subtract one to undo adding above + empty_positions[szval] = B[i][0] - 1 #pos[i][0] #empty_positions[szval].reveal().print_reg('ps0') empty_leaves[szval] = B[i][1] #pos[i][1] sz.iadd(1) @@ -1634,13 +1640,13 @@ class PackedIndexStructure(object): def batch_init(self, values): """ Initialize m values with indices 0, ..., m-1 """ m = len(values) - n_entries = max(1, m//self.entries_per_block) + n_entries = int(math.ceil(m / self.entries_per_block)) new_values = sint.Matrix(n_entries, self.elements_per_block) values = Array.create_from(values) @for_range(n_entries) def _(i): - block = [0] * self.elements_per_block + block = Array.create_from([sint(0)] * self.elements_per_block) for j in range(self.elements_per_block): base = i * self.entries_per_block + j * self.entries_per_element for k in range(self.entries_per_element):