Working on parallel matrix shuffle

This commit is contained in:
Vincent Ehrmanntraut
2024-12-12 08:56:46 +01:00
parent 3d3bfcf6a6
commit 0f4c825366
2 changed files with 84 additions and 24 deletions

View File

@@ -7106,7 +7106,7 @@ class SubMultiArray(_vectorizable):
"""
self.assign_vector(self.get_vector().secure_shuffle(self.part_size()))
def secure_permute(self, permutation, reverse=False, n_threads=None):
def secure_permute(self, permutation, reverse=False, n_parallel=None):
""" Securely permute rows (first index). See
:py:func:`secure_shuffle` for references.
@@ -7114,12 +7114,20 @@ class SubMultiArray(_vectorizable):
:param reverse: whether to apply inverse (default: False)
"""
if n_threads is not None:
permutation = MemValue(permutation)
@library.for_range_multithread(n_threads, 1, self.get_part_size())
def _(i):
self.set_column(i, self.get_column(i).secure_permute(
permutation, reverse=reverse))
if self.value_type == sint and False:
unit_size = self.get_part_size()
n = self.sizes[0] * unit_size
res = sint(size=n)
applyshuffle(n, res, self[:], unit_size, permutation, reverse)
self.assign_vector(res)
else:
if n_parallel is None:
n_parallel = self.get_part_size()
@library.for_range_parallel(n_parallel, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
def sort(self, key_indices=None, n_bits=None):
""" Sort sub-arrays (different first index) in place.

View File

@@ -1,4 +1,21 @@
def test_allocator():
arr1 = sint.Array(5)
arr2 = sint.Array(10)
arr3 = sint.Array(20)
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(10)
p3 = sint.get_secure_shuffle(20)
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
arr1.secure_permute(p1)
arr2[0] = arr1[0]
arr2.secure_permute(p2)
arr3.secure_permute(p3)
def test_case(permutation_sizes, timer_base: int | None = None):
if timer_base is not None:
start_timer(timer_base + 0)
@@ -100,25 +117,60 @@ def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], ti
crash()
def test_allocator():
arr1 = sint.Array(5)
arr2 = sint.Array(10)
arr3 = sint.Array(20)
def test_permute_matrix(timer_base: int, value_type=sint) -> None:
def test_permuted_matrix(m, p):
permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])])
permuted_indices.secure_permute(p, reverse=True)
permuted_indices = permuted_indices.reveal()
@for_range(m.sizes[0])
def check_row(i):
@for_range(m.sizes[1])
def check_entry(j):
@if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j))
def fail():
print_ln("Matrix permuted unexpectedly.")
crash()
print_ln("Pre-create matrix")
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(10)
p3 = sint.get_secure_shuffle(20)
p2 = sint.get_secure_shuffle(5)
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
arr1.secure_permute(p1)
arr2[0] = arr1[0]
arr2.secure_permute(p2)
arr3.secure_permute(p3)
start_timer(timer_base + 1)
m1.secure_permute(p1, n_parallel=1)
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
m2.secure_permute(p2, n_parallel=1)
stop_timer(timer_base + 2)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
print_ln("Permuted m1:")
if value_type == sint:
for row in m1:
print_ln("%s", row)
test_permuted_matrix(m1, p1)
print_ln("Permuted m2:")
if value_type == sint:
for row in m2:
print_ln("%s", row)
test_permuted_matrix(m2, p2)
# test_allocator()
test_case([5,10], 10)
test_case([5, 10, 15, 20], 20)
test_case([4,8,16], 30)
test_case([5], 40)
test_parallel_permutation_equals_sequential_permutation([5,10],50)
# test_case([5,10], 10)
# test_case([5, 10, 15, 20], 20)
# test_case([4,8,16], 30)
# test_case([5], 40)
#
# test_parallel_permutation_equals_sequential_permutation([5,10],50)
#
# test_permute_matrix(60)
# test_permute_matrix(70, value_type=sfix)