mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Fix applyshuffle with dead code elimination and more flexible Matrix secure_permute
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from Compiler.library import get_number_of_players
|
||||
from Compiler.sqrt_oram import n_parallel
|
||||
from Compiler.util import if_else
|
||||
|
||||
|
||||
@@ -138,10 +139,13 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
|
||||
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
|
||||
print_ln("post-create matrix")
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
p2 = sint.get_secure_shuffle(5)
|
||||
p3 = sint.get_secure_shuffle(5)
|
||||
|
||||
start_timer(timer_base + 1)
|
||||
m1.secure_permute(p1)
|
||||
@@ -152,11 +156,19 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
start_timer(timer_base + 3)
|
||||
m3.secure_permute(p2, n_threads=3)
|
||||
stop_timer(timer_base + 3)
|
||||
start_timer(timer_base + 4)
|
||||
m4.secure_permute(p2, n_threads=3, n_parallel=2)
|
||||
stop_timer(timer_base + 4)
|
||||
start_timer(timer_base + 5)
|
||||
m5.secure_permute(p3, n_threads=3, n_parallel=3)
|
||||
stop_timer(timer_base + 5)
|
||||
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
|
||||
|
||||
m1 = m1.reveal()
|
||||
m2 = m2.reveal()
|
||||
m3 = m3.reveal()
|
||||
m4 = m4.reveal()
|
||||
m5 = m5.reveal()
|
||||
|
||||
print_ln("Permuted m1:")
|
||||
for row in m1:
|
||||
@@ -173,6 +185,16 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m3, p2)
|
||||
|
||||
print_ln("Permuted m4 (should be equal to m2):")
|
||||
for row in m4:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m4, p2)
|
||||
|
||||
print_ln("Permuted m5:")
|
||||
for row in m5:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m5, p3)
|
||||
|
||||
def test_secure_shuffle_still_works(size: int, timer_base: int):
|
||||
arr = Array.create_from([sint(i) for i in range(size)])
|
||||
start_timer(timer_base)
|
||||
@@ -226,6 +248,11 @@ def test_inverse_permutation_still_works(size, timer_base: int):
|
||||
|
||||
|
||||
|
||||
def test_dead_code_elimination():
|
||||
vector = sint([0,2,4,6,5,3,1])
|
||||
handle = sint.get_secure_shuffle(7)
|
||||
print_ln('%s', vector.secure_permute(handle).reveal())
|
||||
|
||||
|
||||
test_allocator()
|
||||
test_case([10, 15], 10)
|
||||
@@ -239,4 +266,6 @@ test_permute_matrix(60)
|
||||
test_permute_matrix(70, value_type=sfix)
|
||||
|
||||
test_secure_shuffle_still_works(32, 80)
|
||||
test_inverse_permutation_still_works(8, 80)
|
||||
test_inverse_permutation_still_works(8, 80)
|
||||
|
||||
test_dead_code_elimination()
|
||||
Reference in New Issue
Block a user