mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
271 lines
8.7 KiB
Plaintext
271 lines
8.7 KiB
Plaintext
from Compiler.library import get_number_of_players
|
|
from Compiler.sqrt_oram import n_parallel
|
|
from Compiler.util import if_else
|
|
|
|
|
|
def test_allocator():
|
|
arr1 = sint.Array(5)
|
|
arr2 = sint.Array(10)
|
|
arr3 = sint.Array(20)
|
|
|
|
p1 = sint.get_secure_shuffle(5)
|
|
p2 = sint.get_secure_shuffle(10)
|
|
p3 = sint.get_secure_shuffle(20)
|
|
|
|
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
|
|
arr1.secure_permute(p1)
|
|
arr2[0] = arr1[0]
|
|
arr2.secure_permute(p2)
|
|
arr3.secure_permute(p3)
|
|
|
|
|
|
def test_case(permutation_sizes, timer_base: int | None = None):
|
|
if timer_base is not None:
|
|
start_timer(timer_base + 0)
|
|
arrays = []
|
|
permutations = []
|
|
for size in permutation_sizes:
|
|
arrays.append(Array.create_from([sint(i) for i in range(size)]))
|
|
permutations.append(sint.get_secure_shuffle(size))
|
|
if timer_base is not None:
|
|
stop_timer(timer_base + 0)
|
|
start_timer(timer_base + 1)
|
|
|
|
for arr, p in zip(arrays, permutations):
|
|
arr.secure_permute(p)
|
|
|
|
if timer_base is not None:
|
|
stop_timer(timer_base + 1)
|
|
start_timer(timer_base + 2)
|
|
|
|
for i, arr in enumerate(arrays):
|
|
revealed = arr.reveal()
|
|
print_ln("%s", revealed)
|
|
|
|
n_matched = cint(0)
|
|
@for_range(len(arr))
|
|
def count_matches(i: cint) -> None:
|
|
n_matched.update(n_matched + (revealed[i] == i))
|
|
@if_(n_matched == len(arr))
|
|
def didnt_permute():
|
|
print_ln("Permutation potentially didn't work (permutation might have been identity by chance).")
|
|
crash()
|
|
|
|
if timer_base is not None:
|
|
stop_timer(timer_base + 2)
|
|
start_timer(timer_base + 3)
|
|
|
|
for arr, p in zip(arrays, permutations):
|
|
arr.secure_permute(p, reverse=True)
|
|
|
|
if timer_base is not None:
|
|
stop_timer(timer_base + 3)
|
|
start_timer(timer_base + 4)
|
|
|
|
for i, arr in enumerate(arrays):
|
|
revealed = arr.reveal()
|
|
print_ln("%s", revealed)
|
|
|
|
@for_range(len(arr))
|
|
def test_is_original(i: cint) -> None:
|
|
@if_(revealed[i] != i)
|
|
def fail():
|
|
print_ln("Failed to invert permutation!")
|
|
crash()
|
|
|
|
if timer_base is not None:
|
|
stop_timer(timer_base + 4)
|
|
|
|
|
|
def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], timer_base: int) -> None:
|
|
start_timer(timer_base)
|
|
permutations = []
|
|
for permutation_size in sizes:
|
|
permutations.append(sint.get_secure_shuffle(permutation_size))
|
|
stop_timer(timer_base)
|
|
|
|
start_timer(timer_base + 1)
|
|
arrs_to_permute_sequentially = []
|
|
arrs_to_permute_parallely = []
|
|
for permutation_size in sizes:
|
|
arrs_to_permute_sequentially.append(Array.create_from([sint(i) for i in range(permutation_size)]))
|
|
arrs_to_permute_parallely.append(Array.create_from([sint(i) for i in range(permutation_size)]))
|
|
stop_timer(timer_base + 1)
|
|
|
|
start_timer(timer_base + 2)
|
|
for arr, perm in zip(arrs_to_permute_sequentially, permutations):
|
|
arr.secure_permute(perm)
|
|
break_point()
|
|
stop_timer(timer_base + 2)
|
|
|
|
start_timer(timer_base + 3)
|
|
for arr, perm in zip(arrs_to_permute_parallely, permutations):
|
|
arr.secure_permute(perm)
|
|
stop_timer(timer_base + 3)
|
|
|
|
start_timer(timer_base + 4)
|
|
arrs_to_permute_sequentially = [arr.reveal() for arr in arrs_to_permute_sequentially]
|
|
arrs_to_permute_parallely = [arr.reveal() for arr in arrs_to_permute_parallely]
|
|
stop_timer(timer_base + 4)
|
|
|
|
for (arr_seq, arr_par) in zip(arrs_to_permute_sequentially, arrs_to_permute_parallely):
|
|
print_ln("Sequential: %s", arr_seq)
|
|
print_ln("Parallel: %s", arr_par)
|
|
|
|
@for_range(len(arr_seq))
|
|
def test_equals(i: cint) -> None:
|
|
@if_(arr_seq[i] != arr_par[i])
|
|
def fail():
|
|
print_ln("Sequentially permuted arrays to not match the parallely permuted arrays.")
|
|
crash()
|
|
|
|
|
|
def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
|
def test_permuted_matrix(m, p):
|
|
permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])])
|
|
permuted_indices.secure_permute(p, reverse=True)
|
|
permuted_indices = permuted_indices.reveal()
|
|
|
|
@for_range(m.sizes[0])
|
|
def check_row(i):
|
|
@for_range(m.sizes[1])
|
|
def check_entry(j):
|
|
@if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j))
|
|
def fail():
|
|
print_ln("Matrix permuted unexpectedly.")
|
|
crash()
|
|
|
|
print_ln("Pre-create matrix")
|
|
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
|
|
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
|
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
|
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
|
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
|
|
print_ln("post-create matrix")
|
|
|
|
p1 = sint.get_secure_shuffle(5)
|
|
p2 = sint.get_secure_shuffle(5)
|
|
p3 = sint.get_secure_shuffle(5)
|
|
|
|
start_timer(timer_base + 1)
|
|
m1.secure_permute(p1)
|
|
stop_timer(timer_base + 1)
|
|
start_timer(timer_base + 2)
|
|
m2.secure_permute(p2)
|
|
stop_timer(timer_base + 2)
|
|
start_timer(timer_base + 3)
|
|
m3.secure_permute(p2, n_threads=3)
|
|
stop_timer(timer_base + 3)
|
|
start_timer(timer_base + 4)
|
|
m4.secure_permute(p2, n_threads=3, n_parallel=2)
|
|
stop_timer(timer_base + 4)
|
|
start_timer(timer_base + 5)
|
|
m5.secure_permute(p3, n_threads=3, n_parallel=3)
|
|
stop_timer(timer_base + 5)
|
|
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
|
|
|
|
m1 = m1.reveal()
|
|
m2 = m2.reveal()
|
|
m3 = m3.reveal()
|
|
m4 = m4.reveal()
|
|
m5 = m5.reveal()
|
|
|
|
print_ln("Permuted m1:")
|
|
for row in m1:
|
|
print_ln("%s", row)
|
|
test_permuted_matrix(m1, p1)
|
|
|
|
print_ln("Permuted m2:")
|
|
for row in m2:
|
|
print_ln("%s", row)
|
|
test_permuted_matrix(m2, p2)
|
|
|
|
print_ln("Permuted m3 (should be equal to m2):")
|
|
for row in m3:
|
|
print_ln("%s", row)
|
|
test_permuted_matrix(m3, p2)
|
|
|
|
print_ln("Permuted m4 (should be equal to m2):")
|
|
for row in m4:
|
|
print_ln("%s", row)
|
|
test_permuted_matrix(m4, p2)
|
|
|
|
print_ln("Permuted m5:")
|
|
for row in m5:
|
|
print_ln("%s", row)
|
|
test_permuted_matrix(m5, p3)
|
|
|
|
def test_secure_shuffle_still_works(size: int, timer_base: int):
|
|
arr = Array.create_from([sint(i) for i in range(size)])
|
|
start_timer(timer_base)
|
|
arr.secure_shuffle()
|
|
stop_timer(timer_base)
|
|
arr = arr.reveal()
|
|
|
|
n_matched = cint(0)
|
|
@for_range(len(arr))
|
|
def count_matches(i: cint) -> None:
|
|
n_matched.update(n_matched + (arr[i] == i))
|
|
@if_(n_matched == len(arr))
|
|
def didnt_permute():
|
|
print_ln("Shuffle potentially didn't work (permutation might have been identity by chance).")
|
|
crash()
|
|
|
|
def test_inverse_permutation_still_works(size, timer_base: int):
|
|
@if_e(get_number_of_players() == 2)
|
|
def test():
|
|
program.use_invperm(True)
|
|
|
|
arr = Array.create_from([sint(i) for i in range(size)])
|
|
p_as_arr = Array.create_from(arr[:])
|
|
p_inv_as_arr = Array.create_from(arr[:])
|
|
start_timer(timer_base + 1)
|
|
p1 = sint.get_secure_shuffle(size)
|
|
p_as_arr.secure_permute(p1, reverse=True)
|
|
p_inv_as_arr.secure_permute(p1)
|
|
p_inv_as_arr = p_inv_as_arr.reveal()
|
|
stop_timer(timer_base + 1)
|
|
|
|
start_timer(timer_base + 2)
|
|
p_inv2 = Array.create_from(p_as_arr[:].inverse_permutation())
|
|
# p_inv2.inverse_permutation()
|
|
p_inv2 = p_inv2.reveal()
|
|
stop_timer(timer_base + 2)
|
|
|
|
print_ln("Permutation: %s", p_as_arr.reveal())
|
|
print_ln("Inverse from secure_permute: %s", p_inv_as_arr)
|
|
print_ln("Inverse from inverse_permut: %s", p_inv2)
|
|
|
|
@for_range(size)
|
|
def check(i):
|
|
@if_(p_inv_as_arr[i] != p_inv2[i])
|
|
def fail():
|
|
print_ln("Inverse permutation don't match.")
|
|
crash()
|
|
@else_
|
|
def _():
|
|
print_ln("Inverse permutation is only tested in 2-party computation.")
|
|
|
|
|
|
|
|
def test_dead_code_elimination():
|
|
vector = sint([0,2,4,6,5,3,1])
|
|
handle = sint.get_secure_shuffle(7)
|
|
print_ln('%s', vector.secure_permute(handle).reveal())
|
|
|
|
|
|
test_allocator()
|
|
test_case([10, 15], 10)
|
|
test_case([10, 15, 20], 20)
|
|
test_case([16,32], 30)
|
|
test_case([256], 40)
|
|
|
|
test_parallel_permutation_equals_sequential_permutation([5,10],50)
|
|
|
|
test_permute_matrix(60)
|
|
test_permute_matrix(70, value_type=sfix)
|
|
|
|
test_secure_shuffle_still_works(32, 80)
|
|
test_inverse_permutation_still_works(8, 80)
|
|
|
|
test_dead_code_elimination() |