mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Working on parallel matrix shuffle
This commit is contained in:
@@ -7106,7 +7106,7 @@ class SubMultiArray(_vectorizable):
|
||||
"""
|
||||
self.assign_vector(self.get_vector().secure_shuffle(self.part_size()))
|
||||
|
||||
def secure_permute(self, permutation, reverse=False, n_threads=None):
|
||||
def secure_permute(self, permutation, reverse=False, n_parallel=None):
|
||||
""" Securely permute rows (first index). See
|
||||
:py:func:`secure_shuffle` for references.
|
||||
|
||||
@@ -7114,12 +7114,20 @@ class SubMultiArray(_vectorizable):
|
||||
:param reverse: whether to apply inverse (default: False)
|
||||
|
||||
"""
|
||||
if n_threads is not None:
|
||||
permutation = MemValue(permutation)
|
||||
@library.for_range_multithread(n_threads, 1, self.get_part_size())
|
||||
def _(i):
|
||||
self.set_column(i, self.get_column(i).secure_permute(
|
||||
permutation, reverse=reverse))
|
||||
if self.value_type == sint and False:
|
||||
unit_size = self.get_part_size()
|
||||
n = self.sizes[0] * unit_size
|
||||
res = sint(size=n)
|
||||
applyshuffle(n, res, self[:], unit_size, permutation, reverse)
|
||||
self.assign_vector(res)
|
||||
else:
|
||||
if n_parallel is None:
|
||||
n_parallel = self.get_part_size()
|
||||
@library.for_range_parallel(n_parallel, self.get_part_size())
|
||||
def iter(i):
|
||||
column = self.get_column(i)
|
||||
column = column.secure_permute(permutation, reverse=reverse)
|
||||
self.set_column(i, column)
|
||||
|
||||
def sort(self, key_indices=None, n_bits=None):
|
||||
""" Sort sub-arrays (different first index) in place.
|
||||
|
||||
@@ -1,4 +1,21 @@
|
||||
|
||||
|
||||
def test_allocator():
|
||||
arr1 = sint.Array(5)
|
||||
arr2 = sint.Array(10)
|
||||
arr3 = sint.Array(20)
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
p2 = sint.get_secure_shuffle(10)
|
||||
p3 = sint.get_secure_shuffle(20)
|
||||
|
||||
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
|
||||
arr1.secure_permute(p1)
|
||||
arr2[0] = arr1[0]
|
||||
arr2.secure_permute(p2)
|
||||
arr3.secure_permute(p3)
|
||||
|
||||
|
||||
def test_case(permutation_sizes, timer_base: int | None = None):
|
||||
if timer_base is not None:
|
||||
start_timer(timer_base + 0)
|
||||
@@ -100,25 +117,60 @@ def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], ti
|
||||
crash()
|
||||
|
||||
|
||||
def test_allocator():
|
||||
arr1 = sint.Array(5)
|
||||
arr2 = sint.Array(10)
|
||||
arr3 = sint.Array(20)
|
||||
def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
def test_permuted_matrix(m, p):
|
||||
permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])])
|
||||
permuted_indices.secure_permute(p, reverse=True)
|
||||
permuted_indices = permuted_indices.reveal()
|
||||
|
||||
@for_range(m.sizes[0])
|
||||
def check_row(i):
|
||||
@for_range(m.sizes[1])
|
||||
def check_entry(j):
|
||||
@if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j))
|
||||
def fail():
|
||||
print_ln("Matrix permuted unexpectedly.")
|
||||
crash()
|
||||
|
||||
print_ln("Pre-create matrix")
|
||||
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
|
||||
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
print_ln("post-create matrix")
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
p2 = sint.get_secure_shuffle(10)
|
||||
p3 = sint.get_secure_shuffle(20)
|
||||
p2 = sint.get_secure_shuffle(5)
|
||||
|
||||
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
|
||||
arr1.secure_permute(p1)
|
||||
arr2[0] = arr1[0]
|
||||
arr2.secure_permute(p2)
|
||||
arr3.secure_permute(p3)
|
||||
|
||||
start_timer(timer_base + 1)
|
||||
m1.secure_permute(p1, n_parallel=1)
|
||||
stop_timer(timer_base + 1)
|
||||
start_timer(timer_base + 2)
|
||||
m2.secure_permute(p2, n_parallel=1)
|
||||
stop_timer(timer_base + 2)
|
||||
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
|
||||
|
||||
m1 = m1.reveal()
|
||||
m2 = m2.reveal()
|
||||
|
||||
print_ln("Permuted m1:")
|
||||
if value_type == sint:
|
||||
for row in m1:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m1, p1)
|
||||
|
||||
print_ln("Permuted m2:")
|
||||
if value_type == sint:
|
||||
for row in m2:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m2, p2)
|
||||
|
||||
# test_allocator()
|
||||
test_case([5,10], 10)
|
||||
test_case([5, 10, 15, 20], 20)
|
||||
test_case([4,8,16], 30)
|
||||
test_case([5], 40)
|
||||
|
||||
test_parallel_permutation_equals_sequential_permutation([5,10],50)
|
||||
# test_case([5,10], 10)
|
||||
# test_case([5, 10, 15, 20], 20)
|
||||
# test_case([4,8,16], 30)
|
||||
# test_case([5], 40)
|
||||
#
|
||||
# test_parallel_permutation_equals_sequential_permutation([5,10],50)
|
||||
#
|
||||
# test_permute_matrix(60)
|
||||
# test_permute_matrix(70, value_type=sfix)
|
||||
Reference in New Issue
Block a user