mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Fix applyshuffle with dead code elimination and more flexible Matrix secure_permute
This commit is contained in:
@@ -2763,6 +2763,7 @@ class applyshuffle(shuffle_base, base.Mergeable):
|
||||
__slots__ = []
|
||||
code = base.opcodes['APPLYSHUFFLE']
|
||||
arg_format = itertools.cycle(['int', 'sw','s','int','ci','int'])
|
||||
is_vec = lambda self: True # Ensures dead-code elimination works.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(applyshuffle, self).__init__(*args, **kwargs)
|
||||
|
||||
@@ -7274,7 +7274,7 @@ class SubMultiArray(_vectorizable):
|
||||
self.secure_permute(perm)
|
||||
delshuffle(perm)
|
||||
|
||||
def secure_permute(self, permutation, reverse=False, n_threads=None):
|
||||
def secure_permute(self, permutation, reverse=False, n_threads=None, n_parallel=None):
|
||||
""" Securely permute rows (first index). See
|
||||
:py:func:`secure_shuffle` for references.
|
||||
|
||||
@@ -7282,7 +7282,8 @@ class SubMultiArray(_vectorizable):
|
||||
:param reverse: whether to apply inverse (default: False)
|
||||
|
||||
"""
|
||||
if self.value_type == sint:
|
||||
if (self.value_type == sint) and (n_threads is None):
|
||||
# Use only a single shuffle instruction if applicable and permutation is single-threaded anyway.
|
||||
unit_size = self.get_part_size()
|
||||
n = self.sizes[0] * unit_size
|
||||
res = sint(size=n)
|
||||
@@ -7291,11 +7292,19 @@ class SubMultiArray(_vectorizable):
|
||||
else:
|
||||
if n_threads is not None:
|
||||
permutation = MemValue(permutation)
|
||||
@library.for_range_multithread(n_threads, 1, self.get_part_size())
|
||||
def iter(i):
|
||||
column = self.get_column(i)
|
||||
column = column.secure_permute(permutation, reverse=reverse)
|
||||
self.set_column(i, column)
|
||||
|
||||
if n_parallel is None:
|
||||
@library.for_range_opt_multithread(n_threads, self.get_part_size())
|
||||
def iter(i):
|
||||
column = self.get_column(i)
|
||||
column = column.secure_permute(permutation, reverse=reverse)
|
||||
self.set_column(i, column)
|
||||
else:
|
||||
@library.for_range_multithread(n_threads, n_parallel, self.get_part_size())
|
||||
def iter(i):
|
||||
column = self.get_column(i)
|
||||
column = column.secure_permute(permutation, reverse=reverse)
|
||||
self.set_column(i, column)
|
||||
|
||||
def sort(self, key_indices=None, n_bits=None, batcher=False):
|
||||
""" Sort sub-arrays (different first index) in place.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from Compiler.library import get_number_of_players
|
||||
from Compiler.sqrt_oram import n_parallel
|
||||
from Compiler.util import if_else
|
||||
|
||||
|
||||
@@ -138,10 +139,13 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
|
||||
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
|
||||
print_ln("post-create matrix")
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
p2 = sint.get_secure_shuffle(5)
|
||||
p3 = sint.get_secure_shuffle(5)
|
||||
|
||||
start_timer(timer_base + 1)
|
||||
m1.secure_permute(p1)
|
||||
@@ -152,11 +156,19 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
start_timer(timer_base + 3)
|
||||
m3.secure_permute(p2, n_threads=3)
|
||||
stop_timer(timer_base + 3)
|
||||
start_timer(timer_base + 4)
|
||||
m4.secure_permute(p2, n_threads=3, n_parallel=2)
|
||||
stop_timer(timer_base + 4)
|
||||
start_timer(timer_base + 5)
|
||||
m5.secure_permute(p3, n_threads=3, n_parallel=3)
|
||||
stop_timer(timer_base + 5)
|
||||
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
|
||||
|
||||
m1 = m1.reveal()
|
||||
m2 = m2.reveal()
|
||||
m3 = m3.reveal()
|
||||
m4 = m4.reveal()
|
||||
m5 = m5.reveal()
|
||||
|
||||
print_ln("Permuted m1:")
|
||||
for row in m1:
|
||||
@@ -173,6 +185,16 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m3, p2)
|
||||
|
||||
print_ln("Permuted m4 (should be equal to m2):")
|
||||
for row in m4:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m4, p2)
|
||||
|
||||
print_ln("Permuted m5:")
|
||||
for row in m5:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m5, p3)
|
||||
|
||||
def test_secure_shuffle_still_works(size: int, timer_base: int):
|
||||
arr = Array.create_from([sint(i) for i in range(size)])
|
||||
start_timer(timer_base)
|
||||
@@ -226,6 +248,11 @@ def test_inverse_permutation_still_works(size, timer_base: int):
|
||||
|
||||
|
||||
|
||||
def test_dead_code_elimination():
|
||||
vector = sint([0,2,4,6,5,3,1])
|
||||
handle = sint.get_secure_shuffle(7)
|
||||
print_ln('%s', vector.secure_permute(handle).reveal())
|
||||
|
||||
|
||||
test_allocator()
|
||||
test_case([10, 15], 10)
|
||||
@@ -239,4 +266,6 @@ test_permute_matrix(60)
|
||||
test_permute_matrix(70, value_type=sfix)
|
||||
|
||||
test_secure_shuffle_still_works(32, 80)
|
||||
test_inverse_permutation_still_works(8, 80)
|
||||
test_inverse_permutation_still_works(8, 80)
|
||||
|
||||
test_dead_code_elimination()
|
||||
Reference in New Issue
Block a user