diff --git a/Programs/Source/test_permute.mpc b/Programs/Source/test_permute.mpc index 2565adbd..5b24f4d2 100644 --- a/Programs/Source/test_permute.mpc +++ b/Programs/Source/test_permute.mpc @@ -19,7 +19,17 @@ def test_case(permutation_sizes, timer_base: int | None = None): start_timer(timer_base + 2) for i, arr in enumerate(arrays): - print_ln("%s", arr.reveal()) + revealed = arr.reveal() + print_ln("%s", revealed) + + n_matched = cint(0) + @for_range(len(arr)) + def count_matches(i: cint) -> None: + n_matched.update(n_matched + (revealed[i] == i)) + @if_(n_matched == len(arr)) + def didnt_permute(): + print_ln("Permutation likely didn't work.") + crash() if timer_base is not None: stop_timer(timer_base + 2) @@ -33,7 +43,15 @@ def test_case(permutation_sizes, timer_base: int | None = None): start_timer(timer_base + 4) for i, arr in enumerate(arrays): - print_ln("%s", arr.reveal()) + revealed = arr.reveal() + print_ln("%s", revealed) + + @for_range(len(arr)) + def test_is_original(i: cint) -> None: + @if_(revealed[i] != i) + def fail(): + print_ln("FAILED!") + crash() if timer_base is not None: stop_timer(timer_base + 4) @@ -55,5 +73,7 @@ def test_allocator(): arr3.secure_permute(p3) # test_allocator() -test_case([5, 10, 20], 10) -test_case([5, 10, 15, 20], 20) \ No newline at end of file +test_case([5,10], 10) +test_case([5, 10, 15, 20], 20) +test_case([4,8,16], 30) +test_case([5], 40) \ No newline at end of file diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index 1e42310a..93ffe9e6 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -77,6 +77,9 @@ private: void iter_waksman(bool reverse = false); void waksman_round(int size, bool inwards, bool reverse); + vector> waksman_round_init(vector& toShuffle, size_t shuffle_unit_size, int depth, vector>& iter_waksman_config, bool inwards, bool reverse); + void waksman_round_finish(vector& toShuffle, size_t unit_size, vector> indices); + void pre(vector& a, size_t n, size_t input_base); void post(vector& a, size_t n, size_t input_base); @@ -106,6 +109,8 @@ public: void applyMultiple(vector& a, vector& sizes, vector& destinations, vector& sources, vector& unit_sizes, vector& handles, vector& reverse, store_type& store); + void applyMultiple(vector& a, vector& sizes, vector& destinations, vector& sources, + vector& unit_sizes, vector& shuffles, vector& reverse); /** * Calculate the secret inverse permutation of stack given secret permutation. diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 59334d23..182a4728 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -100,15 +100,178 @@ void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t outpu template void SecureShuffle::applyMultiple(vector& a, vector& sizes, vector& destinations, vector& sources, vector& unit_sizes, vector& handles, vector& reverse, store_type& store) { + vector shuffles; + for (size_t &handle : handles) + shuffles.push_back(store.get(handle)); + + this->applyMultiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverse); +} + +template +void SecureShuffle::applyMultiple(vector &a, vector &sizes, vector &destinations, + vector &sources, vector &unit_sizes, vector &shuffles, vector &reverse) { const auto n_shuffles = sizes.size(); assert(sources.size() == n_shuffles); assert(destinations.size() == n_shuffles); assert(unit_sizes.size() == n_shuffles); - assert(handles.size() == n_shuffles); + assert(shuffles.size() == n_shuffles); assert(reverse.size() == n_shuffles); - for (size_t i = 0; i < n_shuffles; i++) { - this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], store.get(handles[i]), reverse[i]); + // SecurePermute works by making t players create and "secret-share" a permutation. + // Then each permutation is applied in a pass. As long as one of these permutations was created by a honest party, + // the resulting combined shuffle is hidden. + const auto n_passes = proc.protocol.get_relevant_players().size(); + + // Initialize the shuffles. + vector isExact(n_shuffles, false); + vector> toShuffle; + size_t max_depth = 0; + + for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { + const size_t input_base = sources[currentShuffle]; + const size_t n = sizes[currentShuffle]; + const size_t unit_size = unit_sizes[currentShuffle]; + + assert(shuffles[currentShuffle].size() == n_passes); + assert(n % unit_size == 0); + + const size_t n_shuffle = n / unit_size; + const size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + const bool exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; + + if (log2(n_shuffle) > max_depth) + max_depth = log2(n_shuffle); + + vector tmp; + if (exact) + { + tmp.resize(n_shuffle_pow2 * unit_size); + for (size_t i = 0; i < n; i++) + tmp[i] = a[input_base + i]; + } + else + { + // Pad n_shuffle to n_shuffle_pow2. To reduce this back to n_shuffle after-the-fact, a flag bit is + // added to every real entry. + const size_t shuffle_unit_size = unit_size + 1; + tmp.resize(shuffle_unit_size * n_shuffle_pow2); + for (size_t i = 0; i < n_shuffle; i++) + { + for (size_t j = 0; j < unit_size; j++) + tmp[i * shuffle_unit_size + j] = a[input_base + i * unit_size + j]; + tmp[(i + 1) * shuffle_unit_size - 1] = T::constant(1, proc.P.my_num(), proc.MC.get_alphai()); + } + for (size_t i = n_shuffle * shuffle_unit_size; i < tmp.size(); i++) + tmp[i] = T::constant(0, proc.P.my_num(), proc.MC.get_alphai()); + unit_sizes[currentShuffle] = shuffle_unit_size; + } + + // auto& MC = proc.MC; + // MC.init_open(proc.P); + // for (size_t i = 0; i < tmp.size(); i++) + // MC.prepare_open(tmp[i]); + // MC.exchange(proc.P); + // for (size_t i = 0; i < tmp.size(); i++) + // cout << "Setup tmp[" << i << "]=" << MC.finalize_open() << endl; + + toShuffle.push_back(tmp); + isExact[currentShuffle] = exact; + } + + // Apply the shuffles. + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { + for (size_t pass = 0; pass < n_passes; pass++) + { + const auto isReverse = reverse[current_shuffle]; + size_t configIdx = pass; + if (isReverse) + configIdx = n_passes - pass - 1; + + auto& config = shuffles[current_shuffle][configIdx]; + + int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle]; + for (int depth = 0; depth < log2(n); depth++) { + proc.protocol.init_mul(); + vector> indices = waksman_round_init( + toShuffle[current_shuffle], + unit_sizes[current_shuffle], + depth, + config, + true, + isReverse + ); + proc.protocol.exchange(); + + waksman_round_finish(toShuffle[current_shuffle], unit_sizes[current_shuffle], indices); + } + + for (int depth = log2(n) - 2; depth >= 0; depth--) { + proc.protocol.init_mul(); + vector> indices = waksman_round_init( + toShuffle[current_shuffle], + unit_sizes[current_shuffle], + depth, + config, + false, + isReverse + ); + proc.protocol.exchange(); + + waksman_round_finish(toShuffle[current_shuffle], unit_sizes[current_shuffle], indices); + } + } + } + + // Write the shuffled results into memory. + for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { + const size_t n = sizes[currentShuffle]; + const size_t shuffled_unit_size = unit_sizes[currentShuffle]; + const size_t output_base = destinations[currentShuffle]; + + const vector& shuffledData = toShuffle[currentShuffle]; + + // auto& MC = proc.MC; + // MC.init_open(proc.P); + // for (size_t i = 0; i < shuffledData.size(); i++) + // MC.prepare_open(shuffledData[i]); + // MC.exchange(proc.P); + // for (size_t i = 0; i < shuffledData.size(); i++) + // cout << "Setup shuffledData[" << i << "]=" << MC.finalize_open() << endl; + + if (isExact[currentShuffle]) + for (size_t i = 0; i < n; i++) + a[output_base + i] = shuffledData[i]; + else + { + const size_t original_unit_size = shuffled_unit_size - 1; + const size_t n_shuffle = n / original_unit_size; + const size_t n_shuffle_pow2 = shuffledData.size() / shuffled_unit_size; + + // Reveal the "real element" flags. + auto& MC = proc.MC; + MC.init_open(proc.P); + for (size_t i = 0; i < n_shuffle_pow2; i++) { + MC.prepare_open(shuffledData.at((i + 1) * shuffled_unit_size - 1)); + } + MC.exchange(proc.P); + + // Filter out the real elements. + size_t i_shuffle = 0; + for (size_t i = 0; i < n_shuffle_pow2; i++) + { + auto bit = MC.finalize_open(); + if (bit == 1) + { + // only output real elements + for (size_t j = 0; j < original_unit_size; j++) + a.at(output_base + i_shuffle * original_unit_size + j) = + shuffledData.at(i * shuffled_unit_size + j); + i_shuffle++; + } + } + if (i_shuffle != n_shuffle) + throw runtime_error("incorrect shuffle"); + } } } @@ -461,4 +624,62 @@ void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) swap(tmp, to_shuffle); } +template +vector> SecureShuffle::waksman_round_init(vector &toShuffle, size_t shuffle_unit_size, int depth, vector> &iter_waksman_config, bool inwards, bool reverse) { + int n = toShuffle.size() / shuffle_unit_size; + assert((int) iter_waksman_config.at(depth).size() == n); + int n_blocks = 1 << depth; + int size = n / (2 * n_blocks); + bool outwards = !inwards; + vector> indices; + indices.reserve(n / 2); + Waksman waksman(n); + for (int k = 0; k < n / 2; k++) + { + int j = k % size; + int i = k / size; + int base = 2 * i * size; + int in1 = base + j + j * inwards; + int in2 = in1 + inwards + size * outwards; + int out1 = base + j + j * outwards; + int out2 = out1 + outwards + size * inwards; + int i_bit = base + j + size * (outwards ^ reverse); + bool run = waksman.matters(depth, i_bit); + if (run) + { + for (size_t l = 0; l < shuffle_unit_size; l++) + proc.protocol.prepare_mul(iter_waksman_config.at(depth).at(i_bit), + toShuffle.at(in1 * shuffle_unit_size + l) - toShuffle.at(in2 * shuffle_unit_size + l)); + } + indices.push_back({in1, in2, out1, out2, run}); + } + return indices; +} + +template +void SecureShuffle::waksman_round_finish(vector &toShuffle, size_t unit_size, vector> indices) { + int n = toShuffle.size() / unit_size; + + vector tmp(toShuffle.size()); + for (int k = 0; k < n / 2; k++) + { + auto idx = indices.at(k); + for (size_t l = 0; l < unit_size; l++) + { + T diff; + if (idx[4]) + diff = proc.protocol.finalize_mul(); + tmp.at(idx[2] * unit_size + l) = toShuffle.at( + idx[0] * unit_size + l) - diff; + tmp.at(idx[3] * unit_size + l) = toShuffle.at( + idx[1] * unit_size + l) + diff; + } + } + + // for (size_t i = 0; i < toShuffle.size(); i++) + // toShuffle[i] = tmp.at(i); + swap(tmp, toShuffle); +} + + #endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */