From 337ba94d069437e138e0f70f6261cc9948354f18 Mon Sep 17 00:00:00 2001 From: Vincent Ehrmanntraut Date: Fri, 20 Dec 2024 08:32:09 +0100 Subject: [PATCH] Fix applyshuffle with dead code elimination and more flexible Matrix secure_permute --- Compiler/instructions.py | 1 + Compiler/types.py | 23 ++++++++++++++++------- Programs/Source/test_permute.mpc | 31 ++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index fb4ca710..bc3f1d48 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2763,6 +2763,7 @@ class applyshuffle(shuffle_base, base.Mergeable): __slots__ = [] code = base.opcodes['APPLYSHUFFLE'] arg_format = itertools.cycle(['int', 'sw','s','int','ci','int']) + is_vec = lambda self: True # Ensures dead-code elimination works. def __init__(self, *args, **kwargs): super(applyshuffle, self).__init__(*args, **kwargs) diff --git a/Compiler/types.py b/Compiler/types.py index 23d75951..b7ce0a50 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -7274,7 +7274,7 @@ class SubMultiArray(_vectorizable): self.secure_permute(perm) delshuffle(perm) - def secure_permute(self, permutation, reverse=False, n_threads=None): + def secure_permute(self, permutation, reverse=False, n_threads=None, n_parallel=None): """ Securely permute rows (first index). See :py:func:`secure_shuffle` for references. @@ -7282,7 +7282,8 @@ class SubMultiArray(_vectorizable): :param reverse: whether to apply inverse (default: False) """ - if self.value_type == sint: + if (self.value_type == sint) and (n_threads is None): + # Use only a single shuffle instruction if applicable and permutation is single-threaded anyway. unit_size = self.get_part_size() n = self.sizes[0] * unit_size res = sint(size=n) @@ -7291,11 +7292,19 @@ class SubMultiArray(_vectorizable): else: if n_threads is not None: permutation = MemValue(permutation) - @library.for_range_multithread(n_threads, 1, self.get_part_size()) - def iter(i): - column = self.get_column(i) - column = column.secure_permute(permutation, reverse=reverse) - self.set_column(i, column) + + if n_parallel is None: + @library.for_range_opt_multithread(n_threads, self.get_part_size()) + def iter(i): + column = self.get_column(i) + column = column.secure_permute(permutation, reverse=reverse) + self.set_column(i, column) + else: + @library.for_range_multithread(n_threads, n_parallel, self.get_part_size()) + def iter(i): + column = self.get_column(i) + column = column.secure_permute(permutation, reverse=reverse) + self.set_column(i, column) def sort(self, key_indices=None, n_bits=None, batcher=False): """ Sort sub-arrays (different first index) in place. diff --git a/Programs/Source/test_permute.mpc b/Programs/Source/test_permute.mpc index d3334ba5..c7778585 100644 --- a/Programs/Source/test_permute.mpc +++ b/Programs/Source/test_permute.mpc @@ -1,4 +1,5 @@ from Compiler.library import get_number_of_players +from Compiler.sqrt_oram import n_parallel from Compiler.util import if_else @@ -138,10 +139,13 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None: m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)]) m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) + m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) + m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)]) print_ln("post-create matrix") p1 = sint.get_secure_shuffle(5) p2 = sint.get_secure_shuffle(5) + p3 = sint.get_secure_shuffle(5) start_timer(timer_base + 1) m1.secure_permute(p1) @@ -152,11 +156,19 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None: start_timer(timer_base + 3) m3.secure_permute(p2, n_threads=3) stop_timer(timer_base + 3) + start_timer(timer_base + 4) + m4.secure_permute(p2, n_threads=3, n_parallel=2) + stop_timer(timer_base + 4) + start_timer(timer_base + 5) + m5.secure_permute(p3, n_threads=3, n_parallel=3) + stop_timer(timer_base + 5) print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.") m1 = m1.reveal() m2 = m2.reveal() m3 = m3.reveal() + m4 = m4.reveal() + m5 = m5.reveal() print_ln("Permuted m1:") for row in m1: @@ -173,6 +185,16 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None: print_ln("%s", row) test_permuted_matrix(m3, p2) + print_ln("Permuted m4 (should be equal to m2):") + for row in m4: + print_ln("%s", row) + test_permuted_matrix(m4, p2) + + print_ln("Permuted m5:") + for row in m5: + print_ln("%s", row) + test_permuted_matrix(m5, p3) + def test_secure_shuffle_still_works(size: int, timer_base: int): arr = Array.create_from([sint(i) for i in range(size)]) start_timer(timer_base) @@ -226,6 +248,11 @@ def test_inverse_permutation_still_works(size, timer_base: int): +def test_dead_code_elimination(): + vector = sint([0,2,4,6,5,3,1]) + handle = sint.get_secure_shuffle(7) + print_ln('%s', vector.secure_permute(handle).reveal()) + test_allocator() test_case([10, 15], 10) @@ -239,4 +266,6 @@ test_permute_matrix(60) test_permute_matrix(70, value_type=sfix) test_secure_shuffle_still_works(32, 80) -test_inverse_permutation_still_works(8, 80) \ No newline at end of file +test_inverse_permutation_still_works(8, 80) + +test_dead_code_elimination() \ No newline at end of file