diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 93f21266..bc3f1d48 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -2749,9 +2749,10 @@ class gensecshuffle(shuffle_base): def add_usage(self, req_node): self.add_gen_usage(req_node, self.args[1]) -class applyshuffle(base.VectorInstruction, shuffle_base): +class applyshuffle(shuffle_base, base.Mergeable): """ Generate secure shuffle to bit used several times. + :param: vector size (int) :param: destination (sint) :param: source (sint) :param: number of elements to be treated as one (int) @@ -2761,15 +2762,20 @@ class applyshuffle(base.VectorInstruction, shuffle_base): """ __slots__ = [] code = base.opcodes['APPLYSHUFFLE'] - arg_format = ['sw','s','int','ci','int'] + arg_format = itertools.cycle(['int', 'sw','s','int','ci','int']) + is_vec = lambda self: True # Ensures dead-code elimination works. def __init__(self, *args, **kwargs): super(applyshuffle, self).__init__(*args, **kwargs) - assert len(args[0]) == len(args[1]) - assert len(args[0]) > args[2] + assert (len(args) % 6) == 0 + for i in range(0, len(args), 6): + assert args[i] == len(args[i+1]) + assert args[i] == len(args[i + 2]) + assert args[i] > args[i + 3] def add_usage(self, req_node): - self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) + for i in range(0, len(self.args), 6): + self.add_apply_usage(req_node, self.args[i], self.args[i + 3]) class delshuffle(base.Instruction): """ Delete secure shuffle. diff --git a/Compiler/types.py b/Compiler/types.py index 1e111341..cafeccec 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -3122,7 +3122,7 @@ class sint(_secret, _int): @read_mem_value def secure_permute(self, shuffle, unit_size=1, reverse=False): res = sint(size=self.size) - applyshuffle(res, self, unit_size, shuffle, reverse) + applyshuffle(self.size, res, self, unit_size, shuffle, reverse) return res def inverse_permutation(self): @@ -7274,20 +7274,38 @@ class SubMultiArray(_vectorizable): self.secure_permute(perm) delshuffle(perm) - def secure_permute(self, permutation, reverse=False, n_threads=None): + def secure_permute(self, permutation, reverse=False, n_threads=None, n_parallel=None): """ Securely permute rows (first index). See :py:func:`secure_shuffle` for references. :param permutation: output of :py:func:`sint.get_secure_shuffle()` :param reverse: whether to apply inverse (default: False) - + :param n_threads: How many threads should be used. Will not multithread when set to None (default: None) + :param n_parallel: How many columns should be permuted in parallel. Will use the compiler's optimization budget is set to None. (default: None). """ - if n_threads is not None: - permutation = MemValue(permutation) - @library.for_range_multithread(n_threads, 1, self.get_part_size()) - def _(i): - self.set_column(i, self.get_column(i).secure_permute( - permutation, reverse=reverse)) + if (self.value_type == sint) and (n_threads is None): + # Use only a single shuffle instruction if applicable and permutation is single-threaded anyway. + unit_size = self.get_part_size() + n = self.sizes[0] * unit_size + res = sint(size=n) + applyshuffle(n, res, self[:], unit_size, permutation, reverse) + self.assign_vector(res) + else: + if n_threads is not None: + permutation = MemValue(permutation) + + if n_parallel is None: + @library.for_range_opt_multithread(n_threads, self.get_part_size()) + def iter(i): + column = self.get_column(i) + column = column.secure_permute(permutation, reverse=reverse) + self.set_column(i, column) + else: + @library.for_range_multithread(n_threads, n_parallel, self.get_part_size()) + def iter(i): + column = self.get_column(i) + column = column.secure_permute(permutation, reverse=reverse) + self.set_column(i, column) def sort(self, key_indices=None, n_bits=None, batcher=False): """ Sort sub-arrays (different first index) in place. diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 96d42cb3..24a77ec9 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -286,7 +286,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 5 register operands case PRINTFLOATPLAIN: case PRINTFLOATPLAINB: - case APPLYSHUFFLE: get_vector(5, start, s); break; case INCINT: @@ -321,14 +320,11 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case RUN_TAPE: case CONV2DS: case MATMULS: - num_var_args = get_int(s); - get_vector(num_var_args, start, s); - break; + case APPLYSHUFFLE: case MATMULSM: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; - // read from file, input is opcode num_args, // start_file_posn (read), end_file_posn(write) var1, var2, ... case READFILESHARE: @@ -1187,8 +1183,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.shuffle_store)); return; case APPLYSHUFFLE: - Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)), - Proc.machine.shuffle_store); + Proc.Procp.apply_shuffle(*this, Proc.machine.shuffle_store); return; case DELSHUFFLE: Proc.machine.shuffle_store.del(Proc.read_Ci(r[0])); diff --git a/Processor/Processor.h b/Processor/Processor.h index f5ba8a00..d6ab05a6 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -96,8 +96,7 @@ public: void secure_shuffle(const Instruction& instruction); size_t generate_secure_shuffle(const Instruction& instruction, ShuffleStore& shuffle_store); - void apply_shuffle(const Instruction& instruction, int handle, - ShuffleStore& shuffle_store); + void apply_shuffle(const Instruction& instruction, ShuffleStore& shuffle_store); void inverse_permutation(const Instruction& instruction); void input_personal(const vector& args); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d09e6711..c6beb9a6 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -908,13 +908,28 @@ size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction, } template -void SubProcessor::apply_shuffle(const Instruction& instruction, int handle, - ShuffleStore& shuffle_store) +void SubProcessor::apply_shuffle(const Instruction& instruction, + ShuffleStore& shuffle_store) { - shuffler.apply(S, instruction.get_size(), instruction.get_start()[2], - instruction.get_start()[0], instruction.get_start()[1], - shuffle_store.get(handle), - instruction.get_start()[4]); + const auto& args = instruction.get_start(); + + const auto n_shuffles = args.size() / 6; + vector sizes(n_shuffles, 0); + vector destinations(n_shuffles, 0); + vector sources(n_shuffles, 0); + vector unit_sizes(n_shuffles, 0); + vector shuffles(n_shuffles, 0); + vector reverse(n_shuffles, false); + for (size_t i = 0; i < n_shuffles; i++) { + sizes[i] = args[6 * i]; + destinations[i] = args[6 * i + 1]; + sources[i] = args[6 * i + 2]; + unit_sizes[i] = args[6 * i + 3]; + shuffles[i] = Proc->read_Ci(args[6 * i + 4]); + reverse[i] = args[6 * i + 5]; + } + shuffler.apply_multiple(S, sizes, destinations, sources, unit_sizes, shuffles, reverse, shuffle_store); + maybe_check(); } diff --git a/Programs/Source/test_permute.mpc b/Programs/Source/test_permute.mpc new file mode 100644 index 00000000..c7778585 --- /dev/null +++ b/Programs/Source/test_permute.mpc @@ -0,0 +1,271 @@ +from Compiler.library import get_number_of_players +from Compiler.sqrt_oram import n_parallel +from Compiler.util import if_else + + +def test_allocator(): + arr1 = sint.Array(5) + arr2 = sint.Array(10) + arr3 = sint.Array(20) + + p1 = sint.get_secure_shuffle(5) + p2 = sint.get_secure_shuffle(10) + p3 = sint.get_secure_shuffle(20) + + # Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward. + arr1.secure_permute(p1) + arr2[0] = arr1[0] + arr2.secure_permute(p2) + arr3.secure_permute(p3) + + +def test_case(permutation_sizes, timer_base: int | None = None): + if timer_base is not None: + start_timer(timer_base + 0) + arrays = [] + permutations = [] + for size in permutation_sizes: + arrays.append(Array.create_from([sint(i) for i in range(size)])) + permutations.append(sint.get_secure_shuffle(size)) + if timer_base is not None: + stop_timer(timer_base + 0) + start_timer(timer_base + 1) + + for arr, p in zip(arrays, permutations): + arr.secure_permute(p) + + if timer_base is not None: + stop_timer(timer_base + 1) + start_timer(timer_base + 2) + + for i, arr in enumerate(arrays): + revealed = arr.reveal() + print_ln("%s", revealed) + + n_matched = cint(0) + @for_range(len(arr)) + def count_matches(i: cint) -> None: + n_matched.update(n_matched + (revealed[i] == i)) + @if_(n_matched == len(arr)) + def didnt_permute(): + print_ln("Permutation potentially didn't work (permutation might have been identity by chance).") + crash() + + if timer_base is not None: + stop_timer(timer_base + 2) + start_timer(timer_base + 3) + + for arr, p in zip(arrays, permutations): + arr.secure_permute(p, reverse=True) + + if timer_base is not None: + stop_timer(timer_base + 3) + start_timer(timer_base + 4) + + for i, arr in enumerate(arrays): + revealed = arr.reveal() + print_ln("%s", revealed) + + @for_range(len(arr)) + def test_is_original(i: cint) -> None: + @if_(revealed[i] != i) + def fail(): + print_ln("Failed to invert permutation!") + crash() + + if timer_base is not None: + stop_timer(timer_base + 4) + + +def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], timer_base: int) -> None: + start_timer(timer_base) + permutations = [] + for permutation_size in sizes: + permutations.append(sint.get_secure_shuffle(permutation_size)) + stop_timer(timer_base) + + start_timer(timer_base + 1) + arrs_to_permute_sequentially = [] + arrs_to_permute_parallely = [] + for permutation_size in sizes: + arrs_to_permute_sequentially.append(Array.create_from([sint(i) for i in range(permutation_size)])) + arrs_to_permute_parallely.append(Array.create_from([sint(i) for i in range(permutation_size)])) + stop_timer(timer_base + 1) + + start_timer(timer_base + 2) + for arr, perm in zip(arrs_to_permute_sequentially, permutations): + arr.secure_permute(perm) + break_point() + stop_timer(timer_base + 2) + + start_timer(timer_base + 3) + for arr, perm in zip(arrs_to_permute_parallely, permutations): + arr.secure_permute(perm) + stop_timer(timer_base + 3) + + start_timer(timer_base + 4) + arrs_to_permute_sequentially = [arr.reveal() for arr in arrs_to_permute_sequentially] + arrs_to_permute_parallely = [arr.reveal() for arr in arrs_to_permute_parallely] + stop_timer(timer_base + 4) + + for (arr_seq, arr_par) in zip(arrs_to_permute_sequentially, arrs_to_permute_parallely): + print_ln("Sequential: %s", arr_seq) + print_ln("Parallel: %s", arr_par) + + @for_range(len(arr_seq)) + def test_equals(i: cint) -> None: + @if_(arr_seq[i] != arr_par[i]) + def fail(): + print_ln("Sequentially permuted arrays to not match the parallely permuted arrays.") + crash() + + +def test_permute_matrix(timer_base: int, value_type=sint) -> None: + def test_permuted_matrix(m, p): + permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])]) + permuted_indices.secure_permute(p, reverse=True) + permuted_indices = permuted_indices.reveal() + + @for_range(m.sizes[0]) + def check_row(i): + @for_range(m.sizes[1]) + def check_entry(j): + @if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j)) + def fail(): + print_ln("Matrix permuted unexpectedly.") + crash() + + print_ln("Pre-create matrix") + m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)]) + m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) + m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) + m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)]) + m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)]) + print_ln("post-create matrix") + + p1 = sint.get_secure_shuffle(5) + p2 = sint.get_secure_shuffle(5) + p3 = sint.get_secure_shuffle(5) + + start_timer(timer_base + 1) + m1.secure_permute(p1) + stop_timer(timer_base + 1) + start_timer(timer_base + 2) + m2.secure_permute(p2) + stop_timer(timer_base + 2) + start_timer(timer_base + 3) + m3.secure_permute(p2, n_threads=3) + stop_timer(timer_base + 3) + start_timer(timer_base + 4) + m4.secure_permute(p2, n_threads=3, n_parallel=2) + stop_timer(timer_base + 4) + start_timer(timer_base + 5) + m5.secure_permute(p3, n_threads=3, n_parallel=3) + stop_timer(timer_base + 5) + print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.") + + m1 = m1.reveal() + m2 = m2.reveal() + m3 = m3.reveal() + m4 = m4.reveal() + m5 = m5.reveal() + + print_ln("Permuted m1:") + for row in m1: + print_ln("%s", row) + test_permuted_matrix(m1, p1) + + print_ln("Permuted m2:") + for row in m2: + print_ln("%s", row) + test_permuted_matrix(m2, p2) + + print_ln("Permuted m3 (should be equal to m2):") + for row in m3: + print_ln("%s", row) + test_permuted_matrix(m3, p2) + + print_ln("Permuted m4 (should be equal to m2):") + for row in m4: + print_ln("%s", row) + test_permuted_matrix(m4, p2) + + print_ln("Permuted m5:") + for row in m5: + print_ln("%s", row) + test_permuted_matrix(m5, p3) + +def test_secure_shuffle_still_works(size: int, timer_base: int): + arr = Array.create_from([sint(i) for i in range(size)]) + start_timer(timer_base) + arr.secure_shuffle() + stop_timer(timer_base) + arr = arr.reveal() + + n_matched = cint(0) + @for_range(len(arr)) + def count_matches(i: cint) -> None: + n_matched.update(n_matched + (arr[i] == i)) + @if_(n_matched == len(arr)) + def didnt_permute(): + print_ln("Shuffle potentially didn't work (permutation might have been identity by chance).") + crash() + +def test_inverse_permutation_still_works(size, timer_base: int): + @if_e(get_number_of_players() == 2) + def test(): + program.use_invperm(True) + + arr = Array.create_from([sint(i) for i in range(size)]) + p_as_arr = Array.create_from(arr[:]) + p_inv_as_arr = Array.create_from(arr[:]) + start_timer(timer_base + 1) + p1 = sint.get_secure_shuffle(size) + p_as_arr.secure_permute(p1, reverse=True) + p_inv_as_arr.secure_permute(p1) + p_inv_as_arr = p_inv_as_arr.reveal() + stop_timer(timer_base + 1) + + start_timer(timer_base + 2) + p_inv2 = Array.create_from(p_as_arr[:].inverse_permutation()) + # p_inv2.inverse_permutation() + p_inv2 = p_inv2.reveal() + stop_timer(timer_base + 2) + + print_ln("Permutation: %s", p_as_arr.reveal()) + print_ln("Inverse from secure_permute: %s", p_inv_as_arr) + print_ln("Inverse from inverse_permut: %s", p_inv2) + + @for_range(size) + def check(i): + @if_(p_inv_as_arr[i] != p_inv2[i]) + def fail(): + print_ln("Inverse permutation don't match.") + crash() + @else_ + def _(): + print_ln("Inverse permutation is only tested in 2-party computation.") + + + +def test_dead_code_elimination(): + vector = sint([0,2,4,6,5,3,1]) + handle = sint.get_secure_shuffle(7) + print_ln('%s', vector.secure_permute(handle).reveal()) + + +test_allocator() +test_case([10, 15], 10) +test_case([10, 15, 20], 20) +test_case([16,32], 30) +test_case([256], 40) + +test_parallel_permutation_equals_sequential_permutation([5,10],50) + +test_permute_matrix(60) +test_permute_matrix(70, value_type=sfix) + +test_secure_shuffle_still_works(32, 80) +test_inverse_permutation_still_works(8, 80) + +test_dead_code_elimination() \ No newline at end of file diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 246c9cb2..169bdf82 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -37,8 +37,8 @@ public: return store.add(); } - void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, int, bool) + void apply(StackedVector& a, size_t n, size_t unit_size, size_t output_base, + size_t input_base, size_t, bool) { auto source = a.begin() + input_base; auto dest = a.begin() + output_base; @@ -49,13 +49,29 @@ public: if (n > 1) { // swap first two to pass check - for (int i = 0; i < unit_size; i++) + for (size_t i = 0; i < unit_size; i++) swap(a[output_base + i], a[output_base + i + unit_size]); } } - void inverse_permutation(StackedVector&, size_t, size_t, size_t) - { + void inverse_permutation(StackedVector &, size_t, size_t, size_t) { + throw runtime_error("inverse permutation not implemented"); + }; + + void apply_multiple(StackedVector &a, vector &sizes, vector &destinations, + vector &sources, + vector &unit_sizes, vector &handles, vector &reverses, + store_type&) { + const auto n_shuffles = sizes.size(); + assert(sources.size() == n_shuffles); + assert(destinations.size() == n_shuffles); + assert(unit_sizes.size() == n_shuffles); + assert(handles.size() == n_shuffles); + assert(reverses.size() == n_shuffles); + + for (size_t i = 0; i < n_shuffles; i++) { + this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], handles[i], reverses[i]); + } } }; diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h index fecfe7c3..b5f8a677 100644 --- a/Protocols/Rep3Shuffler.h +++ b/Protocols/Rep3Shuffler.h @@ -28,8 +28,10 @@ public: int generate(int n_shuffle, store_type& store); - void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, shuffle_type& shuffle, bool reverse); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& handles, vector& reverse, store_type& store); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& shuffles, vector& reverse); void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp index 52a1e9dc..1fdb04a2 100644 --- a/Protocols/Rep3Shuffler.hpp +++ b/Protocols/Rep3Shuffler.hpp @@ -9,34 +9,33 @@ #include "Rep3Shuffler.h" template -Rep3Shuffler::Rep3Shuffler(StackedVector& a, size_t n, int unit_size, - size_t output_base, size_t input_base, SubProcessor& proc) : - proc(proc) -{ +Rep3Shuffler::Rep3Shuffler(StackedVector &a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor &proc) : proc(proc) { store_type store; int handle = generate(n / unit_size, store); - apply(a, n, unit_size, output_base, input_base, store.get(handle), - false); + + vector sizes{n}; + vector unit_sizes{static_cast(unit_size)}; + vector destinations{output_base}; + vector sources{input_base}; + vector shuffles{store.get(handle)}; + vector reverses{true}; + this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); } template -Rep3Shuffler::Rep3Shuffler(SubProcessor& proc) : - proc(proc) -{ +Rep3Shuffler::Rep3Shuffler(SubProcessor &proc) : proc(proc) { } template -int Rep3Shuffler::generate(int n_shuffle, store_type& store) -{ +int Rep3Shuffler::generate(int n_shuffle, store_type &store) { int res = store.add(); - auto& shuffle = store.get(res); - for (int i = 0; i < 2; i++) - { - auto& perm = shuffle[i]; + auto &shuffle = store.get(res); + for (int i = 0; i < 2; i++) { + auto &perm = shuffle[i]; for (int j = 0; j < n_shuffle; j++) perm.push_back(j); - for (int j = 0; j < n_shuffle; j++) - { + for (int j = 0; j < n_shuffle; j++) { int k = proc.protocol.shared_prngs[i].get_uint(n_shuffle - j); swap(perm[j], perm[k + j]); } @@ -45,85 +44,129 @@ int Rep3Shuffler::generate(int n_shuffle, store_type& store) } template -void Rep3Shuffler::apply(StackedVector& a, size_t n, int unit_size, - size_t output_base, size_t input_base, shuffle_type& shuffle, - bool reverse) -{ +void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, + vector &sources, + vector &unit_sizes, vector &handles, vector &reverses, + store_type &store) { + vector shuffles; + for (size_t &handle: handles) { + shuffle_type &shuffle = store.get(handle); + shuffles.push_back(shuffle); + } + + apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); +} + +template +void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, + vector &sources, vector &unit_sizes, vector &shuffles, + vector &reverses) { + const auto n_shuffles = sizes.size(); + assert(sources.size() == n_shuffles); + assert(destinations.size() == n_shuffles); + assert(unit_sizes.size() == n_shuffles); + assert(shuffles.size() == n_shuffles); + assert(reverses.size() == n_shuffles); + assert(proc.P.num_players() == 3); assert(not T::malicious); assert(not T::dishonest_majority); - assert(n % unit_size == 0); - if (shuffle.empty()) - throw runtime_error("shuffle has been deleted"); + // for (size_t i = 0; i < n_shuffles; i++) { + // this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], store.get(handles[i]), reverses[i]); + // } - stats[n / unit_size] += unit_size; + vector > to_shuffle; + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + assert(sizes[current_shuffle] % unit_sizes[current_shuffle] == 0); + vector x; + for (size_t j = 0; j < sizes[current_shuffle]; j++) + x.push_back(a[sources[current_shuffle] + j]); + to_shuffle.push_back(x); - vector to_shuffle; - for (size_t i = 0; i < n; i++) - to_shuffle.push_back(a[input_base + i]); + const auto &shuffle = shuffles[current_shuffle]; + if (shuffle.empty()) + throw runtime_error("shuffle has been deleted"); + + stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle]; + } typename T::Input input(proc); - vector to_share(n); - - for (int ii = 0; ii < 3; ii++) - { - int i; - if (reverse) - i = 2 - ii; - else - i = ii; - - if (proc.P.get_player(i) == 0) - { - for (size_t j = 0; j < n / unit_size; j++) - for (int k = 0; k < unit_size; k++) - if (reverse) - to_share.at(j * unit_size + k) = to_shuffle.at( - shuffle[0].at(j) * unit_size + k).sum(); - else - to_share.at(shuffle[0].at(j) * unit_size + k) = - to_shuffle.at(j * unit_size + k).sum(); - } - else if (proc.P.get_player(i) == 1) - { - for (size_t j = 0; j < n / unit_size; j++) - for (int k = 0; k < unit_size; k++) - if (reverse) - to_share[j * unit_size + k] = to_shuffle[shuffle[1][j] - * unit_size + k][0]; - else - to_share[shuffle[1][j] * unit_size + k] = to_shuffle[j - * unit_size + k][0]; - } - + for (int pass = 0; pass < 3; pass++) { input.reset_all(proc.P); - if (proc.P.get_player(i) < 2) - for (auto& x : to_share) - input.add_mine(x); + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + const auto n = sizes[current_shuffle]; + const auto unit_size = unit_sizes[current_shuffle]; + const auto &shuffle = shuffles[current_shuffle]; + const auto reverse = reverses[current_shuffle]; + const auto current_to_shuffle = to_shuffle[current_shuffle]; - for (int k = 0; k < 2; k++) - input.add_other((-i + 3 + k) % 3); + vector to_share(n); + int i; + if (reverse) + i = 2 - pass; + else + i = pass; + + if (proc.P.get_player(i) == 0) { + for (size_t j = 0; j < n / unit_size; j++) + for (size_t k = 0; k < unit_size; k++) + if (reverse) + to_share.at(j * unit_size + k) = current_to_shuffle.at( + shuffle[0].at(j) * unit_size + k).sum(); + else + to_share.at(shuffle[0].at(j) * unit_size + k) = + current_to_shuffle.at(j * unit_size + k).sum(); + } else if (proc.P.get_player(i) == 1) { + for (size_t j = 0; j < n / unit_size; j++) + for (size_t k = 0; k < unit_size; k++) + if (reverse) + to_share[j * unit_size + k] = current_to_shuffle[shuffle[1][j] * unit_size + k][0]; + else + to_share[shuffle[1][j] * unit_size + k] = current_to_shuffle[j * unit_size + k][0]; + } + + if (proc.P.get_player(i) < 2) + for (auto &x: to_share) + input.add_mine(x); + for (int k = 0; k < 2; k++) + input.add_other((-i + 3 + k) % 3); + } input.exchange(); to_shuffle.clear(); - for (size_t j = 0; j < n; j++) - { - T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3); - to_shuffle.push_back(x); + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + const auto n = sizes[current_shuffle]; + const auto reverse = reverses[current_shuffle]; + + int i; + if (reverse) + i = 2 - pass; + else + i = pass; + + vector tmp; + for (size_t j = 0; j < n; j++) { + T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3); + tmp.push_back(x); + } + to_shuffle.push_back(tmp); } } - for (size_t i = 0; i < n; i++) - a[output_base + i] = to_shuffle[i]; + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + const auto n = sizes[current_shuffle]; + + for (size_t i = 0; i < n; i++) + a[destinations[current_shuffle] + i] = to_shuffle[current_shuffle][i]; + } } template -void Rep3Shuffler::inverse_permutation(StackedVector&, size_t, size_t, size_t) -{ +void Rep3Shuffler::inverse_permutation(StackedVector &, size_t, size_t, size_t) { throw runtime_error("inverse permutation not implemented"); } diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index a6b2df72..496584e6 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -40,13 +40,6 @@ public: private: SubProcessor& proc; - vector to_shuffle; - vector> config; - vector tmp; - int unit_size; - - size_t n_shuffle; - bool exact; /** * Generates and returns a newly generated random permutation. This permutation is generated locally. @@ -66,19 +59,16 @@ private: * e.g. [2, 4, 0, 3, 1] -> perm(1) = 4 * * @param config_player The player tasked with generating the random permutation from which to configure the waksman network. - * @param n_shuffle The size of the permutation to generate. + * @param n The size of the permutation to generate. */ - void configure(int config_player, vector* perm, int n); - void player_round(int config_player); + vector> configure(int config_player, vector* perm, int n); - void waksman(StackedVector& a, int depth, int start); - void cond_swap(T& x, T& y, const T& b); + int prep_multiple(StackedVector& a, vector &sizes, vector &sources, vector &unit_sizes, vector>& to_shuffle, vector &exact); + void finalize_multiple(StackedVector& a, vector& sizes, vector& unit_sizes, vector& destinations, vector& isExact, vector>& to_shuffle); - void iter_waksman(bool reverse = false); - void waksman_round(int size, bool inwards, bool reverse); - - void pre(StackedVector& a, size_t n, size_t input_base); - void post(StackedVector& a, size_t n, size_t input_base); + void parallel_waksman_round(size_t pass, int depth, bool inwards, vector>& toShuffle, vector& unit_sizes, vector& reverse, vector& shuffles); + vector> waksman_round_init(vector& toShuffle, size_t shuffle_unit_size, int depth, vector>& iter_waksman_config, bool inwards, bool reverse); + void waksman_round_finish(vector& toShuffle, size_t unit_size, vector> indices); public: map stats; @@ -90,21 +80,10 @@ public: int generate(int n_shuffle, store_type& store); - /** - * - * @param a The vector of registers representing the stack // TODO: Is this correct? - * @param n The size of the input vector to shuffle - * @param unit_size Determines how many vector items constitute a single block with regards to permutation: - * i.e. input vector [1,2,3,4] with unit_size=2 under permutation map [1,0] - * would result in [3,4,1,2] - * @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to) - * @param input_base The starting address of the input vector (i.e. the location from which to read the permutation) - * @param shuffle The preconfigured waksman network (shuffle) to use - * @param reverse Boolean indicating whether to apply the inverse of the permutation - * @see SecureShuffle::generate for obtaining a shuffle handle - */ - void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, shuffle_type& shuffle, bool reverse); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& handles, vector& reverse, store_type& store); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& shuffles, vector& reverse); /** * Calculate the secret inverse permutation of stack given secret permutation. diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 186b6fee..d55129a5 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -53,49 +53,68 @@ void ShuffleStore::del(int handle) template SecureShuffle::SecureShuffle(SubProcessor& proc) : - proc(proc), unit_size(0), n_shuffle(0), exact(false) + proc(proc) { } template SecureShuffle::SecureShuffle(StackedVector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : - proc(proc), unit_size(unit_size), n_shuffle(0), exact(false) + proc(proc) { - pre(a, n, input_base); + store_type store; + int handle = generate(n / unit_size, store); - for (auto i : proc.protocol.get_relevant_players()) - player_round(i); - - post(a, n, output_base); + vector sizes{n}; + vector unit_sizes{static_cast(unit_size)}; + vector destinations{output_base}; + vector sources{input_base}; + vector shuffles{store.get(handle)}; + vector reverses{true}; + this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); } template -void SecureShuffle::apply(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, shuffle_type& shuffle, bool reverse) -{ - this->unit_size = unit_size; +void SecureShuffle::apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& handles, vector& reverse, store_type& store) { + vector shuffles; + for (size_t &handle : handles) + shuffles.push_back(store.get(handle)); - stats[n / unit_size] += unit_size; + this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverse); +} - pre(a, n, input_base); +template +void SecureShuffle::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, + vector &sources, vector &unit_sizes, vector &shuffles, vector &reverse) { + const auto n_shuffles = sizes.size(); + assert(sources.size() == n_shuffles); + assert(destinations.size() == n_shuffles); + assert(unit_sizes.size() == n_shuffles); + assert(shuffles.size() == n_shuffles); + assert(reverse.size() == n_shuffles); - assert(shuffle.size() == proc.protocol.get_relevant_players().size()); + // SecureShuffle works by making t players create and "secret-share" a permutation. + // Then each permutation is applied in a pass. As long as one of these permutations was created by an honest party, + // the resulting combined shuffle is hidden. + const auto n_passes = proc.protocol.get_relevant_players().size(); - if (reverse) - for (auto it = shuffle.end(); it > shuffle.begin(); it--) - { - this->config = *(it - 1); - iter_waksman(reverse); - } - else - for (auto& config : shuffle) - { - this->config = config; - iter_waksman(reverse); - } + // Initialize the shuffles. + vector is_exact(n_shuffles, false); + vector> to_shuffle; + int max_depth = prep_multiple(a, sizes, sources, unit_sizes, to_shuffle, is_exact); - post(a, n, output_base); + // Apply the shuffles. + for (size_t pass = 0; pass < n_passes; pass++) + { + for (int depth = 0; depth < max_depth; depth++) + parallel_waksman_round(pass, depth, true, to_shuffle, unit_sizes, reverse, shuffles); + for (int depth = max_depth - 1; depth >= 0; depth--) + parallel_waksman_round(pass, depth, false, to_shuffle, unit_sizes, reverse, shuffles); + } + + // Write the shuffled results into memory. + finalize_multiple(a, sizes, unit_sizes, destinations, is_exact, to_shuffle); } @@ -115,27 +134,38 @@ void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, si if (T::malicious) throw runtime_error("inverse permutation only implemented for semi-honest protocols"); - // We are dealing directly with permutations, so the unit_size will always be 1. - this->unit_size = 1; - // We need to account for sizes which are not a power of 2 - size_t n_pow2 = (1u << int(ceil(log2(n)))); + vector sizes { n }; + vector unit_sizes { 1 }; // We are dealing directly with permutations, so the unit_size will always be 1. + vector destinations { output_base }; + vector sources { input_base }; + vector reverse { true }; + vector> to_shuffle; + vector is_exact(1, false); - // Copy over the input registers - pre(stack, n, input_base); + prep_multiple(stack, sizes, sources, unit_sizes, to_shuffle, is_exact); + + size_t shuffle_size = to_shuffle[0].size() / unit_sizes[0]; // Alice generates stack local permutation and shares the waksman configuration bits secretly to Bob. - vector perm_alice(n_pow2); - if (P.my_num() == alice) + vector perm_alice(shuffle_size); + if (P.my_num() == alice) { perm_alice = generate_random_permutation(n); - configure(alice, &perm_alice, n); + } + auto config = configure(alice, &perm_alice, n); + vector shuffles {{ config, config }}; + // Apply perm_alice to perm_alice to get perm_bob, // stack permutation that we can reveal to Bob without Bob learning anything about perm_alice (since it is masked by perm_a) - iter_waksman(true); + for (int depth = 0; depth < log2(shuffle_size); depth++) + parallel_waksman_round(0, depth, true, to_shuffle, unit_sizes, reverse, shuffles); + for (int depth = log2(shuffle_size); depth >= 0; depth--) + parallel_waksman_round(0, depth, false, to_shuffle, unit_sizes, reverse, shuffles); + // Store perm_bob at stack[output_base] - post(stack, n, output_base); + finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle); // Reveal permutation perm_bob = perm_a * perm_alice // Since this permutation is masked by perm_a, Bob learns nothing about perm - vector perm_bob(n_pow2); + vector perm_bob(shuffle_size); typename T::PrivateOutput output(proc); for (size_t i = 0; i < n; i++) output.prepare_sending(stack[output_base + i], bob); @@ -147,13 +177,13 @@ void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, si perm_bob[i] = (int) val.get_si(); } - vector perm_bob_inv(n_pow2); + vector perm_bob_inv(shuffle_size); if (P.my_num() == bob) { for (int i = 0; i < (int) n; i++) perm_bob_inv[perm_bob[i]] = i; // Pad the permutation to n_pow2 // Required when using waksman networks - for (int i = (int) n; i < (int) n_pow2; i++) + for (int i = (int) n; i < (int) shuffle_size; i++) perm_bob_inv[i] = i; } @@ -169,74 +199,118 @@ void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, si stack[output_base + i] = input.finalize(alice); // The two parties now jointly compute perm_a * perm_bob_inv to obtain perm_inv - pre(stack, n, output_base); - configure(bob, &perm_bob_inv, n); - iter_waksman(true); - // perm_inv is written back to stack[output_base] - post(stack, n, output_base); -} - -template -void SecureShuffle::pre(StackedVector& a, size_t n, size_t input_base) -{ - n_shuffle = n / unit_size; - assert(unit_size * n_shuffle == n); - size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); - exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; to_shuffle.clear(); + prep_multiple(stack, sizes, destinations, unit_sizes, to_shuffle, is_exact); - if (exact) - { - to_shuffle.resize(n_shuffle_pow2 * unit_size); - for (size_t i = 0; i < n; i++) - to_shuffle[i] = a[input_base + i]; - } - else - { - // sorting power of two elements together with indicator bits - to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle)))); - for (size_t i = 0; i < n_shuffle; i++) - { - for (int j = 0; j < unit_size; j++) - to_shuffle[i * (unit_size + 1) + j] = a[input_base - + i * unit_size + j]; - to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1, - proc.P.my_num(), proc.MC.get_alphai()); - } - this->unit_size++; - } + config = configure(bob, &perm_bob_inv, n); + shuffles[0] = { config, config }; + + for (int i = 0; i < log2(shuffle_size); i++) + parallel_waksman_round(0, i, true, to_shuffle, unit_sizes, reverse, shuffles); + for (int i = log2(shuffle_size) - 2; i >= 0; i--) + parallel_waksman_round(0, i, false, to_shuffle, unit_sizes, reverse, shuffles); + + // Store perm_bob at stack[output_base] + finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle); } template -void SecureShuffle::post(StackedVector& a, size_t n, size_t output_base) -{ - if (exact) - for (size_t i = 0; i < n; i++) - a[output_base + i] = to_shuffle[i]; - else - { - auto& MC = proc.MC; - MC.init_open(proc.P); - int shuffle_unit_size = this->unit_size; - int unit_size = shuffle_unit_size - 1; - for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++) - MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1)); - MC.exchange(proc.P); - size_t i_shuffle = 0; - for (size_t i = 0; i < n_shuffle; i++) +int SecureShuffle::prep_multiple(StackedVector &a, vector &sizes, + vector &sources, vector &unit_sizes, vector> &to_shuffle, vector &is_exact) { + int max_depth = 0; + const size_t n_shuffles = sizes.size(); + + for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { + const size_t input_base = sources[currentShuffle]; + const size_t n = sizes[currentShuffle]; + const size_t unit_size = unit_sizes[currentShuffle]; + + assert(n % unit_size == 0); + + const size_t n_shuffle = n / unit_size; + const size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + const bool exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; + + vector tmp; + if (exact) { - auto bit = MC.finalize_open(); - if (bit == 1) - { - // only output real elements - for (int j = 0; j < unit_size; j++) - a.at(output_base + i_shuffle * unit_size + j) = - to_shuffle.at(i * shuffle_unit_size + j); - i_shuffle++; - } + tmp.resize(n_shuffle_pow2 * unit_size); + for (size_t i = 0; i < n; i++) + tmp[i] = a[input_base + i]; + } + else + { + // Pad n_shuffle to n_shuffle_pow2. To reduce this back to n_shuffle after-the-fact, a flag bit is + // added to every real entry. + const size_t shuffle_unit_size = unit_size + 1; + tmp.resize(shuffle_unit_size * n_shuffle_pow2); + for (size_t i = 0; i < n_shuffle; i++) + { + for (size_t j = 0; j < unit_size; j++) + tmp[i * shuffle_unit_size + j] = a[input_base + i * unit_size + j]; + tmp[(i + 1) * shuffle_unit_size - 1] = T::constant(1, proc.P.my_num(), proc.MC.get_alphai()); + } + for (size_t i = n_shuffle * shuffle_unit_size; i < tmp.size(); i++) + tmp[i] = T::constant(0, proc.P.my_num(), proc.MC.get_alphai()); + unit_sizes[currentShuffle] = shuffle_unit_size; + } + + to_shuffle.push_back(tmp); + is_exact[currentShuffle] = exact; + + const int shuffle_depth = tmp.size() / unit_size; + if (shuffle_depth > max_depth) + max_depth = shuffle_depth; + } + + return max_depth; +} + +template +void SecureShuffle::finalize_multiple(StackedVector &a, vector &sizes, vector &unit_sizes, + vector &destinations, vector &isExact, vector> &to_shuffle) { + const size_t n_shuffles = sizes.size(); + for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { + const size_t n = sizes[currentShuffle]; + const size_t shuffled_unit_size = unit_sizes[currentShuffle]; + const size_t output_base = destinations[currentShuffle]; + + const vector& shuffledData = to_shuffle[currentShuffle]; + + if (isExact[currentShuffle]) + for (size_t i = 0; i < n; i++) + a[output_base + i] = shuffledData[i]; + else + { + const size_t original_unit_size = shuffled_unit_size - 1; + const size_t n_shuffle = n / original_unit_size; + const size_t n_shuffle_pow2 = shuffledData.size() / shuffled_unit_size; + + // Reveal the "real element" flags. + auto& MC = proc.MC; + MC.init_open(proc.P); + for (size_t i = 0; i < n_shuffle_pow2; i++) { + MC.prepare_open(shuffledData.at((i + 1) * shuffled_unit_size - 1)); + } + MC.exchange(proc.P); + + // Filter out the real elements. + size_t i_shuffle = 0; + for (size_t i = 0; i < n_shuffle_pow2; i++) + { + auto bit = MC.finalize_open(); + if (bit == 1) + { + // only output real elements + for (size_t j = 0; j < original_unit_size; j++) + a.at(output_base + i_shuffle * original_unit_size + j) = + shuffledData.at(i * shuffled_unit_size + j); + i_shuffle++; + } + } + if (i_shuffle != n_shuffle) + throw runtime_error("incorrect shuffle"); } - if (i_shuffle != n_shuffle) - throw runtime_error("incorrect shuffle"); } } @@ -256,15 +330,6 @@ vector SecureShuffle::generate_random_permutation(int n) { return perm; } -template -void SecureShuffle::player_round(int config_player) { - vector random_perm(n_shuffle); - if (proc.P.my_num() == config_player) - random_perm = generate_random_permutation(n_shuffle); - configure(config_player, &random_perm, n_shuffle); - iter_waksman(); -} - template int SecureShuffle::generate(int n_shuffle, store_type& store) { @@ -275,8 +340,7 @@ int SecureShuffle::generate(int n_shuffle, store_type& store) vector perm; if (proc.P.my_num() == i) perm = generate_random_permutation(n_shuffle); - configure(i, &perm, n_shuffle); - + auto config = configure(i, &perm, n_shuffle); shuffle.push_back(config); } @@ -284,7 +348,7 @@ int SecureShuffle::generate(int n_shuffle, store_type& store) } template -void SecureShuffle::configure(int config_player, vector *perm, int n) { +vector> SecureShuffle::configure(int config_player, vector *perm, int n) { auto &P = proc.P; auto &input = proc.input; input.reset_all(P); @@ -311,7 +375,7 @@ void SecureShuffle::configure(int config_player, vector *perm, int n) { input.add_other(config_player); input.exchange(); - config.clear(); + vector> config; typename T::Protocol checker(P); checker.init(proc.DataF, proc.MC); checker.init_dotprod(); @@ -345,68 +409,61 @@ void SecureShuffle::configure(int config_player, vector *perm, int n) { P)) == 0); checker.check(); } + + return config; } template -void SecureShuffle::waksman(StackedVector& a, int depth, int start) +void SecureShuffle::parallel_waksman_round(size_t pass, int depth, bool inwards, vector> &toShuffle, + vector &unit_sizes, vector &reverse, vector &shuffles) { - int n = a.size(); + const auto n_passes = proc.protocol.get_relevant_players().size(); + const auto n_shuffles = shuffles.size(); - if (n == 2) - { - cond_swap(a[0], a[1], config.at(depth).at(start)); - return; - } - - vector a0(n / 2), a1(n / 2); - for (int i = 0; i < n / 2; i++) - { - a0.at(i) = a.at(2 * i); - a1.at(i) = a.at(2 * i + 1); - - cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2)); - } - - waksman(a0, depth + 1, start); - waksman(a1, depth + 1, start + n / 2); - - for (int i = 0; i < n / 2; i++) - { - a.at(2 * i) = a0.at(i); - a.at(2 * i + 1) = a1.at(i); - cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start)); - } -} - -template -void SecureShuffle::cond_swap(T& x, T& y, const T& b) -{ - auto diff = proc.protocol.mul(x - y, b); - x -= diff; - y += diff; -} - -template -void SecureShuffle::iter_waksman(bool reverse) -{ - int n = to_shuffle.size() / unit_size; - - for (int depth = 0; depth < log2(n); depth++) - waksman_round(depth, true, reverse); - - for (int depth = log2(n) - 2; depth >= 0; depth--) - waksman_round(depth, false, reverse); -} - -template -void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) -{ - int n = to_shuffle.size() / unit_size; - assert((int) config.at(depth).size() == n); - int nblocks = 1 << depth; - int size = n / (2 * nblocks); - bool outwards = !inwards; + vector>> allIndices; proc.protocol.init_mul(); + + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle]; + if (depth >= log2(n) - !inwards) { + allIndices.push_back({}); + continue; + } + + const auto isReverse = reverse[current_shuffle]; + size_t configIdx = pass; + if (isReverse) + configIdx = n_passes - pass - 1; + + auto& config = shuffles[current_shuffle][configIdx]; + + vector> indices = waksman_round_init( + toShuffle[current_shuffle], + unit_sizes[current_shuffle], + depth, + config, + inwards, + isReverse + ); + allIndices.push_back(indices); + } + proc.protocol.exchange(); + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle]; + if (depth >= log2(n) - !inwards) { + continue; + } + waksman_round_finish(toShuffle[current_shuffle], unit_sizes[current_shuffle], allIndices[current_shuffle]); + } +} + +template +vector> SecureShuffle::waksman_round_init(vector &toShuffle, size_t shuffle_unit_size, int depth, vector> &iter_waksman_config, bool inwards, bool reverse) { + int n = toShuffle.size() / shuffle_unit_size; + assert((int) iter_waksman_config.at(depth).size() == n); + int n_blocks = 1 << depth; + int size = n / (2 * n_blocks); + bool outwards = !inwards; vector> indices; indices.reserve(n / 2); Waksman waksman(n); @@ -423,30 +480,37 @@ void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) bool run = waksman.matters(depth, i_bit); if (run) { - for (int l = 0; l < unit_size; l++) - proc.protocol.prepare_mul(config.at(depth).at(i_bit), - to_shuffle.at(in1 * unit_size + l) - - to_shuffle.at(in2 * unit_size + l)); + for (size_t l = 0; l < shuffle_unit_size; l++) + proc.protocol.prepare_mul(iter_waksman_config.at(depth).at(i_bit), + toShuffle.at(in1 * shuffle_unit_size + l) - toShuffle.at(in2 * shuffle_unit_size + l)); } - indices.push_back({{in1, in2, out1, out2, run}}); + indices.push_back({in1, in2, out1, out2, run}); } - proc.protocol.exchange(); - tmp.resize(to_shuffle.size()); + return indices; +} + +template +void SecureShuffle::waksman_round_finish(vector &toShuffle, size_t unit_size, vector> indices) { + int n = toShuffle.size() / unit_size; + + vector tmp(toShuffle.size()); for (int k = 0; k < n / 2; k++) { auto idx = indices.at(k); - for (int l = 0; l < unit_size; l++) + for (size_t l = 0; l < unit_size; l++) { T diff; if (idx[4]) diff = proc.protocol.finalize_mul(); - tmp.at(idx[2] * unit_size + l) = to_shuffle.at( + tmp.at(idx[2] * unit_size + l) = toShuffle.at( idx[0] * unit_size + l) - diff; - tmp.at(idx[3] * unit_size + l) = to_shuffle.at( + tmp.at(idx[3] * unit_size + l) = toShuffle.at( idx[1] * unit_size + l) + diff; } } - swap(tmp, to_shuffle); + + swap(tmp, toShuffle); } + #endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */ diff --git a/Protocols/SpdzWiseRep3Shuffler.h b/Protocols/SpdzWiseRep3Shuffler.h index c1773620..f61d2d3c 100644 --- a/Protocols/SpdzWiseRep3Shuffler.h +++ b/Protocols/SpdzWiseRep3Shuffler.h @@ -30,8 +30,10 @@ public: int generate(int n_shuffle, store_type& store); - void apply(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, shuffle_type& shuffle, bool reverse); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& handles, vector& reverse, store_type& store); + void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& shuffles, vector& reverse); void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); diff --git a/Protocols/SpdzWiseRep3Shuffler.hpp b/Protocols/SpdzWiseRep3Shuffler.hpp index 971e4b1e..250a6971 100644 --- a/Protocols/SpdzWiseRep3Shuffler.hpp +++ b/Protocols/SpdzWiseRep3Shuffler.hpp @@ -13,8 +13,14 @@ SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(StackedVector& a, size_t n, { store_type store; int handle = generate(n / unit_size, store); - apply(a, n, unit_size, output_base, input_base, store.get(handle), - false); + + vector sizes{n}; + vector unit_sizes{static_cast(unit_size)}; + vector destinations{output_base}; + vector sources{input_base}; + vector shuffles{store.get(handle)}; + vector reverses{true}; + this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); } template @@ -30,31 +36,63 @@ int SpdzWiseRep3Shuffler::generate(int n_shuffle, store_type& store) } template -void SpdzWiseRep3Shuffler::apply(StackedVector& a, size_t n, - int unit_size, size_t output_base, size_t input_base, - shuffle_type& shuffle, bool reverse) -{ - stats[n / unit_size] += unit_size; - - StackedVector to_shuffle; - to_shuffle.reserve(2 * n); - - for (size_t i = 0; i < n; i++) - { - auto& x = a[input_base + i]; - to_shuffle.push_back(x.get_share()); - to_shuffle.push_back(x.get_mac()); +void SpdzWiseRep3Shuffler::apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& handles, vector& reverses, store_type& store) { + vector shuffles; + for (size_t &handle : handles) { + shuffle_type& shuffle = store.get(handle); + shuffles.push_back(shuffle); } - internal.apply(to_shuffle, 2 * n, 2 * unit_size, 0, 0, shuffle, reverse); + apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); +} - for (size_t i = 0; i < n; i++) +template +void SpdzWiseRep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, + vector &sources, vector &unit_sizes, vector &shuffles, vector &reverse) { + + const size_t n_shuffles = sizes.size(); + assert(n_shuffles == destinations.size()); + assert(n_shuffles == sources.size()); + assert(n_shuffles == unit_sizes.size()); + assert(n_shuffles == shuffles.size()); + assert(n_shuffles == reverse.size()); + + StackedVector temporary_memory(0); + vector mapped_positions (n_shuffles, 0); + vector mapped_sizes(n_shuffles, 0); + vector mapped_unit_sizes (n_shuffles, 0); + + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + mapped_positions[current_shuffle] = temporary_memory.size(); + + mapped_sizes[current_shuffle] = 2 * sizes[current_shuffle]; + mapped_unit_sizes[current_shuffle] = 2 * unit_sizes[current_shuffle]; + stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle]; + + for (size_t i = 0; i < sizes[current_shuffle]; i++) + { + auto& x = a[sources[current_shuffle] + i]; + temporary_memory.push_back(x.get_share()); + temporary_memory.push_back(x.get_mac()); + } + } + + internal.apply_multiple(temporary_memory, mapped_sizes, mapped_positions, mapped_positions, mapped_unit_sizes, shuffles, reverse); + + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - auto& x = a[output_base + i]; - x.set_share(to_shuffle[2 * i]); - x.set_mac(to_shuffle[2 * i + 1]); - proc.protocol.add_to_check(x); + const size_t n = sizes[current_shuffle]; + const size_t dest = destinations[current_shuffle]; + const size_t pos = mapped_positions[current_shuffle]; + for (size_t i = 0; i < n; i++) + { + auto& x = a[dest + i]; + x.set_share(temporary_memory[pos + 2 * i]); + x.set_mac(temporary_memory[pos + 2 * i + 1]); + proc.protocol.add_to_check(x); + } } proc.protocol.maybe_check();