Merge pull request #1554 from vincent-ehrmanntraut/parallel_permutations

Parallelize applyshuffle instruction
This commit is contained in:
Marcel Keller
2024-12-23 13:33:41 +11:00
committed by GitHub
13 changed files with 796 additions and 348 deletions

View File

@@ -2749,9 +2749,10 @@ class gensecshuffle(shuffle_base):
def add_usage(self, req_node):
self.add_gen_usage(req_node, self.args[1])
class applyshuffle(base.VectorInstruction, shuffle_base):
class applyshuffle(shuffle_base, base.Mergeable):
""" Generate secure shuffle to bit used several times.
:param: vector size (int)
:param: destination (sint)
:param: source (sint)
:param: number of elements to be treated as one (int)
@@ -2761,15 +2762,20 @@ class applyshuffle(base.VectorInstruction, shuffle_base):
"""
__slots__ = []
code = base.opcodes['APPLYSHUFFLE']
arg_format = ['sw','s','int','ci','int']
arg_format = itertools.cycle(['int', 'sw','s','int','ci','int'])
is_vec = lambda self: True # Ensures dead-code elimination works.
def __init__(self, *args, **kwargs):
super(applyshuffle, self).__init__(*args, **kwargs)
assert len(args[0]) == len(args[1])
assert len(args[0]) > args[2]
assert (len(args) % 6) == 0
for i in range(0, len(args), 6):
assert args[i] == len(args[i+1])
assert args[i] == len(args[i + 2])
assert args[i] > args[i + 3]
def add_usage(self, req_node):
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
for i in range(0, len(self.args), 6):
self.add_apply_usage(req_node, self.args[i], self.args[i + 3])
class delshuffle(base.Instruction):
""" Delete secure shuffle.

View File

@@ -3122,7 +3122,7 @@ class sint(_secret, _int):
@read_mem_value
def secure_permute(self, shuffle, unit_size=1, reverse=False):
res = sint(size=self.size)
applyshuffle(res, self, unit_size, shuffle, reverse)
applyshuffle(self.size, res, self, unit_size, shuffle, reverse)
return res
def inverse_permutation(self):
@@ -7274,20 +7274,38 @@ class SubMultiArray(_vectorizable):
self.secure_permute(perm)
delshuffle(perm)
def secure_permute(self, permutation, reverse=False, n_threads=None):
def secure_permute(self, permutation, reverse=False, n_threads=None, n_parallel=None):
""" Securely permute rows (first index). See
:py:func:`secure_shuffle` for references.
:param permutation: output of :py:func:`sint.get_secure_shuffle()`
:param reverse: whether to apply inverse (default: False)
:param n_threads: How many threads should be used. Will not multithread when set to None (default: None)
:param n_parallel: How many columns should be permuted in parallel. Will use the compiler's optimization budget is set to None. (default: None).
"""
if n_threads is not None:
permutation = MemValue(permutation)
@library.for_range_multithread(n_threads, 1, self.get_part_size())
def _(i):
self.set_column(i, self.get_column(i).secure_permute(
permutation, reverse=reverse))
if (self.value_type == sint) and (n_threads is None):
# Use only a single shuffle instruction if applicable and permutation is single-threaded anyway.
unit_size = self.get_part_size()
n = self.sizes[0] * unit_size
res = sint(size=n)
applyshuffle(n, res, self[:], unit_size, permutation, reverse)
self.assign_vector(res)
else:
if n_threads is not None:
permutation = MemValue(permutation)
if n_parallel is None:
@library.for_range_opt_multithread(n_threads, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
else:
@library.for_range_multithread(n_threads, n_parallel, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)
self.set_column(i, column)
def sort(self, key_indices=None, n_bits=None, batcher=False):
""" Sort sub-arrays (different first index) in place.

View File

@@ -286,7 +286,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
// instructions with 5 register operands
case PRINTFLOATPLAIN:
case PRINTFLOATPLAINB:
case APPLYSHUFFLE:
get_vector(5, start, s);
break;
case INCINT:
@@ -321,14 +320,11 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case RUN_TAPE:
case CONV2DS:
case MATMULS:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
case APPLYSHUFFLE:
case MATMULSM:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
// read from file, input is opcode num_args,
// start_file_posn (read), end_file_posn(write) var1, var2, ...
case READFILESHARE:
@@ -1187,8 +1183,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.machine.shuffle_store));
return;
case APPLYSHUFFLE:
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)),
Proc.machine.shuffle_store);
Proc.Procp.apply_shuffle(*this, Proc.machine.shuffle_store);
return;
case DELSHUFFLE:
Proc.machine.shuffle_store.del(Proc.read_Ci(r[0]));

View File

@@ -96,8 +96,7 @@ public:
void secure_shuffle(const Instruction& instruction);
size_t generate_secure_shuffle(const Instruction& instruction,
ShuffleStore& shuffle_store);
void apply_shuffle(const Instruction& instruction, int handle,
ShuffleStore& shuffle_store);
void apply_shuffle(const Instruction& instruction, ShuffleStore& shuffle_store);
void inverse_permutation(const Instruction& instruction);
void input_personal(const vector<int>& args);

View File

@@ -908,13 +908,28 @@ size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
}
template<class T>
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle,
ShuffleStore& shuffle_store)
void SubProcessor<T>::apply_shuffle(const Instruction& instruction,
ShuffleStore& shuffle_store)
{
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
instruction.get_start()[0], instruction.get_start()[1],
shuffle_store.get(handle),
instruction.get_start()[4]);
const auto& args = instruction.get_start();
const auto n_shuffles = args.size() / 6;
vector<size_t> sizes(n_shuffles, 0);
vector<size_t> destinations(n_shuffles, 0);
vector<size_t> sources(n_shuffles, 0);
vector<size_t> unit_sizes(n_shuffles, 0);
vector<size_t> shuffles(n_shuffles, 0);
vector<bool> reverse(n_shuffles, false);
for (size_t i = 0; i < n_shuffles; i++) {
sizes[i] = args[6 * i];
destinations[i] = args[6 * i + 1];
sources[i] = args[6 * i + 2];
unit_sizes[i] = args[6 * i + 3];
shuffles[i] = Proc->read_Ci(args[6 * i + 4]);
reverse[i] = args[6 * i + 5];
}
shuffler.apply_multiple(S, sizes, destinations, sources, unit_sizes, shuffles, reverse, shuffle_store);
maybe_check();
}

View File

@@ -0,0 +1,271 @@
from Compiler.library import get_number_of_players
from Compiler.sqrt_oram import n_parallel
from Compiler.util import if_else
def test_allocator():
arr1 = sint.Array(5)
arr2 = sint.Array(10)
arr3 = sint.Array(20)
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(10)
p3 = sint.get_secure_shuffle(20)
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
arr1.secure_permute(p1)
arr2[0] = arr1[0]
arr2.secure_permute(p2)
arr3.secure_permute(p3)
def test_case(permutation_sizes, timer_base: int | None = None):
if timer_base is not None:
start_timer(timer_base + 0)
arrays = []
permutations = []
for size in permutation_sizes:
arrays.append(Array.create_from([sint(i) for i in range(size)]))
permutations.append(sint.get_secure_shuffle(size))
if timer_base is not None:
stop_timer(timer_base + 0)
start_timer(timer_base + 1)
for arr, p in zip(arrays, permutations):
arr.secure_permute(p)
if timer_base is not None:
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
for i, arr in enumerate(arrays):
revealed = arr.reveal()
print_ln("%s", revealed)
n_matched = cint(0)
@for_range(len(arr))
def count_matches(i: cint) -> None:
n_matched.update(n_matched + (revealed[i] == i))
@if_(n_matched == len(arr))
def didnt_permute():
print_ln("Permutation potentially didn't work (permutation might have been identity by chance).")
crash()
if timer_base is not None:
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
for arr, p in zip(arrays, permutations):
arr.secure_permute(p, reverse=True)
if timer_base is not None:
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
for i, arr in enumerate(arrays):
revealed = arr.reveal()
print_ln("%s", revealed)
@for_range(len(arr))
def test_is_original(i: cint) -> None:
@if_(revealed[i] != i)
def fail():
print_ln("Failed to invert permutation!")
crash()
if timer_base is not None:
stop_timer(timer_base + 4)
def test_parallel_permutation_equals_sequential_permutation(sizes: list[int], timer_base: int) -> None:
start_timer(timer_base)
permutations = []
for permutation_size in sizes:
permutations.append(sint.get_secure_shuffle(permutation_size))
stop_timer(timer_base)
start_timer(timer_base + 1)
arrs_to_permute_sequentially = []
arrs_to_permute_parallely = []
for permutation_size in sizes:
arrs_to_permute_sequentially.append(Array.create_from([sint(i) for i in range(permutation_size)]))
arrs_to_permute_parallely.append(Array.create_from([sint(i) for i in range(permutation_size)]))
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
for arr, perm in zip(arrs_to_permute_sequentially, permutations):
arr.secure_permute(perm)
break_point()
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
for arr, perm in zip(arrs_to_permute_parallely, permutations):
arr.secure_permute(perm)
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
arrs_to_permute_sequentially = [arr.reveal() for arr in arrs_to_permute_sequentially]
arrs_to_permute_parallely = [arr.reveal() for arr in arrs_to_permute_parallely]
stop_timer(timer_base + 4)
for (arr_seq, arr_par) in zip(arrs_to_permute_sequentially, arrs_to_permute_parallely):
print_ln("Sequential: %s", arr_seq)
print_ln("Parallel: %s", arr_par)
@for_range(len(arr_seq))
def test_equals(i: cint) -> None:
@if_(arr_seq[i] != arr_par[i])
def fail():
print_ln("Sequentially permuted arrays to not match the parallely permuted arrays.")
crash()
def test_permute_matrix(timer_base: int, value_type=sint) -> None:
def test_permuted_matrix(m, p):
permuted_indices = Array.create_from([sint(i) for i in range(m.sizes[0])])
permuted_indices.secure_permute(p, reverse=True)
permuted_indices = permuted_indices.reveal()
@for_range(m.sizes[0])
def check_row(i):
@for_range(m.sizes[1])
def check_entry(j):
@if_(m[permuted_indices[i]][j] != (m.sizes[1] * i + j))
def fail():
print_ln("Matrix permuted unexpectedly.")
crash()
print_ln("Pre-create matrix")
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m4 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m5 = Matrix.create_from([[value_type(9 * i + j) for j in range(9)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
p2 = sint.get_secure_shuffle(5)
p3 = sint.get_secure_shuffle(5)
start_timer(timer_base + 1)
m1.secure_permute(p1)
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
m2.secure_permute(p2)
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
m3.secure_permute(p2, n_threads=3)
stop_timer(timer_base + 3)
start_timer(timer_base + 4)
m4.secure_permute(p2, n_threads=3, n_parallel=2)
stop_timer(timer_base + 4)
start_timer(timer_base + 5)
m5.secure_permute(p3, n_threads=3, n_parallel=3)
stop_timer(timer_base + 5)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
m3 = m3.reveal()
m4 = m4.reveal()
m5 = m5.reveal()
print_ln("Permuted m1:")
for row in m1:
print_ln("%s", row)
test_permuted_matrix(m1, p1)
print_ln("Permuted m2:")
for row in m2:
print_ln("%s", row)
test_permuted_matrix(m2, p2)
print_ln("Permuted m3 (should be equal to m2):")
for row in m3:
print_ln("%s", row)
test_permuted_matrix(m3, p2)
print_ln("Permuted m4 (should be equal to m2):")
for row in m4:
print_ln("%s", row)
test_permuted_matrix(m4, p2)
print_ln("Permuted m5:")
for row in m5:
print_ln("%s", row)
test_permuted_matrix(m5, p3)
def test_secure_shuffle_still_works(size: int, timer_base: int):
arr = Array.create_from([sint(i) for i in range(size)])
start_timer(timer_base)
arr.secure_shuffle()
stop_timer(timer_base)
arr = arr.reveal()
n_matched = cint(0)
@for_range(len(arr))
def count_matches(i: cint) -> None:
n_matched.update(n_matched + (arr[i] == i))
@if_(n_matched == len(arr))
def didnt_permute():
print_ln("Shuffle potentially didn't work (permutation might have been identity by chance).")
crash()
def test_inverse_permutation_still_works(size, timer_base: int):
@if_e(get_number_of_players() == 2)
def test():
program.use_invperm(True)
arr = Array.create_from([sint(i) for i in range(size)])
p_as_arr = Array.create_from(arr[:])
p_inv_as_arr = Array.create_from(arr[:])
start_timer(timer_base + 1)
p1 = sint.get_secure_shuffle(size)
p_as_arr.secure_permute(p1, reverse=True)
p_inv_as_arr.secure_permute(p1)
p_inv_as_arr = p_inv_as_arr.reveal()
stop_timer(timer_base + 1)
start_timer(timer_base + 2)
p_inv2 = Array.create_from(p_as_arr[:].inverse_permutation())
# p_inv2.inverse_permutation()
p_inv2 = p_inv2.reveal()
stop_timer(timer_base + 2)
print_ln("Permutation: %s", p_as_arr.reveal())
print_ln("Inverse from secure_permute: %s", p_inv_as_arr)
print_ln("Inverse from inverse_permut: %s", p_inv2)
@for_range(size)
def check(i):
@if_(p_inv_as_arr[i] != p_inv2[i])
def fail():
print_ln("Inverse permutation don't match.")
crash()
@else_
def _():
print_ln("Inverse permutation is only tested in 2-party computation.")
def test_dead_code_elimination():
vector = sint([0,2,4,6,5,3,1])
handle = sint.get_secure_shuffle(7)
print_ln('%s', vector.secure_permute(handle).reveal())
test_allocator()
test_case([10, 15], 10)
test_case([10, 15, 20], 20)
test_case([16,32], 30)
test_case([256], 40)
test_parallel_permutation_equals_sequential_permutation([5,10],50)
test_permute_matrix(60)
test_permute_matrix(70, value_type=sfix)
test_secure_shuffle_still_works(32, 80)
test_inverse_permutation_still_works(8, 80)
test_dead_code_elimination()

View File

@@ -37,8 +37,8 @@ public:
return store.add();
}
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int, bool)
void apply(StackedVector<T>& a, size_t n, size_t unit_size, size_t output_base,
size_t input_base, size_t, bool)
{
auto source = a.begin() + input_base;
auto dest = a.begin() + output_base;
@@ -49,13 +49,29 @@ public:
if (n > 1)
{
// swap first two to pass check
for (int i = 0; i < unit_size; i++)
for (size_t i = 0; i < unit_size; i++)
swap(a[output_base + i], a[output_base + i + unit_size]);
}
}
void inverse_permutation(StackedVector<T>&, size_t, size_t, size_t)
{
void inverse_permutation(StackedVector<T> &, size_t, size_t, size_t) {
throw runtime_error("inverse permutation not implemented");
};
void apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources,
vector<size_t> &unit_sizes, vector<size_t> &handles, vector<bool> &reverses,
store_type&) {
const auto n_shuffles = sizes.size();
assert(sources.size() == n_shuffles);
assert(destinations.size() == n_shuffles);
assert(unit_sizes.size() == n_shuffles);
assert(handles.size() == n_shuffles);
assert(reverses.size() == n_shuffles);
for (size_t i = 0; i < n_shuffles; i++) {
this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], handles[i], reverses[i]);
}
}
};

View File

@@ -28,8 +28,10 @@ public:
int generate(int n_shuffle, store_type& store);
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, shuffle_type& shuffle, bool reverse);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<size_t>& handles, vector<bool>& reverse, store_type& store);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<shuffle_type>& shuffles, vector<bool>& reverse);
void inverse_permutation(StackedVector<T>& stack, size_t n, size_t output_base,
size_t input_base);

View File

@@ -9,34 +9,33 @@
#include "Rep3Shuffler.h"
template<class T>
Rep3Shuffler<T>::Rep3Shuffler(StackedVector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T>& proc) :
proc(proc)
{
Rep3Shuffler<T>::Rep3Shuffler(StackedVector<T> &a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T> &proc) : proc(proc) {
store_type store;
int handle = generate(n / unit_size, store);
apply(a, n, unit_size, output_base, input_base, store.get(handle),
false);
vector<size_t> sizes{n};
vector<size_t> unit_sizes{static_cast<size_t>(unit_size)};
vector<size_t> destinations{output_base};
vector<size_t> sources{input_base};
vector<shuffle_type> shuffles{store.get(handle)};
vector<bool> reverses{true};
this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses);
}
template<class T>
Rep3Shuffler<T>::Rep3Shuffler(SubProcessor<T>& proc) :
proc(proc)
{
Rep3Shuffler<T>::Rep3Shuffler(SubProcessor<T> &proc) : proc(proc) {
}
template<class T>
int Rep3Shuffler<T>::generate(int n_shuffle, store_type& store)
{
int Rep3Shuffler<T>::generate(int n_shuffle, store_type &store) {
int res = store.add();
auto& shuffle = store.get(res);
for (int i = 0; i < 2; i++)
{
auto& perm = shuffle[i];
auto &shuffle = store.get(res);
for (int i = 0; i < 2; i++) {
auto &perm = shuffle[i];
for (int j = 0; j < n_shuffle; j++)
perm.push_back(j);
for (int j = 0; j < n_shuffle; j++)
{
for (int j = 0; j < n_shuffle; j++) {
int k = proc.protocol.shared_prngs[i].get_uint(n_shuffle - j);
swap(perm[j], perm[k + j]);
}
@@ -45,85 +44,129 @@ int Rep3Shuffler<T>::generate(int n_shuffle, store_type& store)
}
template<class T>
void Rep3Shuffler<T>::apply(StackedVector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, shuffle_type& shuffle,
bool reverse)
{
void Rep3Shuffler<T>::apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources,
vector<size_t> &unit_sizes, vector<size_t> &handles, vector<bool> &reverses,
store_type &store) {
vector<shuffle_type> shuffles;
for (size_t &handle: handles) {
shuffle_type &shuffle = store.get(handle);
shuffles.push_back(shuffle);
}
apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses);
}
template<class T>
void Rep3Shuffler<T>::apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources, vector<size_t> &unit_sizes, vector<shuffle_type> &shuffles,
vector<bool> &reverses) {
const auto n_shuffles = sizes.size();
assert(sources.size() == n_shuffles);
assert(destinations.size() == n_shuffles);
assert(unit_sizes.size() == n_shuffles);
assert(shuffles.size() == n_shuffles);
assert(reverses.size() == n_shuffles);
assert(proc.P.num_players() == 3);
assert(not T::malicious);
assert(not T::dishonest_majority);
assert(n % unit_size == 0);
if (shuffle.empty())
throw runtime_error("shuffle has been deleted");
// for (size_t i = 0; i < n_shuffles; i++) {
// this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], store.get(handles[i]), reverses[i]);
// }
stats[n / unit_size] += unit_size;
vector<vector<T> > to_shuffle;
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
assert(sizes[current_shuffle] % unit_sizes[current_shuffle] == 0);
vector<T> x;
for (size_t j = 0; j < sizes[current_shuffle]; j++)
x.push_back(a[sources[current_shuffle] + j]);
to_shuffle.push_back(x);
vector<T> to_shuffle;
for (size_t i = 0; i < n; i++)
to_shuffle.push_back(a[input_base + i]);
const auto &shuffle = shuffles[current_shuffle];
if (shuffle.empty())
throw runtime_error("shuffle has been deleted");
stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle];
}
typename T::Input input(proc);
vector<typename T::clear> to_share(n);
for (int ii = 0; ii < 3; ii++)
{
int i;
if (reverse)
i = 2 - ii;
else
i = ii;
if (proc.P.get_player(i) == 0)
{
for (size_t j = 0; j < n / unit_size; j++)
for (int k = 0; k < unit_size; k++)
if (reverse)
to_share.at(j * unit_size + k) = to_shuffle.at(
shuffle[0].at(j) * unit_size + k).sum();
else
to_share.at(shuffle[0].at(j) * unit_size + k) =
to_shuffle.at(j * unit_size + k).sum();
}
else if (proc.P.get_player(i) == 1)
{
for (size_t j = 0; j < n / unit_size; j++)
for (int k = 0; k < unit_size; k++)
if (reverse)
to_share[j * unit_size + k] = to_shuffle[shuffle[1][j]
* unit_size + k][0];
else
to_share[shuffle[1][j] * unit_size + k] = to_shuffle[j
* unit_size + k][0];
}
for (int pass = 0; pass < 3; pass++) {
input.reset_all(proc.P);
if (proc.P.get_player(i) < 2)
for (auto& x : to_share)
input.add_mine(x);
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
const auto n = sizes[current_shuffle];
const auto unit_size = unit_sizes[current_shuffle];
const auto &shuffle = shuffles[current_shuffle];
const auto reverse = reverses[current_shuffle];
const auto current_to_shuffle = to_shuffle[current_shuffle];
for (int k = 0; k < 2; k++)
input.add_other((-i + 3 + k) % 3);
vector<typename T::clear> to_share(n);
int i;
if (reverse)
i = 2 - pass;
else
i = pass;
if (proc.P.get_player(i) == 0) {
for (size_t j = 0; j < n / unit_size; j++)
for (size_t k = 0; k < unit_size; k++)
if (reverse)
to_share.at(j * unit_size + k) = current_to_shuffle.at(
shuffle[0].at(j) * unit_size + k).sum();
else
to_share.at(shuffle[0].at(j) * unit_size + k) =
current_to_shuffle.at(j * unit_size + k).sum();
} else if (proc.P.get_player(i) == 1) {
for (size_t j = 0; j < n / unit_size; j++)
for (size_t k = 0; k < unit_size; k++)
if (reverse)
to_share[j * unit_size + k] = current_to_shuffle[shuffle[1][j] * unit_size + k][0];
else
to_share[shuffle[1][j] * unit_size + k] = current_to_shuffle[j * unit_size + k][0];
}
if (proc.P.get_player(i) < 2)
for (auto &x: to_share)
input.add_mine(x);
for (int k = 0; k < 2; k++)
input.add_other((-i + 3 + k) % 3);
}
input.exchange();
to_shuffle.clear();
for (size_t j = 0; j < n; j++)
{
T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3);
to_shuffle.push_back(x);
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
const auto n = sizes[current_shuffle];
const auto reverse = reverses[current_shuffle];
int i;
if (reverse)
i = 2 - pass;
else
i = pass;
vector<T> tmp;
for (size_t j = 0; j < n; j++) {
T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3);
tmp.push_back(x);
}
to_shuffle.push_back(tmp);
}
}
for (size_t i = 0; i < n; i++)
a[output_base + i] = to_shuffle[i];
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
const auto n = sizes[current_shuffle];
for (size_t i = 0; i < n; i++)
a[destinations[current_shuffle] + i] = to_shuffle[current_shuffle][i];
}
}
template<class T>
void Rep3Shuffler<T>::inverse_permutation(StackedVector<T>&, size_t, size_t, size_t)
{
void Rep3Shuffler<T>::inverse_permutation(StackedVector<T> &, size_t, size_t, size_t) {
throw runtime_error("inverse permutation not implemented");
}

View File

@@ -40,13 +40,6 @@ public:
private:
SubProcessor<T>& proc;
vector<T> to_shuffle;
vector<vector<T>> config;
vector<T> tmp;
int unit_size;
size_t n_shuffle;
bool exact;
/**
* Generates and returns a newly generated random permutation. This permutation is generated locally.
@@ -66,19 +59,16 @@ private:
* e.g. [2, 4, 0, 3, 1] -> perm(1) = 4
*
* @param config_player The player tasked with generating the random permutation from which to configure the waksman network.
* @param n_shuffle The size of the permutation to generate.
* @param n The size of the permutation to generate.
*/
void configure(int config_player, vector<int>* perm, int n);
void player_round(int config_player);
vector<vector<T>> configure(int config_player, vector<int>* perm, int n);
void waksman(StackedVector<T>& a, int depth, int start);
void cond_swap(T& x, T& y, const T& b);
int prep_multiple(StackedVector<T>& a, vector<size_t> &sizes, vector<size_t> &sources, vector<size_t> &unit_sizes, vector<vector<T>>& to_shuffle, vector<bool> &exact);
void finalize_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& unit_sizes, vector<size_t>& destinations, vector<bool>& isExact, vector<vector<T>>& to_shuffle);
void iter_waksman(bool reverse = false);
void waksman_round(int size, bool inwards, bool reverse);
void pre(StackedVector<T>& a, size_t n, size_t input_base);
void post(StackedVector<T>& a, size_t n, size_t input_base);
void parallel_waksman_round(size_t pass, int depth, bool inwards, vector<vector<T>>& toShuffle, vector<size_t>& unit_sizes, vector<bool>& reverse, vector<shuffle_type>& shuffles);
vector<array<int, 5>> waksman_round_init(vector<T>& toShuffle, size_t shuffle_unit_size, int depth, vector<vector<T>>& iter_waksman_config, bool inwards, bool reverse);
void waksman_round_finish(vector<T>& toShuffle, size_t unit_size, vector<array<int, 5>> indices);
public:
map<long, long> stats;
@@ -90,21 +80,10 @@ public:
int generate(int n_shuffle, store_type& store);
/**
*
* @param a The vector of registers representing the stack // TODO: Is this correct?
* @param n The size of the input vector to shuffle
* @param unit_size Determines how many vector items constitute a single block with regards to permutation:
* i.e. input vector [1,2,3,4] with <code>unit_size=2</code> under permutation map [1,0]
* would result in [3,4,1,2]
* @param output_base The starting address of the output vector (i.e. the location to write the inverted permutation to)
* @param input_base The starting address of the input vector (i.e. the location from which to read the permutation)
* @param shuffle The preconfigured waksman network (shuffle) to use
* @param reverse Boolean indicating whether to apply the inverse of the permutation
* @see SecureShuffle::generate for obtaining a shuffle handle
*/
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, shuffle_type& shuffle, bool reverse);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<size_t>& handles, vector<bool>& reverse, store_type& store);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<shuffle_type>& shuffles, vector<bool>& reverse);
/**
* Calculate the secret inverse permutation of stack given secret permutation.

View File

@@ -53,49 +53,68 @@ void ShuffleStore<T>::del(int handle)
template<class T>
SecureShuffle<T>::SecureShuffle(SubProcessor<T>& proc) :
proc(proc), unit_size(0), n_shuffle(0), exact(false)
proc(proc)
{
}
template<class T>
SecureShuffle<T>::SecureShuffle(StackedVector<T>& a, size_t n, int unit_size,
size_t output_base, size_t input_base, SubProcessor<T>& proc) :
proc(proc), unit_size(unit_size), n_shuffle(0), exact(false)
proc(proc)
{
pre(a, n, input_base);
store_type store;
int handle = generate(n / unit_size, store);
for (auto i : proc.protocol.get_relevant_players())
player_round(i);
post(a, n, output_base);
vector<size_t> sizes{n};
vector<size_t> unit_sizes{static_cast<size_t>(unit_size)};
vector<size_t> destinations{output_base};
vector<size_t> sources{input_base};
vector<shuffle_type> shuffles{store.get(handle)};
vector<bool> reverses{true};
this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses);
}
template<class T>
void SecureShuffle<T>::apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, shuffle_type& shuffle, bool reverse)
{
this->unit_size = unit_size;
void SecureShuffle<T>::apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<size_t>& handles, vector<bool>& reverse, store_type& store) {
vector<shuffle_type> shuffles;
for (size_t &handle : handles)
shuffles.push_back(store.get(handle));
stats[n / unit_size] += unit_size;
this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverse);
}
pre(a, n, input_base);
template<class T>
void SecureShuffle<T>::apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources, vector<size_t> &unit_sizes, vector<shuffle_type> &shuffles, vector<bool> &reverse) {
const auto n_shuffles = sizes.size();
assert(sources.size() == n_shuffles);
assert(destinations.size() == n_shuffles);
assert(unit_sizes.size() == n_shuffles);
assert(shuffles.size() == n_shuffles);
assert(reverse.size() == n_shuffles);
assert(shuffle.size() == proc.protocol.get_relevant_players().size());
// SecureShuffle works by making t players create and "secret-share" a permutation.
// Then each permutation is applied in a pass. As long as one of these permutations was created by an honest party,
// the resulting combined shuffle is hidden.
const auto n_passes = proc.protocol.get_relevant_players().size();
if (reverse)
for (auto it = shuffle.end(); it > shuffle.begin(); it--)
{
this->config = *(it - 1);
iter_waksman(reverse);
}
else
for (auto& config : shuffle)
{
this->config = config;
iter_waksman(reverse);
}
// Initialize the shuffles.
vector is_exact(n_shuffles, false);
vector<vector<T>> to_shuffle;
int max_depth = prep_multiple(a, sizes, sources, unit_sizes, to_shuffle, is_exact);
post(a, n, output_base);
// Apply the shuffles.
for (size_t pass = 0; pass < n_passes; pass++)
{
for (int depth = 0; depth < max_depth; depth++)
parallel_waksman_round(pass, depth, true, to_shuffle, unit_sizes, reverse, shuffles);
for (int depth = max_depth - 1; depth >= 0; depth--)
parallel_waksman_round(pass, depth, false, to_shuffle, unit_sizes, reverse, shuffles);
}
// Write the shuffled results into memory.
finalize_multiple(a, sizes, unit_sizes, destinations, is_exact, to_shuffle);
}
@@ -115,27 +134,38 @@ void SecureShuffle<T>::inverse_permutation(StackedVector<T> &stack, size_t n, si
if (T::malicious)
throw runtime_error("inverse permutation only implemented for semi-honest protocols");
// We are dealing directly with permutations, so the unit_size will always be 1.
this->unit_size = 1;
// We need to account for sizes which are not a power of 2
size_t n_pow2 = (1u << int(ceil(log2(n))));
vector<size_t> sizes { n };
vector<size_t> unit_sizes { 1 }; // We are dealing directly with permutations, so the unit_size will always be 1.
vector<size_t> destinations { output_base };
vector<size_t> sources { input_base };
vector<bool> reverse { true };
vector<vector<T>> to_shuffle;
vector<bool> is_exact(1, false);
// Copy over the input registers
pre(stack, n, input_base);
prep_multiple(stack, sizes, sources, unit_sizes, to_shuffle, is_exact);
size_t shuffle_size = to_shuffle[0].size() / unit_sizes[0];
// Alice generates stack local permutation and shares the waksman configuration bits secretly to Bob.
vector<int> perm_alice(n_pow2);
if (P.my_num() == alice)
vector<int> perm_alice(shuffle_size);
if (P.my_num() == alice) {
perm_alice = generate_random_permutation(n);
configure(alice, &perm_alice, n);
}
auto config = configure(alice, &perm_alice, n);
vector<shuffle_type> shuffles {{ config, config }};
// Apply perm_alice to perm_alice to get perm_bob,
// stack permutation that we can reveal to Bob without Bob learning anything about perm_alice (since it is masked by perm_a)
iter_waksman(true);
for (int depth = 0; depth < log2(shuffle_size); depth++)
parallel_waksman_round(0, depth, true, to_shuffle, unit_sizes, reverse, shuffles);
for (int depth = log2(shuffle_size); depth >= 0; depth--)
parallel_waksman_round(0, depth, false, to_shuffle, unit_sizes, reverse, shuffles);
// Store perm_bob at stack[output_base]
post(stack, n, output_base);
finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle);
// Reveal permutation perm_bob = perm_a * perm_alice
// Since this permutation is masked by perm_a, Bob learns nothing about perm
vector<int> perm_bob(n_pow2);
vector<int> perm_bob(shuffle_size);
typename T::PrivateOutput output(proc);
for (size_t i = 0; i < n; i++)
output.prepare_sending(stack[output_base + i], bob);
@@ -147,13 +177,13 @@ void SecureShuffle<T>::inverse_permutation(StackedVector<T> &stack, size_t n, si
perm_bob[i] = (int) val.get_si();
}
vector<int> perm_bob_inv(n_pow2);
vector<int> perm_bob_inv(shuffle_size);
if (P.my_num() == bob) {
for (int i = 0; i < (int) n; i++)
perm_bob_inv[perm_bob[i]] = i;
// Pad the permutation to n_pow2
// Required when using waksman networks
for (int i = (int) n; i < (int) n_pow2; i++)
for (int i = (int) n; i < (int) shuffle_size; i++)
perm_bob_inv[i] = i;
}
@@ -169,74 +199,118 @@ void SecureShuffle<T>::inverse_permutation(StackedVector<T> &stack, size_t n, si
stack[output_base + i] = input.finalize(alice);
// The two parties now jointly compute perm_a * perm_bob_inv to obtain perm_inv
pre(stack, n, output_base);
configure(bob, &perm_bob_inv, n);
iter_waksman(true);
// perm_inv is written back to stack[output_base]
post(stack, n, output_base);
}
template<class T>
void SecureShuffle<T>::pre(StackedVector<T>& a, size_t n, size_t input_base)
{
n_shuffle = n / unit_size;
assert(unit_size * n_shuffle == n);
size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle))));
exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious;
to_shuffle.clear();
prep_multiple(stack, sizes, destinations, unit_sizes, to_shuffle, is_exact);
if (exact)
{
to_shuffle.resize(n_shuffle_pow2 * unit_size);
for (size_t i = 0; i < n; i++)
to_shuffle[i] = a[input_base + i];
}
else
{
// sorting power of two elements together with indicator bits
to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle))));
for (size_t i = 0; i < n_shuffle; i++)
{
for (int j = 0; j < unit_size; j++)
to_shuffle[i * (unit_size + 1) + j] = a[input_base
+ i * unit_size + j];
to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1,
proc.P.my_num(), proc.MC.get_alphai());
}
this->unit_size++;
}
config = configure(bob, &perm_bob_inv, n);
shuffles[0] = { config, config };
for (int i = 0; i < log2(shuffle_size); i++)
parallel_waksman_round(0, i, true, to_shuffle, unit_sizes, reverse, shuffles);
for (int i = log2(shuffle_size) - 2; i >= 0; i--)
parallel_waksman_round(0, i, false, to_shuffle, unit_sizes, reverse, shuffles);
// Store perm_bob at stack[output_base]
finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle);
}
template<class T>
void SecureShuffle<T>::post(StackedVector<T>& a, size_t n, size_t output_base)
{
if (exact)
for (size_t i = 0; i < n; i++)
a[output_base + i] = to_shuffle[i];
else
{
auto& MC = proc.MC;
MC.init_open(proc.P);
int shuffle_unit_size = this->unit_size;
int unit_size = shuffle_unit_size - 1;
for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++)
MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1));
MC.exchange(proc.P);
size_t i_shuffle = 0;
for (size_t i = 0; i < n_shuffle; i++)
int SecureShuffle<T>::prep_multiple(StackedVector<T> &a, vector<size_t> &sizes,
vector<size_t> &sources, vector<size_t> &unit_sizes, vector<vector<T>> &to_shuffle, vector<bool> &is_exact) {
int max_depth = 0;
const size_t n_shuffles = sizes.size();
for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) {
const size_t input_base = sources[currentShuffle];
const size_t n = sizes[currentShuffle];
const size_t unit_size = unit_sizes[currentShuffle];
assert(n % unit_size == 0);
const size_t n_shuffle = n / unit_size;
const size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle))));
const bool exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious;
vector<T> tmp;
if (exact)
{
auto bit = MC.finalize_open();
if (bit == 1)
{
// only output real elements
for (int j = 0; j < unit_size; j++)
a.at(output_base + i_shuffle * unit_size + j) =
to_shuffle.at(i * shuffle_unit_size + j);
i_shuffle++;
}
tmp.resize(n_shuffle_pow2 * unit_size);
for (size_t i = 0; i < n; i++)
tmp[i] = a[input_base + i];
}
else
{
// Pad n_shuffle to n_shuffle_pow2. To reduce this back to n_shuffle after-the-fact, a flag bit is
// added to every real entry.
const size_t shuffle_unit_size = unit_size + 1;
tmp.resize(shuffle_unit_size * n_shuffle_pow2);
for (size_t i = 0; i < n_shuffle; i++)
{
for (size_t j = 0; j < unit_size; j++)
tmp[i * shuffle_unit_size + j] = a[input_base + i * unit_size + j];
tmp[(i + 1) * shuffle_unit_size - 1] = T::constant(1, proc.P.my_num(), proc.MC.get_alphai());
}
for (size_t i = n_shuffle * shuffle_unit_size; i < tmp.size(); i++)
tmp[i] = T::constant(0, proc.P.my_num(), proc.MC.get_alphai());
unit_sizes[currentShuffle] = shuffle_unit_size;
}
to_shuffle.push_back(tmp);
is_exact[currentShuffle] = exact;
const int shuffle_depth = tmp.size() / unit_size;
if (shuffle_depth > max_depth)
max_depth = shuffle_depth;
}
return max_depth;
}
template<class T>
void SecureShuffle<T>::finalize_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &unit_sizes,
vector<size_t> &destinations, vector<bool> &isExact, vector<vector<T>> &to_shuffle) {
const size_t n_shuffles = sizes.size();
for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) {
const size_t n = sizes[currentShuffle];
const size_t shuffled_unit_size = unit_sizes[currentShuffle];
const size_t output_base = destinations[currentShuffle];
const vector<T>& shuffledData = to_shuffle[currentShuffle];
if (isExact[currentShuffle])
for (size_t i = 0; i < n; i++)
a[output_base + i] = shuffledData[i];
else
{
const size_t original_unit_size = shuffled_unit_size - 1;
const size_t n_shuffle = n / original_unit_size;
const size_t n_shuffle_pow2 = shuffledData.size() / shuffled_unit_size;
// Reveal the "real element" flags.
auto& MC = proc.MC;
MC.init_open(proc.P);
for (size_t i = 0; i < n_shuffle_pow2; i++) {
MC.prepare_open(shuffledData.at((i + 1) * shuffled_unit_size - 1));
}
MC.exchange(proc.P);
// Filter out the real elements.
size_t i_shuffle = 0;
for (size_t i = 0; i < n_shuffle_pow2; i++)
{
auto bit = MC.finalize_open();
if (bit == 1)
{
// only output real elements
for (size_t j = 0; j < original_unit_size; j++)
a.at(output_base + i_shuffle * original_unit_size + j) =
shuffledData.at(i * shuffled_unit_size + j);
i_shuffle++;
}
}
if (i_shuffle != n_shuffle)
throw runtime_error("incorrect shuffle");
}
if (i_shuffle != n_shuffle)
throw runtime_error("incorrect shuffle");
}
}
@@ -256,15 +330,6 @@ vector<int> SecureShuffle<T>::generate_random_permutation(int n) {
return perm;
}
template<class T>
void SecureShuffle<T>::player_round(int config_player) {
vector<int> random_perm(n_shuffle);
if (proc.P.my_num() == config_player)
random_perm = generate_random_permutation(n_shuffle);
configure(config_player, &random_perm, n_shuffle);
iter_waksman();
}
template<class T>
int SecureShuffle<T>::generate(int n_shuffle, store_type& store)
{
@@ -275,8 +340,7 @@ int SecureShuffle<T>::generate(int n_shuffle, store_type& store)
vector<int> perm;
if (proc.P.my_num() == i)
perm = generate_random_permutation(n_shuffle);
configure(i, &perm, n_shuffle);
auto config = configure(i, &perm, n_shuffle);
shuffle.push_back(config);
}
@@ -284,7 +348,7 @@ int SecureShuffle<T>::generate(int n_shuffle, store_type& store)
}
template<class T>
void SecureShuffle<T>::configure(int config_player, vector<int> *perm, int n) {
vector<vector<T>> SecureShuffle<T>::configure(int config_player, vector<int> *perm, int n) {
auto &P = proc.P;
auto &input = proc.input;
input.reset_all(P);
@@ -311,7 +375,7 @@ void SecureShuffle<T>::configure(int config_player, vector<int> *perm, int n) {
input.add_other(config_player);
input.exchange();
config.clear();
vector<vector<T>> config;
typename T::Protocol checker(P);
checker.init(proc.DataF, proc.MC);
checker.init_dotprod();
@@ -345,68 +409,61 @@ void SecureShuffle<T>::configure(int config_player, vector<int> *perm, int n) {
P)) == 0);
checker.check();
}
return config;
}
template<class T>
void SecureShuffle<T>::waksman(StackedVector<T>& a, int depth, int start)
void SecureShuffle<T>::parallel_waksman_round(size_t pass, int depth, bool inwards, vector<vector<T>> &toShuffle,
vector<size_t> &unit_sizes, vector<bool> &reverse, vector<shuffle_type> &shuffles)
{
int n = a.size();
const auto n_passes = proc.protocol.get_relevant_players().size();
const auto n_shuffles = shuffles.size();
if (n == 2)
{
cond_swap(a[0], a[1], config.at(depth).at(start));
return;
}
vector<T> a0(n / 2), a1(n / 2);
for (int i = 0; i < n / 2; i++)
{
a0.at(i) = a.at(2 * i);
a1.at(i) = a.at(2 * i + 1);
cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2));
}
waksman(a0, depth + 1, start);
waksman(a1, depth + 1, start + n / 2);
for (int i = 0; i < n / 2; i++)
{
a.at(2 * i) = a0.at(i);
a.at(2 * i + 1) = a1.at(i);
cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start));
}
}
template<class T>
void SecureShuffle<T>::cond_swap(T& x, T& y, const T& b)
{
auto diff = proc.protocol.mul(x - y, b);
x -= diff;
y += diff;
}
template<class T>
void SecureShuffle<T>::iter_waksman(bool reverse)
{
int n = to_shuffle.size() / unit_size;
for (int depth = 0; depth < log2(n); depth++)
waksman_round(depth, true, reverse);
for (int depth = log2(n) - 2; depth >= 0; depth--)
waksman_round(depth, false, reverse);
}
template<class T>
void SecureShuffle<T>::waksman_round(int depth, bool inwards, bool reverse)
{
int n = to_shuffle.size() / unit_size;
assert((int) config.at(depth).size() == n);
int nblocks = 1 << depth;
int size = n / (2 * nblocks);
bool outwards = !inwards;
vector<vector<array<int, 5>>> allIndices;
proc.protocol.init_mul();
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle];
if (depth >= log2(n) - !inwards) {
allIndices.push_back({});
continue;
}
const auto isReverse = reverse[current_shuffle];
size_t configIdx = pass;
if (isReverse)
configIdx = n_passes - pass - 1;
auto& config = shuffles[current_shuffle][configIdx];
vector<array<int, 5>> indices = waksman_round_init(
toShuffle[current_shuffle],
unit_sizes[current_shuffle],
depth,
config,
inwards,
isReverse
);
allIndices.push_back(indices);
}
proc.protocol.exchange();
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle];
if (depth >= log2(n) - !inwards) {
continue;
}
waksman_round_finish(toShuffle[current_shuffle], unit_sizes[current_shuffle], allIndices[current_shuffle]);
}
}
template<class T>
vector<array<int, 5>> SecureShuffle<T>::waksman_round_init(vector<T> &toShuffle, size_t shuffle_unit_size, int depth, vector<vector<T>> &iter_waksman_config, bool inwards, bool reverse) {
int n = toShuffle.size() / shuffle_unit_size;
assert((int) iter_waksman_config.at(depth).size() == n);
int n_blocks = 1 << depth;
int size = n / (2 * n_blocks);
bool outwards = !inwards;
vector<array<int, 5>> indices;
indices.reserve(n / 2);
Waksman waksman(n);
@@ -423,30 +480,37 @@ void SecureShuffle<T>::waksman_round(int depth, bool inwards, bool reverse)
bool run = waksman.matters(depth, i_bit);
if (run)
{
for (int l = 0; l < unit_size; l++)
proc.protocol.prepare_mul(config.at(depth).at(i_bit),
to_shuffle.at(in1 * unit_size + l)
- to_shuffle.at(in2 * unit_size + l));
for (size_t l = 0; l < shuffle_unit_size; l++)
proc.protocol.prepare_mul(iter_waksman_config.at(depth).at(i_bit),
toShuffle.at(in1 * shuffle_unit_size + l) - toShuffle.at(in2 * shuffle_unit_size + l));
}
indices.push_back({{in1, in2, out1, out2, run}});
indices.push_back({in1, in2, out1, out2, run});
}
proc.protocol.exchange();
tmp.resize(to_shuffle.size());
return indices;
}
template<class T>
void SecureShuffle<T>::waksman_round_finish(vector<T> &toShuffle, size_t unit_size, vector<array<int, 5>> indices) {
int n = toShuffle.size() / unit_size;
vector<T> tmp(toShuffle.size());
for (int k = 0; k < n / 2; k++)
{
auto idx = indices.at(k);
for (int l = 0; l < unit_size; l++)
for (size_t l = 0; l < unit_size; l++)
{
T diff;
if (idx[4])
diff = proc.protocol.finalize_mul();
tmp.at(idx[2] * unit_size + l) = to_shuffle.at(
tmp.at(idx[2] * unit_size + l) = toShuffle.at(
idx[0] * unit_size + l) - diff;
tmp.at(idx[3] * unit_size + l) = to_shuffle.at(
tmp.at(idx[3] * unit_size + l) = toShuffle.at(
idx[1] * unit_size + l) + diff;
}
}
swap(tmp, to_shuffle);
swap(tmp, toShuffle);
}
#endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */

View File

@@ -30,8 +30,10 @@ public:
int generate(int n_shuffle, store_type& store);
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, shuffle_type& shuffle, bool reverse);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<size_t>& handles, vector<bool>& reverse, store_type& store);
void apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<shuffle_type>& shuffles, vector<bool>& reverse);
void inverse_permutation(StackedVector<T>& stack, size_t n, size_t output_base,
size_t input_base);

View File

@@ -13,8 +13,14 @@ SpdzWiseRep3Shuffler<T>::SpdzWiseRep3Shuffler(StackedVector<T>& a, size_t n,
{
store_type store;
int handle = generate(n / unit_size, store);
apply(a, n, unit_size, output_base, input_base, store.get(handle),
false);
vector<size_t> sizes{n};
vector<size_t> unit_sizes{static_cast<size_t>(unit_size)};
vector<size_t> destinations{output_base};
vector<size_t> sources{input_base};
vector<shuffle_type> shuffles{store.get(handle)};
vector<bool> reverses{true};
this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses);
}
template<class T>
@@ -30,31 +36,63 @@ int SpdzWiseRep3Shuffler<T>::generate(int n_shuffle, store_type& store)
}
template<class T>
void SpdzWiseRep3Shuffler<T>::apply(StackedVector<T>& a, size_t n,
int unit_size, size_t output_base, size_t input_base,
shuffle_type& shuffle, bool reverse)
{
stats[n / unit_size] += unit_size;
StackedVector<typename T::part_type::Honest> to_shuffle;
to_shuffle.reserve(2 * n);
for (size_t i = 0; i < n; i++)
{
auto& x = a[input_base + i];
to_shuffle.push_back(x.get_share());
to_shuffle.push_back(x.get_mac());
void SpdzWiseRep3Shuffler<T>::apply_multiple(StackedVector<T>& a, vector<size_t>& sizes, vector<size_t>& destinations, vector<size_t>& sources,
vector<size_t>& unit_sizes, vector<size_t>& handles, vector<bool>& reverses, store_type& store) {
vector<shuffle_type> shuffles;
for (size_t &handle : handles) {
shuffle_type& shuffle = store.get(handle);
shuffles.push_back(shuffle);
}
internal.apply(to_shuffle, 2 * n, 2 * unit_size, 0, 0, shuffle, reverse);
apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses);
}
for (size_t i = 0; i < n; i++)
template<class T>
void SpdzWiseRep3Shuffler<T>::apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources, vector<size_t> &unit_sizes, vector<shuffle_type> &shuffles, vector<bool> &reverse) {
const size_t n_shuffles = sizes.size();
assert(n_shuffles == destinations.size());
assert(n_shuffles == sources.size());
assert(n_shuffles == unit_sizes.size());
assert(n_shuffles == shuffles.size());
assert(n_shuffles == reverse.size());
StackedVector<typename T::part_type::Honest> temporary_memory(0);
vector<size_t> mapped_positions (n_shuffles, 0);
vector<size_t> mapped_sizes(n_shuffles, 0);
vector<size_t> mapped_unit_sizes (n_shuffles, 0);
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) {
mapped_positions[current_shuffle] = temporary_memory.size();
mapped_sizes[current_shuffle] = 2 * sizes[current_shuffle];
mapped_unit_sizes[current_shuffle] = 2 * unit_sizes[current_shuffle];
stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle];
for (size_t i = 0; i < sizes[current_shuffle]; i++)
{
auto& x = a[sources[current_shuffle] + i];
temporary_memory.push_back(x.get_share());
temporary_memory.push_back(x.get_mac());
}
}
internal.apply_multiple(temporary_memory, mapped_sizes, mapped_positions, mapped_positions, mapped_unit_sizes, shuffles, reverse);
for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++)
{
auto& x = a[output_base + i];
x.set_share(to_shuffle[2 * i]);
x.set_mac(to_shuffle[2 * i + 1]);
proc.protocol.add_to_check(x);
const size_t n = sizes[current_shuffle];
const size_t dest = destinations[current_shuffle];
const size_t pos = mapped_positions[current_shuffle];
for (size_t i = 0; i < n; i++)
{
auto& x = a[dest + i];
x.set_share(temporary_memory[pos + 2 * i]);
x.set_mac(temporary_memory[pos + 2 * i + 1]);
proc.protocol.add_to_check(x);
}
}
proc.protocol.maybe_check();