mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Fix emulation of permutations, use single-instruction matrix shuffle, use multithreading for matrix shuffles
This commit is contained in:
@@ -7274,7 +7274,7 @@ class SubMultiArray(_vectorizable):
|
||||
self.secure_permute(perm)
|
||||
delshuffle(perm)
|
||||
|
||||
def secure_permute(self, permutation, reverse=False, n_parallel=None):
|
||||
def secure_permute(self, permutation, reverse=False, n_threads=None):
|
||||
""" Securely permute rows (first index). See
|
||||
:py:func:`secure_shuffle` for references.
|
||||
|
||||
@@ -7282,16 +7282,16 @@ class SubMultiArray(_vectorizable):
|
||||
:param reverse: whether to apply inverse (default: False)
|
||||
|
||||
"""
|
||||
if self.value_type == sint and False:
|
||||
if self.value_type == sint:
|
||||
unit_size = self.get_part_size()
|
||||
n = self.sizes[0] * unit_size
|
||||
res = sint(size=n)
|
||||
applyshuffle(n, res, self[:], unit_size, permutation, reverse)
|
||||
self.assign_vector(res)
|
||||
else:
|
||||
if n_parallel is None:
|
||||
n_parallel = self.get_part_size()
|
||||
@library.for_range_parallel(n_parallel, self.get_part_size())
|
||||
if n_threads is not None:
|
||||
permutation = MemValue(permutation)
|
||||
@library.for_range_multithread(n_threads, 1, self.get_part_size())
|
||||
def iter(i):
|
||||
column = self.get_column(i)
|
||||
column = column.secure_permute(permutation, reverse=reverse)
|
||||
|
||||
@@ -137,6 +137,7 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
print_ln("Pre-create matrix")
|
||||
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
|
||||
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
|
||||
print_ln("post-create matrix")
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
@@ -148,10 +149,14 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
start_timer(timer_base + 2)
|
||||
m2.secure_permute(p2)
|
||||
stop_timer(timer_base + 2)
|
||||
start_timer(timer_base + 3)
|
||||
m3.secure_permute(p2, n_threads=3)
|
||||
stop_timer(timer_base + 3)
|
||||
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
|
||||
|
||||
m1 = m1.reveal()
|
||||
m2 = m2.reveal()
|
||||
m3 = m3.reveal()
|
||||
|
||||
print_ln("Permuted m1:")
|
||||
for row in m1:
|
||||
@@ -163,6 +168,11 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m2, p2)
|
||||
|
||||
print_ln("Permuted m3 (should be equal to m2):")
|
||||
for row in m3:
|
||||
print_ln("%s", row)
|
||||
test_permuted_matrix(m3, p2)
|
||||
|
||||
def test_secure_shuffle_still_works(size: int, timer_base: int):
|
||||
arr = Array.create_from([sint(i) for i in range(size)])
|
||||
start_timer(timer_base)
|
||||
@@ -220,7 +230,7 @@ def test_inverse_permutation_still_works(size, timer_base: int):
|
||||
test_allocator()
|
||||
test_case([10, 15], 10)
|
||||
test_case([10, 15, 20], 20)
|
||||
test_case([8,16], 30)
|
||||
test_case([16,32], 30)
|
||||
test_case([256], 40)
|
||||
|
||||
test_parallel_permutation_equals_sequential_permutation([5,10],50)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
print_ln("%s", cfix(0))
|
||||
@@ -37,8 +37,8 @@ public:
|
||||
return store.add();
|
||||
}
|
||||
|
||||
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
|
||||
size_t input_base, int, bool)
|
||||
void apply(StackedVector<T>& a, size_t n, size_t unit_size, size_t output_base,
|
||||
size_t input_base, size_t, bool)
|
||||
{
|
||||
auto source = a.begin() + input_base;
|
||||
auto dest = a.begin() + output_base;
|
||||
@@ -49,23 +49,28 @@ public:
|
||||
if (n > 1)
|
||||
{
|
||||
// swap first two to pass check
|
||||
for (int i = 0; i < unit_size; i++)
|
||||
for (size_t i = 0; i < unit_size; i++)
|
||||
swap(a[output_base + i], a[output_base + i + unit_size]);
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_permutation(StackedVector<T>&, size_t, size_t, size_t)
|
||||
void applyMultiple(StackedVector<T>& a, vector<int>& sizes, vector<int>& destinations, vector<int>& sources,
|
||||
vector<int>& unit_sizes, vector<int>& handles, vector<bool>& reverse, store_type& store) {
|
||||
void inverse_permutation(StackedVector<T> &, size_t, size_t, size_t) {
|
||||
throw runtime_error("inverse permutation not implemented");
|
||||
};
|
||||
|
||||
void apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
|
||||
vector<size_t> &sources,
|
||||
vector<size_t> &unit_sizes, vector<size_t> &handles, vector<bool> &reverses,
|
||||
store_type&) {
|
||||
const auto n_shuffles = sizes.size();
|
||||
assert(sources.size() == n_shuffles);
|
||||
assert(destinations.size() == n_shuffles);
|
||||
assert(unit_sizes.size() == n_shuffles);
|
||||
assert(handles.size() == n_shuffles);
|
||||
assert(reverse.size() == n_shuffles);
|
||||
assert(reverses.size() == n_shuffles);
|
||||
|
||||
for (size_t i = 0; i < n_shuffles; i++) {
|
||||
this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], store.get(handles[i]), reverse[i]);
|
||||
this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], handles[i], reverses[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user