Files
MP-SPDZ/Programs/Source/multinode_example_main.py
2023-12-14 12:17:54 +11:00

35 lines
991 B
Python

import random
n_nodes_per_party = int(program.args[1])
n_threads_per_node = int(program.args[2])
n_ops_per_thread = int(program.args[3])
n_ops_per_node = n_threads_per_node * n_ops_per_thread
n_ops = n_nodes_per_party * n_ops_per_node
data = Array.create_from(sint(regint.inc(n_ops)))
listen_for_clients(15000)
ready = regint.Array(n_nodes_per_party)
@for_range(n_nodes_per_party)
def _(i):
ready[accept_client_connection(15000)] = 1
runtime_error_if(sum(ready) != n_nodes_per_party, 'connection problems')
@for_range(n_nodes_per_party)
def _(i):
data.get_vector(base=i * n_ops_per_node,
size=n_ops_per_node).write_fully_to_socket(i)
@for_range(n_nodes_per_party)
def _(i):
data.assign_vector(sint.read_from_socket(i, size=n_ops_per_node),
base=i * n_ops_per_node)
for i in range(10):
index = random.randrange(n_ops)
value = data[index].reveal()
runtime_error_if(value != index ** 2, '%s != %s', value, index ** 2)