Fix applyshuffle with dead code elimination and more flexible Matrix secure_permute

This commit is contained in:
Vincent Ehrmanntraut
2024-12-20 08:32:09 +01:00
parent 0db8d7a419
commit 337ba94d06
3 changed files with 47 additions and 8 deletions

View File

@@ -2763,6 +2763,7 @@ class applyshuffle(shuffle_base, base.Mergeable):
__slots__ = []
code = base.opcodes['APPLYSHUFFLE']
arg_format = itertools.cycle(['int', 'sw','s','int','ci','int'])
is_vec = lambda self: True # Ensures dead-code elimination works.
def __init__(self, *args, **kwargs):
super(applyshuffle, self).__init__(*args, **kwargs)

View File

@@ -7274,7 +7274,7 @@ class SubMultiArray(_vectorizable):
self.secure_permute(perm)
delshuffle(perm)
def secure_permute(self, permutation, reverse=False, n_threads=None):
def secure_permute(self, permutation, reverse=False, n_threads=None, n_parallel=None):
""" Securely permute rows (first index). See
:py:func:`secure_shuffle` for references.
@@ -7282,7 +7282,8 @@ class SubMultiArray(_vectorizable):
:param reverse: whether to apply inverse (default: False)
"""
if self.value_type == sint:
if (self.value_type == sint) and (n_threads is None):
# Use only a single shuffle instruction if applicable and permutation is single-threaded anyway.
unit_size = self.get_part_size()
n = self.sizes[0] * unit_size
res = sint(size=n)
@@ -7291,11 +7292,19 @@ class SubMultiArray(_vectorizable):
else:
if n_threads is not None:
permutation = MemValue(permutation)
@library.for_range_multithread(n_threads, 1, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
if n_parallel is None:
@library.for_range_opt_multithread(n_threads, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
else:
@library.for_range_multithread(n_threads, n_parallel, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
def sort(self, key_indices=None, n_bits=None, batcher=False):
""" Sort sub-arrays (different first index) in place.

View File

@@ -1,4 +1,5 @@
from Compiler.library import get_number_of_players
from Compiler.sqrt_oram import n_parallel
from Compiler.util import if_else
@@ -138,10 +139,13 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(5)
p3 = sint.get_secure_shuffle(5)
start_timer(timer_base + 1)
m1.secure_permute(p1)
@@ -152,11 +156,19 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
start_timer(timer_base + 3)
m3.secure_permute(p2, n_threads=3)
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
m4.secure_permute(p2, n_threads=3, n_parallel=2)
stop_timer(timer_base + 4)
start_timer(timer_base + 5)
m5.secure_permute(p3, n_threads=3, n_parallel=3)
stop_timer(timer_base + 5)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
m3 = m3.reveal()
m4 = m4.reveal()
m5 = m5.reveal()
print_ln("Permuted m1:")
for row in m1:
@@ -173,6 +185,16 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
print_ln("%s", row)
test_permuted_matrix(m3, p2)
print_ln("Permuted m4 (should be equal to m2):")
for row in m4:
print_ln("%s", row)
test_permuted_matrix(m4, p2)
print_ln("Permuted m5:")
for row in m5:
print_ln("%s", row)
test_permuted_matrix(m5, p3)
def test_secure_shuffle_still_works(size: int, timer_base: int):
arr = Array.create_from([sint(i) for i in range(size)])
start_timer(timer_base)
@@ -226,6 +248,11 @@ def test_inverse_permutation_still_works(size, timer_base: int):
def test_dead_code_elimination():
vector = sint([0,2,4,6,5,3,1])
handle = sint.get_secure_shuffle(7)
print_ln('%s', vector.secure_permute(handle).reveal())
test_allocator()
test_case([10, 15], 10)
@@ -239,4 +266,6 @@ test_permute_matrix(60)
test_permute_matrix(70, value_type=sfix)
test_secure_shuffle_still_works(32, 80)
test_inverse_permutation_still_works(8, 80)
test_inverse_permutation_still_works(8, 80)
test_dead_code_elimination()