from Compiler.library import get_number_of_players from Compiler.sqrt_oram import n_parallel from Compiler.util import if_else def test_allocator(): arr1 = sint.Array(5) arr2 = sint.Array(10) arr3 = sint.Array(20) p1 = sint.get_secure_shuffle(5) p2 = sint.get_secure_shuffle(10) p3 = sint.get_secure_shuffle(20) # Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward. arr1.secure_permute(p1) arr2[0] = arr1[0] arr2.secure_permute(p2) arr3.secure_permute(p3) def test_case(permutation_sizes, timer_base: int | None = None): if timer_base is not None: start_timer(timer_base + 0) arrays = [] permutations = [] for size in permutation_sizes: arrays.append(Array.create_from([sint(i) for i in range(size)])) permutations.append(sint.get_secure_shuffle(size)) if timer_base is not None: stop_timer(timer_base + 0) start_timer(timer_base + 1) for arr, p in zip(arrays, permutations): arr.secure_permute(p) if timer_base is not None: stop_timer(timer_base + 1) start_timer(timer_base + 2) for i, arr in enumerate(arrays): revealed = arr.reveal() print_ln("%s", revealed) n_matched = cint(0) @for_range(len(arr)) def count_matches(i: cint) -> None: n_matched.update(n_matched + (revealed[i] == i)) @if_(n_matched == len(arr)) def didnt_permute(): print_ln("Permutation potentially didn't work (permutation might have been identity by chance).") crash() if timer_base is not None: stop_timer(timer_base + 2) start_timer(timer_base + 3) for arr, p in zip(arrays, permutations): arr.secure_permute(p, reverse=True) if timer_base is not None: stop_timer(timer_base + 3) start_timer(timer_base + 4) for i, arr in enumerate(arrays): revealed = arr.reveal() print_ln("%s", revealed) @for_range(len(arr)) def test_is_original(i: cint) -> None: @if_(revealed[i] != i) def fail(): print_ln("Failed to invert permutation!") crash() if timer_base is not None: stop_timer(timer_base + 4) def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], timer_base: int) -> None: start_timer(timer_base) permutations = [] for permutation_size in sizes: permutations.append(sint.get_secure_shuffle(permutation_size)) stop_timer(timer_base) start_timer(timer_base + 1) arrs_to_permute_sequentially = [] arrs_to_permute_parallely = [] for permutation_size in sizes: arrs_to_permute_sequentially.append(Array.create_from([sint(i) for i in range(permutation_size)])) arrs_to_permute_parallely.append(Array.create_from([sint(i) for i in range(permutation_size)])) stop_timer(timer_base + 1) start_timer(timer_base + 2) for arr, perm in zip(arrs_to_permute_sequentially, permutations): arr.secure_permute(perm) break_point() stop_timer(timer_base + 2) start_timer(timer_base + 3) for arr, perm in zip(arrs_to_permute_parallely, permutations): arr.secure_permute(perm) stop_timer(timer_base + 3) start_timer(timer_base + 4) arrs_to_permute_sequentially = [arr.reveal() for arr in arrs_to_permute_sequentially] arrs_to_permute_parallely = [arr.reveal() for arr in arrs_to_permute_parallely] stop_timer(timer_base + 4) for (arr_seq, arr_par) in zip(arrs_to_permute_sequentially, arrs_to_permute_parallely): print_ln("Sequential: %s", arr_seq) print_ln("Parallel: %s", arr_par) @for_range(len(arr_seq)) def test_equals(i: cint) -> None: @if_(arr_seq[i] != arr_par[i]) def fail(): print_ln("Sequentially permuted arrays to not match the parallely permuted arrays.") crash() def test_permute_matrix(timer_base: int, value_type=sint) -> None: def test_permuted_matrix(m, p): permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])]) permuted_indices.secure_permute(p, reverse=True) permuted_indices = permuted_indices.reveal() @for_range(m.sizes[0]) def check_row(i): @for_range(m.sizes[1]) def check_entry(j): @if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j)) def fail(): print_ln("Matrix permuted unexpectedly.") crash() print_ln("Pre-create matrix") m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)]) m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)]) print_ln("post-create matrix") p1 = sint.get_secure_shuffle(5) p2 = sint.get_secure_shuffle(5) p3 = sint.get_secure_shuffle(5) start_timer(timer_base + 1) m1.secure_permute(p1) stop_timer(timer_base + 1) start_timer(timer_base + 2) m2.secure_permute(p2) stop_timer(timer_base + 2) start_timer(timer_base + 3) m3.secure_permute(p2, n_threads=3) stop_timer(timer_base + 3) start_timer(timer_base + 4) m4.secure_permute(p2, n_threads=3, n_parallel=2) stop_timer(timer_base + 4) start_timer(timer_base + 5) m5.secure_permute(p3, n_threads=3, n_parallel=3) stop_timer(timer_base + 5) print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.") m1 = m1.reveal() m2 = m2.reveal() m3 = m3.reveal() m4 = m4.reveal() m5 = m5.reveal() print_ln("Permuted m1:") for row in m1: print_ln("%s", row) test_permuted_matrix(m1, p1) print_ln("Permuted m2:") for row in m2: print_ln("%s", row) test_permuted_matrix(m2, p2) print_ln("Permuted m3 (should be equal to m2):") for row in m3: print_ln("%s", row) test_permuted_matrix(m3, p2) print_ln("Permuted m4 (should be equal to m2):") for row in m4: print_ln("%s", row) test_permuted_matrix(m4, p2) print_ln("Permuted m5:") for row in m5: print_ln("%s", row) test_permuted_matrix(m5, p3) def test_secure_shuffle_still_works(size: int, timer_base: int): arr = Array.create_from([sint(i) for i in range(size)]) start_timer(timer_base) arr.secure_shuffle() stop_timer(timer_base) arr = arr.reveal() n_matched = cint(0) @for_range(len(arr)) def count_matches(i: cint) -> None: n_matched.update(n_matched + (arr[i] == i)) @if_(n_matched == len(arr)) def didnt_permute(): print_ln("Shuffle potentially didn't work (permutation might have been identity by chance).") crash() def test_inverse_permutation_still_works(size, timer_base: int): @if_e(get_number_of_players() == 2) def test(): program.use_invperm(True) arr = Array.create_from([sint(i) for i in range(size)]) p_as_arr = Array.create_from(arr[:]) p_inv_as_arr = Array.create_from(arr[:]) start_timer(timer_base + 1) p1 = sint.get_secure_shuffle(size) p_as_arr.secure_permute(p1, reverse=True) p_inv_as_arr.secure_permute(p1) p_inv_as_arr = p_inv_as_arr.reveal() stop_timer(timer_base + 1) start_timer(timer_base + 2) p_inv2 = Array.create_from(p_as_arr[:].inverse_permutation()) # p_inv2.inverse_permutation() p_inv2 = p_inv2.reveal() stop_timer(timer_base + 2) print_ln("Permutation: %s", p_as_arr.reveal()) print_ln("Inverse from secure_permute: %s", p_inv_as_arr) print_ln("Inverse from inverse_permut: %s", p_inv2) @for_range(size) def check(i): @if_(p_inv_as_arr[i] != p_inv2[i]) def fail(): print_ln("Inverse permutation don't match.") crash() @else_ def _(): print_ln("Inverse permutation is only tested in 2-party computation.") def test_dead_code_elimination(): vector = sint([0,2,4,6,5,3,1]) handle = sint.get_secure_shuffle(7) print_ln('%s', vector.secure_permute(handle).reveal()) test_allocator() test_case([10, 15], 10) test_case([10, 15, 20], 20) test_case([16,32], 30) test_case([256], 40) test_parallel_permutation_equals_sequential_permutation([5,10],50) test_permute_matrix(60) test_permute_matrix(70, value_type=sfix) test_secure_shuffle_still_works(32, 80) test_inverse_permutation_still_works(8, 80) test_dead_code_elimination()