Fix applyshuffle with dead code elimination and more flexible Matrix secure_permute

This commit is contained in:
Vincent Ehrmanntraut
2024-12-20 08:32:09 +01:00
parent 0db8d7a419
commit 337ba94d06
3 changed files with 47 additions and 8 deletions

View File

@@ -1,4 +1,5 @@
from Compiler.library import get_number_of_players
from Compiler.sqrt_oram import n_parallel
from Compiler.util import if_else
@@ -138,10 +139,13 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(5)
p3 = sint.get_secure_shuffle(5)
start_timer(timer_base + 1)
m1.secure_permute(p1)
@@ -152,11 +156,19 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
start_timer(timer_base + 3)
m3.secure_permute(p2, n_threads=3)
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
m4.secure_permute(p2, n_threads=3, n_parallel=2)
stop_timer(timer_base + 4)
start_timer(timer_base + 5)
m5.secure_permute(p3, n_threads=3, n_parallel=3)
stop_timer(timer_base + 5)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
m3 = m3.reveal()
m4 = m4.reveal()
m5 = m5.reveal()
print_ln("Permuted m1:")
for row in m1:
@@ -173,6 +185,16 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
print_ln("%s", row)
test_permuted_matrix(m3, p2)
print_ln("Permuted m4 (should be equal to m2):")
for row in m4:
print_ln("%s", row)
test_permuted_matrix(m4, p2)
print_ln("Permuted m5:")
for row in m5:
print_ln("%s", row)
test_permuted_matrix(m5, p3)
def test_secure_shuffle_still_works(size: int, timer_base: int):
arr = Array.create_from([sint(i) for i in range(size)])
start_timer(timer_base)
@@ -226,6 +248,11 @@ def test_inverse_permutation_still_works(size, timer_base: int):
def test_dead_code_elimination():
vector = sint([0,2,4,6,5,3,1])
handle = sint.get_secure_shuffle(7)
print_ln('%s', vector.secure_permute(handle).reveal())
test_allocator()
test_case([10, 15], 10)
@@ -239,4 +266,6 @@ test_permute_matrix(60)
test_permute_matrix(70, value_type=sfix)
test_secure_shuffle_still_works(32, 80)
test_inverse_permutation_still_works(8, 80)
test_inverse_permutation_still_works(8, 80)
test_dead_code_elimination()