Fix bugs in batch initialization.

This commit is contained in:
Marcel Keller
2025-04-10 16:39:06 +10:00
parent 4675cba4e7
commit 85f2d094a9

View File

@@ -1253,8 +1253,6 @@ class TreeORAM(AbstractORAM):
""" Batch initalization. Obliviously shuffles and adds N entries to
random leaf buckets. """
m = len(values)
if not (m & (m-1)) == 0:
raise CompilerError('Batch size must a power of 2.')
if m != self.size:
raise CompilerError('Batch initialization must have N values.')
if self.value_type != sint:
@@ -1289,17 +1287,22 @@ class TreeORAM(AbstractORAM):
self.value_type.hard_conv(False), value_type=self.value_type)
# save unsorted leaves for position map
unsorted_leaves = Array.create_from(leaves)
unsorted_leaves = leaves
# add all possible leaves to ensure appearance in B
leaves = self.value_type.Array(m + 2 ** self.D)
leaves[:] = unsorted_leaves
leaves.assign(regint.inc(2 ** self.D), base=m)
leaves.sort()
bucket_sz = 0
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
B = sint.Matrix(m, 3)
B = sint.Matrix(len(leaves), 3)
B[0] = [0, leaves[0], 0]
B[-1] = [None, None, sint(1)]
B[-1] = [0, 0, sint(1)]
s = MemValue(sint(0))
@for_range_opt(m - 1)
@for_range_opt(len(B) - 1)
def _(j):
i = j + 1
eq = leaves[i].equal(leaves[i-1])
@@ -1310,6 +1313,8 @@ class TreeORAM(AbstractORAM):
#pos[i] = [s, leaves[i]]
#last_in_bucket[i-1] = 1 - eq
# delete to avoid further usage
del leaves
# shuffle
B.secure_shuffle()
#cint(0).print_reg('shuf')
@@ -1319,7 +1324,7 @@ class TreeORAM(AbstractORAM):
empty_positions = Array(nleaves, self.value_type)
empty_leaves = Array(nleaves, self.value_type)
@for_range(m)
@for_range(len(B))
def _(i):
if_then(reveal(B[i][2]))
#if B[i][2] == 1:
@@ -1329,7 +1334,8 @@ class TreeORAM(AbstractORAM):
else:
szval = sz.read()
#szval.print_reg('sz')
empty_positions[szval] = B[i][0] #pos[i][0]
# subtract one to undo adding above
empty_positions[szval] = B[i][0] - 1 #pos[i][0]
#empty_positions[szval].reveal().print_reg('ps0')
empty_leaves[szval] = B[i][1] #pos[i][1]
sz.iadd(1)
@@ -1634,13 +1640,13 @@ class PackedIndexStructure(object):
def batch_init(self, values):
""" Initialize m values with indices 0, ..., m-1 """
m = len(values)
n_entries = max(1, m//self.entries_per_block)
n_entries = int(math.ceil(m / self.entries_per_block))
new_values = sint.Matrix(n_entries, self.elements_per_block)
values = Array.create_from(values)
@for_range(n_entries)
def _(i):
block = [0] * self.elements_per_block
block = Array.create_from([sint(0)] * self.elements_per_block)
for j in range(self.elements_per_block):
base = i * self.entries_per_block + j * self.entries_per_element
for k in range(self.entries_per_element):