mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
Fix bugs in batch initialization.
This commit is contained in:
@@ -1253,8 +1253,6 @@ class TreeORAM(AbstractORAM):
|
||||
""" Batch initalization. Obliviously shuffles and adds N entries to
|
||||
random leaf buckets. """
|
||||
m = len(values)
|
||||
if not (m & (m-1)) == 0:
|
||||
raise CompilerError('Batch size must a power of 2.')
|
||||
if m != self.size:
|
||||
raise CompilerError('Batch initialization must have N values.')
|
||||
if self.value_type != sint:
|
||||
@@ -1289,17 +1287,22 @@ class TreeORAM(AbstractORAM):
|
||||
self.value_type.hard_conv(False), value_type=self.value_type)
|
||||
|
||||
# save unsorted leaves for position map
|
||||
unsorted_leaves = Array.create_from(leaves)
|
||||
unsorted_leaves = leaves
|
||||
|
||||
# add all possible leaves to ensure appearance in B
|
||||
leaves = self.value_type.Array(m + 2 ** self.D)
|
||||
leaves[:] = unsorted_leaves
|
||||
leaves.assign(regint.inc(2 ** self.D), base=m)
|
||||
leaves.sort()
|
||||
|
||||
bucket_sz = 0
|
||||
# B[i] = (pos, leaf, "last in bucket" flag) for i-th entry
|
||||
B = sint.Matrix(m, 3)
|
||||
B = sint.Matrix(len(leaves), 3)
|
||||
B[0] = [0, leaves[0], 0]
|
||||
B[-1] = [None, None, sint(1)]
|
||||
B[-1] = [0, 0, sint(1)]
|
||||
s = MemValue(sint(0))
|
||||
|
||||
@for_range_opt(m - 1)
|
||||
@for_range_opt(len(B) - 1)
|
||||
def _(j):
|
||||
i = j + 1
|
||||
eq = leaves[i].equal(leaves[i-1])
|
||||
@@ -1310,6 +1313,8 @@ class TreeORAM(AbstractORAM):
|
||||
#pos[i] = [s, leaves[i]]
|
||||
#last_in_bucket[i-1] = 1 - eq
|
||||
|
||||
# delete to avoid further usage
|
||||
del leaves
|
||||
# shuffle
|
||||
B.secure_shuffle()
|
||||
#cint(0).print_reg('shuf')
|
||||
@@ -1319,7 +1324,7 @@ class TreeORAM(AbstractORAM):
|
||||
empty_positions = Array(nleaves, self.value_type)
|
||||
empty_leaves = Array(nleaves, self.value_type)
|
||||
|
||||
@for_range(m)
|
||||
@for_range(len(B))
|
||||
def _(i):
|
||||
if_then(reveal(B[i][2]))
|
||||
#if B[i][2] == 1:
|
||||
@@ -1329,7 +1334,8 @@ class TreeORAM(AbstractORAM):
|
||||
else:
|
||||
szval = sz.read()
|
||||
#szval.print_reg('sz')
|
||||
empty_positions[szval] = B[i][0] #pos[i][0]
|
||||
# subtract one to undo adding above
|
||||
empty_positions[szval] = B[i][0] - 1 #pos[i][0]
|
||||
#empty_positions[szval].reveal().print_reg('ps0')
|
||||
empty_leaves[szval] = B[i][1] #pos[i][1]
|
||||
sz.iadd(1)
|
||||
@@ -1634,13 +1640,13 @@ class PackedIndexStructure(object):
|
||||
def batch_init(self, values):
|
||||
""" Initialize m values with indices 0, ..., m-1 """
|
||||
m = len(values)
|
||||
n_entries = max(1, m//self.entries_per_block)
|
||||
n_entries = int(math.ceil(m / self.entries_per_block))
|
||||
new_values = sint.Matrix(n_entries, self.elements_per_block)
|
||||
values = Array.create_from(values)
|
||||
|
||||
@for_range(n_entries)
|
||||
def _(i):
|
||||
block = [0] * self.elements_per_block
|
||||
block = Array.create_from([sint(0)] * self.elements_per_block)
|
||||
for j in range(self.elements_per_block):
|
||||
base = i * self.entries_per_block + j * self.entries_per_element
|
||||
for k in range(self.entries_per_element):
|
||||
|
||||
Reference in New Issue
Block a user