Files
MP-SPDZ/Programs/Source/test_permute.mpc
Vincent Ehrmanntraut 70fcbc872c Fix printing of floats
2024-12-12 09:52:59 +01:00

173 lines
5.5 KiB
Plaintext

def test_allocator():
arr1 = sint.Array(5)
arr2 = sint.Array(10)
arr3 = sint.Array(20)
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(10)
p3 = sint.get_secure_shuffle(20)
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
arr1.secure_permute(p1)
arr2[0] = arr1[0]
arr2.secure_permute(p2)
arr3.secure_permute(p3)
def test_case(permutation_sizes, timer_base: int | None = None):
if timer_base is not None:
start_timer(timer_base + 0)
arrays = []
permutations = []
for size in permutation_sizes:
arrays.append(Array.create_from([sint(i) for i in range(size)]))
permutations.append(sint.get_secure_shuffle(size))
if timer_base is not None:
stop_timer(timer_base + 0)
start_timer(timer_base + 1)
for arr, p in zip(arrays, permutations):
arr.secure_permute(p)
if timer_base is not None:
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
for i, arr in enumerate(arrays):
revealed = arr.reveal()
print_ln("%s", revealed)
n_matched = cint(0)
@for_range(len(arr))
def count_matches(i: cint) -> None:
n_matched.update(n_matched + (revealed[i] == i))
@if_(n_matched == len(arr))
def didnt_permute():
print_ln("Permutation potentially didn't work (permutation might have been identity by chance).")
crash()
if timer_base is not None:
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
for arr, p in zip(arrays, permutations):
arr.secure_permute(p, reverse=True)
if timer_base is not None:
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
for i, arr in enumerate(arrays):
revealed = arr.reveal()
print_ln("%s", revealed)
@for_range(len(arr))
def test_is_original(i: cint) -> None:
@if_(revealed[i] != i)
def fail():
print_ln("Failed to invert permutation!")
crash()
if timer_base is not None:
stop_timer(timer_base + 4)
def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], timer_base: int) -> None:
start_timer(timer_base)
permutations = []
for permutation_size in sizes:
permutations.append(sint.get_secure_shuffle(permutation_size))
stop_timer(timer_base)
start_timer(timer_base + 1)
arrs_to_permute_sequentially = []
arrs_to_permute_parallely = []
for permutation_size in sizes:
arrs_to_permute_sequentially.append(Array.create_from([sint(i) for i in range(permutation_size)]))
arrs_to_permute_parallely.append(Array.create_from([sint(i) for i in range(permutation_size)]))
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
for arr, perm in zip(arrs_to_permute_sequentially, permutations):
arr.secure_permute(perm)
break_point()
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
for arr, perm in zip(arrs_to_permute_parallely, permutations):
arr.secure_permute(perm)
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
arrs_to_permute_sequentially = [arr.reveal() for arr in arrs_to_permute_sequentially]
arrs_to_permute_parallely = [arr.reveal() for arr in arrs_to_permute_parallely]
stop_timer(timer_base + 4)
for (arr_seq, arr_par) in zip(arrs_to_permute_sequentially, arrs_to_permute_parallely):
print_ln("Sequential: %s", arr_seq)
print_ln("Parallel: %s", arr_par)
@for_range(len(arr_seq))
def test_equals(i: cint) -> None:
@if_(arr_seq[i] != arr_par[i])
def fail():
print_ln("Sequentially permuted arrays to not match the parallely permuted arrays.")
crash()
def test_permute_matrix(timer_base: int, value_type=sint) -> None:
def test_permuted_matrix(m, p):
permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])])
permuted_indices.secure_permute(p, reverse=True)
permuted_indices = permuted_indices.reveal()
@for_range(m.sizes[0])
def check_row(i):
@for_range(m.sizes[1])
def check_entry(j):
@if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j))
def fail():
print_ln("Matrix permuted unexpectedly.")
crash()
print_ln("Pre-create matrix")
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(5)
start_timer(timer_base + 1)
m1.secure_permute(p1, n_parallel=1)
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
m2.secure_permute(p2, n_parallel=1)
stop_timer(timer_base + 2)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
print_ln("Permuted m1:")
for row in m1:
print_ln("%s", row)
test_permuted_matrix(m1, p1)
print_ln("Permuted m2:")
for row in m2:
print_ln("%s", row)
test_permuted_matrix(m2, p2)
test_allocator()
test_case([5,10], 10)
test_case([5, 10, 15, 20], 20)
test_case([4,8,16], 30)
test_case([5], 40)
test_parallel_permutation_equals_sequential_permutation([5,10],50)
test_permute_matrix(60)
test_permute_matrix(70, value_type=sfix)