Fix emulation of permutations, use single-instruction matrix shuffle, use multithreading for matrix shuffles

This commit is contained in:
Vincent Ehrmanntraut
2024-12-19 09:46:58 +01:00
parent f27f15f864
commit 0db8d7a419
4 changed files with 29 additions and 15 deletions

View File

@@ -7274,7 +7274,7 @@ class SubMultiArray(_vectorizable):
self.secure_permute(perm)
delshuffle(perm)
def secure_permute(self, permutation, reverse=False, n_parallel=None):
def secure_permute(self, permutation, reverse=False, n_threads=None):
""" Securely permute rows (first index). See
:py:func:`secure_shuffle` for references.
@@ -7282,16 +7282,16 @@ class SubMultiArray(_vectorizable):
:param reverse: whether to apply inverse (default: False)
"""
if self.value_type == sint and False:
if self.value_type == sint:
unit_size = self.get_part_size()
n = self.sizes[0] * unit_size
res = sint(size=n)
applyshuffle(n, res, self[:], unit_size, permutation, reverse)
self.assign_vector(res)
else:
if n_parallel is None:
n_parallel = self.get_part_size()
@library.for_range_parallel(n_parallel, self.get_part_size())
if n_threads is not None:
permutation = MemValue(permutation)
@library.for_range_multithread(n_threads, 1, self.get_part_size())
def iter(i):
column = self.get_column(i)
column = column.secure_permute(permutation, reverse=reverse)

View File

@@ -137,6 +137,7 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
print_ln("Pre-create matrix")
m1 = Matrix.create_from([[value_type(5 * i + j) for j in range(5)] for i in range(5)])
m2 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
m3 = Matrix.create_from([[value_type(6 * i + j) for j in range(6)] for i in range(5)])
print_ln("post-create matrix")
p1 = sint.get_secure_shuffle(5)
@@ -148,10 +149,14 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
start_timer(timer_base + 2)
m2.secure_permute(p2)
stop_timer(timer_base + 2)
start_timer(timer_base + 3)
m3.secure_permute(p2, n_threads=3)
stop_timer(timer_base + 3)
print_ln(f"Timer {timer_base + 1} and {timer_base + 2} should require equal amount of rounds.")
m1 = m1.reveal()
m2 = m2.reveal()
m3 = m3.reveal()
print_ln("Permuted m1:")
for row in m1:
@@ -163,6 +168,11 @@ def test_permute_matrix(timer_base: int, value_type=sint) -> None:
print_ln("%s", row)
test_permuted_matrix(m2, p2)
print_ln("Permuted m3 (should be equal to m2):")
for row in m3:
print_ln("%s", row)
test_permuted_matrix(m3, p2)
def test_secure_shuffle_still_works(size: int, timer_base: int):
arr = Array.create_from([sint(i) for i in range(size)])
start_timer(timer_base)
@@ -220,7 +230,7 @@ def test_inverse_permutation_still_works(size, timer_base: int):
test_allocator()
test_case([10, 15], 10)
test_case([10, 15, 20], 20)
test_case([8,16], 30)
test_case([16,32], 30)
test_case([256], 40)
test_parallel_permutation_equals_sequential_permutation([5,10],50)

View File

@@ -1 +0,0 @@
print_ln("%s", cfix(0))

View File

@@ -37,8 +37,8 @@ public:
return store.add();
}
void apply(StackedVector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int, bool)
void apply(StackedVector<T>& a, size_t n, size_t unit_size, size_t output_base,
size_t input_base, size_t, bool)
{
auto source = a.begin() + input_base;
auto dest = a.begin() + output_base;
@@ -49,23 +49,28 @@ public:
if (n > 1)
{
// swap first two to pass check
for (int i = 0; i < unit_size; i++)
for (size_t i = 0; i < unit_size; i++)
swap(a[output_base + i], a[output_base + i + unit_size]);
}
}
void inverse_permutation(StackedVector<T>&, size_t, size_t, size_t)
void applyMultiple(StackedVector<T>& a, vector<int>& sizes, vector<int>& destinations, vector<int>& sources,
vector<int>& unit_sizes, vector<int>& handles, vector<bool>& reverse, store_type& store) {
void inverse_permutation(StackedVector<T> &, size_t, size_t, size_t) {
throw runtime_error("inverse permutation not implemented");
};
void apply_multiple(StackedVector<T> &a, vector<size_t> &sizes, vector<size_t> &destinations,
vector<size_t> &sources,
vector<size_t> &unit_sizes, vector<size_t> &handles, vector<bool> &reverses,
store_type&) {
const auto n_shuffles = sizes.size();
assert(sources.size() == n_shuffles);
assert(destinations.size() == n_shuffles);
assert(unit_sizes.size() == n_shuffles);
assert(handles.size() == n_shuffles);
assert(reverse.size() == n_shuffles);
assert(reverses.size() == n_shuffles);
for (size_t i = 0; i < n_shuffles; i++) {
this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], store.get(handles[i]), reverse[i]);
this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], handles[i], reverses[i]);
}
}
};