mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 21:48:11 -05:00
Make applyshuffle instruction mergeable, execution is still sequential
This commit is contained in:
@@ -2662,9 +2662,10 @@ class gensecshuffle(shuffle_base):
|
||||
def add_usage(self, req_node):
|
||||
self.add_gen_usage(req_node, self.args[1])
|
||||
|
||||
class applyshuffle(base.VectorInstruction, shuffle_base):
|
||||
class applyshuffle(shuffle_base, base.Mergeable):
|
||||
""" Generate secure shuffle to bit used several times.
|
||||
|
||||
:param: vector size (int)
|
||||
:param: destination (sint)
|
||||
:param: source (sint)
|
||||
:param: number of elements to be treated as one (int)
|
||||
@@ -2674,15 +2675,19 @@ class applyshuffle(base.VectorInstruction, shuffle_base):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['APPLYSHUFFLE']
|
||||
arg_format = ['sw','s','int','ci','int']
|
||||
arg_format = itertools.cycle(['int', 'sw','s','int','ci','int'])
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(applyshuffle, self).__init__(*args, **kwargs)
|
||||
assert len(args[0]) == len(args[1])
|
||||
assert len(args[0]) > args[2]
|
||||
assert (len(args) % 6) == 0
|
||||
for i in range(0, len(args), 6):
|
||||
assert args[i] == len(args[i+1])
|
||||
assert args[i] == len(args[i + 2])
|
||||
assert args[i] > args[i + 3]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
self.add_apply_usage(req_node, len(self.args[0]), self.args[2])
|
||||
for i in range(0, len(self.args), 6):
|
||||
self.add_apply_usage(req_node, self.args[i], self.args[i + 3])
|
||||
|
||||
class delshuffle(base.Instruction):
|
||||
""" Delete secure shuffle.
|
||||
|
||||
@@ -3055,7 +3055,7 @@ class sint(_secret, _int):
|
||||
@read_mem_value
|
||||
def secure_permute(self, shuffle, unit_size=1, reverse=False):
|
||||
res = sint(size=self.size)
|
||||
applyshuffle(res, self, unit_size, shuffle, reverse)
|
||||
applyshuffle(self.size, res, self, unit_size, shuffle, reverse)
|
||||
return res
|
||||
|
||||
def inverse_permutation(self):
|
||||
|
||||
@@ -285,7 +285,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRINTFLOATPLAIN:
|
||||
case PRINTFLOATPLAINB:
|
||||
case APPLYSHUFFLE:
|
||||
get_vector(5, start, s);
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case INCINT:
|
||||
r[0]=get_int(s);
|
||||
@@ -1136,8 +1137,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.machine.shuffle_store));
|
||||
return;
|
||||
case APPLYSHUFFLE:
|
||||
Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3)),
|
||||
Proc.machine.shuffle_store);
|
||||
Proc.Procp.apply_shuffle(*this, Proc.machine.shuffle_store);
|
||||
return;
|
||||
case DELSHUFFLE:
|
||||
Proc.machine.shuffle_store.del(Proc.read_Ci(r[0]));
|
||||
|
||||
@@ -88,8 +88,7 @@ public:
|
||||
void secure_shuffle(const Instruction& instruction);
|
||||
size_t generate_secure_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store);
|
||||
void apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store);
|
||||
void apply_shuffle(const Instruction& instruction, ShuffleStore& shuffle_store);
|
||||
void inverse_permutation(const Instruction& instruction);
|
||||
|
||||
void input_personal(const vector<int>& args);
|
||||
|
||||
@@ -890,13 +890,23 @@ size_t SubProcessor<T>::generate_secure_shuffle(const Instruction& instruction,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction, int handle,
|
||||
ShuffleStore& shuffle_store)
|
||||
void SubProcessor<T>::apply_shuffle(const Instruction& instruction,
|
||||
ShuffleStore& shuffle_store)
|
||||
{
|
||||
shuffler.apply(S, instruction.get_size(), instruction.get_start()[2],
|
||||
instruction.get_start()[0], instruction.get_start()[1],
|
||||
shuffle_store.get(handle),
|
||||
instruction.get_start()[4]);
|
||||
auto& start = instruction.get_start();
|
||||
|
||||
for (auto shuffleArgs = start.begin(); shuffleArgs < start.end(); shuffleArgs += 6) {
|
||||
// shuffleArgs[0] size
|
||||
// shuffleArgs[1] dest
|
||||
// shuffleArgs[2] source
|
||||
// shuffleArgs[3] unit size
|
||||
// shuffleArgs[4] handle
|
||||
// shuffleArgs[5] reverse
|
||||
shuffler.apply(S, shuffleArgs[0], shuffleArgs[3],
|
||||
shuffleArgs[1], shuffleArgs[2],
|
||||
shuffle_store.get(Proc->read_Ci(shuffleArgs[4])),
|
||||
shuffleArgs[5]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
58
Programs/Source/test_permute.mpc
Normal file
58
Programs/Source/test_permute.mpc
Normal file
@@ -0,0 +1,58 @@
|
||||
|
||||
def test_case(permutation_sizes, timer_base: int | None = None):
|
||||
if timer_base is not None:
|
||||
start_timer(timer_base + 0)
|
||||
arrays = []
|
||||
permutations = []
|
||||
for size in permutation_sizes:
|
||||
arrays.append(Array.create_from([sint(i) for i in range(size)]))
|
||||
permutations.append(sint.get_secure_shuffle(size))
|
||||
if timer_base is not None:
|
||||
stop_timer(timer_base + 0)
|
||||
start_timer(timer_base + 1)
|
||||
|
||||
for arr, p in zip(arrays, permutations):
|
||||
arr.secure_permute(p)
|
||||
|
||||
if timer_base is not None:
|
||||
stop_timer(timer_base + 1)
|
||||
start_timer(timer_base + 2)
|
||||
|
||||
for i, arr in enumerate(arrays):
|
||||
print_ln("%s", arr.reveal())
|
||||
|
||||
if timer_base is not None:
|
||||
stop_timer(timer_base + 2)
|
||||
start_timer(timer_base + 3)
|
||||
|
||||
for arr, p in zip(arrays, permutations):
|
||||
arr.secure_permute(p, reverse=True)
|
||||
|
||||
if timer_base is not None:
|
||||
stop_timer(timer_base + 3)
|
||||
start_timer(timer_base + 4)
|
||||
|
||||
for i, arr in enumerate(arrays):
|
||||
print_ln("%s", arr.reveal())
|
||||
|
||||
if timer_base is not None:
|
||||
stop_timer(timer_base + 4)
|
||||
|
||||
|
||||
def test_allocator():
|
||||
arr1 = sint.Array(5)
|
||||
arr2 = sint.Array(10)
|
||||
arr3 = sint.Array(20)
|
||||
|
||||
p1 = sint.get_secure_shuffle(5)
|
||||
p2 = sint.get_secure_shuffle(10)
|
||||
p3 = sint.get_secure_shuffle(20)
|
||||
|
||||
# Look at the bytecode, arr1 and arr3 should be shuffled in parallel, arr2 afterward.
|
||||
arr1.secure_permute(p1)
|
||||
arr2[0] = arr1[0]
|
||||
arr2.secure_permute(p2)
|
||||
arr3.secure_permute(p3)
|
||||
|
||||
test_allocator()
|
||||
# test_case([5, 10, 20], 10)
|
||||
Reference in New Issue
Block a user