mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
51 Commits
al/remove_
...
mz/mpi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4cf43da36 | ||
|
|
cdce78d3b8 | ||
|
|
46309a3da1 | ||
|
|
6de6ac0fc3 | ||
|
|
746cca0135 | ||
|
|
425436dee2 | ||
|
|
28416693ca | ||
|
|
04d9314c07 | ||
|
|
9f9162c42f | ||
|
|
99bbd0ed7d | ||
|
|
07b2796904 | ||
|
|
b0c44aba2f | ||
|
|
28798813c0 | ||
|
|
add6bf8e5d | ||
|
|
367783c7dd | ||
|
|
3c5732ae3a | ||
|
|
1ceceb2e6f | ||
|
|
abc05e141a | ||
|
|
0c95bb9024 | ||
|
|
fdf2451a5e | ||
|
|
3c240a8709 | ||
|
|
76ce4bd477 | ||
|
|
6f4cbfb108 | ||
|
|
160cd437a8 | ||
|
|
b5cb6b3d74 | ||
|
|
87f2bacaa0 | ||
|
|
9e60a39b44 | ||
|
|
d8e06acbf6 | ||
|
|
9af66d9d60 | ||
|
|
1c3cb60b56 | ||
|
|
8fea689097 | ||
|
|
54f454b9ea | ||
|
|
d3cb8aa111 | ||
|
|
6654bdbe0c | ||
|
|
ae3b33f644 | ||
|
|
1c7b1e7fd9 | ||
|
|
c8414772c0 | ||
|
|
1ea776976e | ||
|
|
1c310f28db | ||
|
|
04cfb1e009 | ||
|
|
fe6015f4b1 | ||
|
|
bd172c342d | ||
|
|
0bcab98438 | ||
|
|
0bdd133be7 | ||
|
|
ec05c66ea2 | ||
|
|
9340135e31 | ||
|
|
83e461873f | ||
|
|
a2ad55fedd | ||
|
|
105466a14c | ||
|
|
c4e9bd836a | ||
|
|
d676639200 |
@@ -1,6 +1,13 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng", "backends/tfhe-cuda-backend"]
|
||||
members = [
|
||||
"tfhe",
|
||||
"tasks",
|
||||
"apps/trivium",
|
||||
"concrete-csprng",
|
||||
"backends/tfhe-cuda-backend",
|
||||
"mpi_test",
|
||||
]
|
||||
|
||||
[profile.bench]
|
||||
lto = "fat"
|
||||
|
||||
33
mpi_test/Cargo.toml
Normal file
33
mpi_test/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "mpi_test"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "Cryptographically Secure PRNG used in the TFHE-rs library."
|
||||
homepage = "https://zama.ai/"
|
||||
documentation = "https://docs.zama.ai/tfhe-rs"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
readme = "README.md"
|
||||
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
|
||||
rust-version = "1.72"
|
||||
|
||||
[dependencies]
|
||||
mpi = { path = "../../rsmpi" }
|
||||
tfhe = { path = "../tfhe", features = [
|
||||
"shortint",
|
||||
"integer",
|
||||
"x86_64",
|
||||
"internal-keycache",
|
||||
] }
|
||||
bincode = "1.3"
|
||||
threadpool = "1.8"
|
||||
crossbeam-channel = "0.5"
|
||||
serde = { version = "1.0", features = ["derive", "rc"] }
|
||||
petgraph = "0.6.4"
|
||||
itertools = "*"
|
||||
thread-priority = "0.15.1"
|
||||
async-priority-channel = "0.2.0"
|
||||
futures = "0.3.30"
|
||||
logging_timer = "1.1.0"
|
||||
simple_logger = "4.3.3"
|
||||
smallvec = { version = "1.13.1", features = ["serde"] }
|
||||
11
mpi_test/run.sh
Executable file
11
mpi_test/run.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
cargo build --profile=release
|
||||
source /etc/profile.d/modules.sh
|
||||
module load mpi/openmpi-x86_64
|
||||
# export LD_LIBRARY_PATH="/usr/lib64/mpich/lib/"
|
||||
export LD_LIBRARY_PATH="/usr/lib64/openmpi/lib/"
|
||||
export RUSTFLAGS="-Ctarget-cpu=native"
|
||||
|
||||
mpirun -n 6 ../target/devo/mpi_test
|
||||
445
mpi_test/src/async_pbs_graph.rs
Normal file
445
mpi_test/src/async_pbs_graph.rs
Normal file
@@ -0,0 +1,445 @@
|
||||
use crate::async_task_graph::{Priority, TaskGraph};
|
||||
use crate::context::Context;
|
||||
use crate::examples::async_mul::{prefix_sum_carry_propagation, OutputCarry};
|
||||
use logging_timer::time;
|
||||
use mpi::traits::*;
|
||||
use petgraph::algo::is_cyclic_directed;
|
||||
use petgraph::stable_graph::NodeIndex;
|
||||
use petgraph::visit::EdgeRef;
|
||||
use petgraph::Direction::{Incoming, Outgoing};
|
||||
use petgraph::Graph;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tfhe::shortint::server_key::LookupTableOwned;
|
||||
use tfhe::shortint::{Ciphertext, ServerKey};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct IndexedCt {
|
||||
index: usize,
|
||||
ct: Ciphertext,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Serialize, Deserialize)]
|
||||
|
||||
pub enum Lut {
|
||||
ExtractMessage,
|
||||
ExtractCarry,
|
||||
BivarMulLow,
|
||||
BivarMulHigh,
|
||||
PrefixSumCarryPropagation,
|
||||
DoesBlockGenerateCarry,
|
||||
DoesBlockGenerateOrPropagate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct IndexedCtsAndLut {
|
||||
index: usize,
|
||||
cts_and_weights: SmallVec<[(u64, Arc<Ciphertext>); 5]>,
|
||||
lut: Lut,
|
||||
}
|
||||
|
||||
impl IndexedCtsAndLut {
|
||||
fn multisum(self, sks: &ServerKey) -> IndexedCtAndLut {
|
||||
let IndexedCtsAndLut {
|
||||
index,
|
||||
cts_and_weights,
|
||||
lut,
|
||||
} = self;
|
||||
|
||||
let mut cts_and_weights = cts_and_weights.into_iter();
|
||||
|
||||
let (first_scalar, first_ct) = cts_and_weights.next().unwrap();
|
||||
|
||||
let mut multisum_result = sks.unchecked_scalar_mul(&first_ct, first_scalar as u8);
|
||||
|
||||
for (scalar, ct) in cts_and_weights {
|
||||
sks.unchecked_add_scalar_mul_assign(&mut multisum_result, &ct, scalar as u8);
|
||||
}
|
||||
|
||||
IndexedCtAndLut {
|
||||
index,
|
||||
ct: multisum_result,
|
||||
lut,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct IndexedCtAndLut {
|
||||
index: usize,
|
||||
ct: Ciphertext,
|
||||
lut: Lut,
|
||||
}
|
||||
|
||||
impl IndexedCtAndLut {
|
||||
fn pbs(self, sks: &ServerKey, luts: &Luts) -> IndexedCt {
|
||||
IndexedCt {
|
||||
ct: sks.apply_lookup_table(&self.ct, luts.get(self.lut)),
|
||||
index: self.index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
pub enum Node {
|
||||
Computed(Arc<Ciphertext>),
|
||||
BootsrapQueued,
|
||||
ToCompute { lookup_table: Lut },
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn ct(&self) -> Option<&Arc<Ciphertext>> {
|
||||
match self {
|
||||
Node::Computed(ct) => Some(ct),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Node {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Computed(_) => f.debug_tuple("Computed").finish(),
|
||||
Self::BootsrapQueued => write!(f, "BootsrapQueued"),
|
||||
Self::ToCompute { .. } => f.debug_struct("ToCompute").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FheGraph {
|
||||
graph: Graph<(Priority, Node), u64>,
|
||||
not_computed_nodes_count: usize,
|
||||
}
|
||||
|
||||
fn insert_predecessors_recursively(
|
||||
graph: &Graph<Node, u64>,
|
||||
successors_max_depths: &mut HashMap<usize, i32>,
|
||||
node_index: NodeIndex,
|
||||
) {
|
||||
if successors_max_depths.contains_key(&node_index.index()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if graph
|
||||
.neighbors_directed(node_index, Outgoing)
|
||||
.all(|successor| successors_max_depths.contains_key(&successor.index()))
|
||||
{
|
||||
let max_successors_depth = graph
|
||||
.neighbors_directed(node_index, Outgoing)
|
||||
.map(|successor| successors_max_depths[&successor.index()])
|
||||
.max();
|
||||
|
||||
successors_max_depths.insert(node_index.index(), max_successors_depth.unwrap_or(0) + 1);
|
||||
|
||||
for predecessor in graph.neighbors_directed(node_index, Incoming) {
|
||||
insert_predecessors_recursively(graph, successors_max_depths, predecessor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FheGraph {
|
||||
pub fn new(graph: Graph<Node, u64>) -> Self {
|
||||
let not_computed_nodes_count = graph
|
||||
.node_weights()
|
||||
.filter(|node| !matches!(&node, Node::Computed(_)))
|
||||
.count();
|
||||
|
||||
let mut successors_max_depth = HashMap::new();
|
||||
|
||||
for node_index in graph.node_indices() {
|
||||
if graph.edges_directed(node_index, Outgoing).next().is_none() {
|
||||
insert_predecessors_recursively(&graph, &mut successors_max_depth, node_index);
|
||||
}
|
||||
}
|
||||
|
||||
dbg!(&successors_max_depth.values().max());
|
||||
|
||||
let graph = graph.map(
|
||||
|node_index, node| {
|
||||
(
|
||||
Priority(successors_max_depth[&node_index.index()]),
|
||||
node.clone(),
|
||||
)
|
||||
},
|
||||
|_, edge| *edge,
|
||||
);
|
||||
|
||||
Self {
|
||||
graph,
|
||||
not_computed_nodes_count,
|
||||
}
|
||||
}
|
||||
fn test_graph_init(&self) {
|
||||
assert!(!is_cyclic_directed(&self.graph));
|
||||
|
||||
for i in self.graph.node_indices() {
|
||||
if self.graph.neighbors_directed(i, Incoming).next().is_none() {
|
||||
assert!(matches!(
|
||||
&self.graph.node_weight(i),
|
||||
Some((_, Node::Computed(_)))
|
||||
))
|
||||
} else {
|
||||
assert!(matches!(
|
||||
&self.graph.node_weight(i),
|
||||
Some((_, Node::ToCompute { .. }))
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_finishable(&self) {
|
||||
assert!(!is_cyclic_directed(&self.graph));
|
||||
|
||||
for i in self.graph.node_indices() {
|
||||
if self.graph.neighbors_directed(i, Incoming).next().is_none() {
|
||||
assert!(matches!(
|
||||
&self.graph.node_weight(i),
|
||||
Some((_, Node::Computed(_)))
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[time]
|
||||
fn predecessors_list(&self, index: NodeIndex) -> SmallVec<[(u64, Arc<Ciphertext>); 5]> {
|
||||
self.graph
|
||||
.edges_directed(index, Incoming)
|
||||
.map(|edge| {
|
||||
(
|
||||
*edge.weight(),
|
||||
self.graph[edge.source()].1.ct().unwrap().clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[time]
|
||||
fn build_task(&mut self, index: NodeIndex) -> (Priority, IndexedCtsAndLut) {
|
||||
let cts_and_weights = self.predecessors_list(index);
|
||||
|
||||
let lut = match self.graph.node_weight(index) {
|
||||
Some((_, Node::ToCompute { lookup_table })) => lookup_table.to_owned(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
self.graph.node_weight_mut(index).unwrap().1 = Node::BootsrapQueued;
|
||||
|
||||
(
|
||||
Priority(0),
|
||||
IndexedCtsAndLut {
|
||||
index: index.index(),
|
||||
cts_and_weights,
|
||||
lut,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskGraph for FheGraph {
|
||||
type Task = IndexedCtsAndLut;
|
||||
|
||||
type Result = IndexedCt;
|
||||
|
||||
#[time]
|
||||
fn init(&mut self) -> Vec<(Priority, IndexedCtsAndLut)> {
|
||||
self.test_graph_init();
|
||||
|
||||
let nodes_to_compute: Vec<_> = self
|
||||
.graph
|
||||
.node_indices()
|
||||
.filter(|&i| {
|
||||
let to_compute =
|
||||
matches!(self.graph.node_weight(i).unwrap().1, Node::ToCompute { .. });
|
||||
|
||||
let all_predecessors_computed =
|
||||
self.graph
|
||||
.neighbors_directed(i, Incoming)
|
||||
.all(|predecessor| {
|
||||
matches!(
|
||||
self.graph.node_weight(predecessor).unwrap().1,
|
||||
Node::Computed(_)
|
||||
)
|
||||
});
|
||||
|
||||
to_compute && all_predecessors_computed
|
||||
})
|
||||
.collect();
|
||||
|
||||
nodes_to_compute
|
||||
.into_iter()
|
||||
.map(|index| self.build_task(index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[time]
|
||||
fn commit_result(&mut self, result: IndexedCt) -> Vec<(Priority, IndexedCtsAndLut)> {
|
||||
self.not_computed_nodes_count -= 1;
|
||||
|
||||
// dbg!(self.not_computed_nodes_count);
|
||||
|
||||
let IndexedCt { index, ct } = result;
|
||||
|
||||
let index = NodeIndex::new(index);
|
||||
|
||||
let node_mut = self.graph.node_weight_mut(index).unwrap();
|
||||
|
||||
assert!(matches!(node_mut.1, Node::BootsrapQueued));
|
||||
node_mut.1 = Node::Computed(Arc::new(ct));
|
||||
|
||||
let nodes_to_compute: Vec<_> = self
|
||||
.graph
|
||||
.neighbors_directed(index, Outgoing)
|
||||
.filter(|&i| {
|
||||
assert!(matches!(
|
||||
self.graph.node_weight(i).unwrap().1,
|
||||
Node::ToCompute { .. }
|
||||
));
|
||||
|
||||
let all_predecessors_computed =
|
||||
self.graph
|
||||
.neighbors_directed(i, Incoming)
|
||||
.all(|predecessor| {
|
||||
matches!(
|
||||
self.graph.node_weight(predecessor).unwrap().1,
|
||||
Node::Computed(_)
|
||||
)
|
||||
});
|
||||
|
||||
all_predecessors_computed
|
||||
})
|
||||
.collect();
|
||||
|
||||
nodes_to_compute
|
||||
.into_iter()
|
||||
.map(|index| self.build_task(index))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_finished(&self) -> bool {
|
||||
self.not_computed_nodes_count == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn async_pbs_graph_queue_master1(
|
||||
&self,
|
||||
sks: Arc<ServerKey>,
|
||||
graph: Graph<Node, u64>,
|
||||
) -> (Graph<Node, u64>, Duration) {
|
||||
let luts = Luts::new(&sks);
|
||||
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let mut sks_serialized = bincode::serialize(sks.as_ref()).unwrap();
|
||||
let mut sks_serialized_len = sks_serialized.len();
|
||||
|
||||
let mut graph = FheGraph::new(graph);
|
||||
|
||||
graph.assert_finishable();
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
root_process.broadcast_into(sks_serialized.as_mut_slice());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
self.async_task_graph_queue_master::<_, _, IndexedCtsAndLut, IndexedCtAndLut, IndexedCt>(
|
||||
&mut graph,
|
||||
(sks, luts),
|
||||
move |(sks, luts), input| input.multisum(sks).pbs(sks, luts),
|
||||
move |(sks, _), task| task.multisum(sks),
|
||||
);
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
(
|
||||
graph.graph.map(|_, node| node.1.clone(), |_, edge| *edge),
|
||||
duration,
|
||||
)
|
||||
}
|
||||
pub fn async_pbs_graph_queue_worker1(&self) {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let mut sks_serialized_len = 0;
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
let mut sks_serialized = vec![0; sks_serialized_len];
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized);
|
||||
|
||||
let sks: Arc<ServerKey> = Arc::new(bincode::deserialize(&sks_serialized).unwrap());
|
||||
|
||||
let luts = Luts::new(&sks);
|
||||
|
||||
self.async_task_graph_queue_worker::<_, IndexedCtAndLut, IndexedCt>(
|
||||
(sks, luts),
|
||||
|(sks, luts), input| input.pbs(sks, luts),
|
||||
);
|
||||
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct Luts {
|
||||
extract_message: LookupTableOwned,
|
||||
extract_carry: LookupTableOwned,
|
||||
bivar_mul_low: LookupTableOwned,
|
||||
bivar_mul_high: LookupTableOwned,
|
||||
prefix_sum_carry_propagation: LookupTableOwned,
|
||||
does_block_generate_carry: LookupTableOwned,
|
||||
does_block_generate_or_propagate: LookupTableOwned,
|
||||
}
|
||||
|
||||
impl Luts {
|
||||
fn new(sks: &ServerKey) -> Self {
|
||||
let message_modulus = sks.message_modulus.0 as u64;
|
||||
|
||||
Self {
|
||||
extract_message: sks.generate_lookup_table(|x| x % message_modulus),
|
||||
extract_carry: sks.generate_lookup_table(|x| x / message_modulus),
|
||||
bivar_mul_low: sks
|
||||
.generate_lookup_table_bivariate(|x, y| (x * y) % message_modulus)
|
||||
.acc,
|
||||
bivar_mul_high: sks
|
||||
.generate_lookup_table_bivariate(|x, y| (x * y) / message_modulus)
|
||||
.acc,
|
||||
prefix_sum_carry_propagation: sks
|
||||
.generate_lookup_table_bivariate(prefix_sum_carry_propagation)
|
||||
.acc,
|
||||
does_block_generate_carry: sks.generate_lookup_table(|x| {
|
||||
if x >= message_modulus {
|
||||
OutputCarry::Generated as u64
|
||||
} else {
|
||||
OutputCarry::None as u64
|
||||
}
|
||||
}),
|
||||
does_block_generate_or_propagate: sks.generate_lookup_table(|x| {
|
||||
if x >= message_modulus {
|
||||
OutputCarry::Generated as u64
|
||||
} else if x == (message_modulus - 1) {
|
||||
OutputCarry::Propagated as u64
|
||||
} else {
|
||||
OutputCarry::None as u64
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, lut: Lut) -> &LookupTableOwned {
|
||||
match lut {
|
||||
Lut::ExtractMessage => &self.extract_message,
|
||||
Lut::ExtractCarry => &self.extract_carry,
|
||||
Lut::BivarMulLow => &self.bivar_mul_low,
|
||||
Lut::BivarMulHigh => &self.bivar_mul_high,
|
||||
Lut::PrefixSumCarryPropagation => &self.prefix_sum_carry_propagation,
|
||||
Lut::DoesBlockGenerateCarry => &self.does_block_generate_carry,
|
||||
Lut::DoesBlockGenerateOrPropagate => &self.does_block_generate_or_propagate,
|
||||
}
|
||||
}
|
||||
}
|
||||
373
mpi_test/src/async_task_graph.rs
Normal file
373
mpi_test/src/async_task_graph.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
use crate::context::Context;
|
||||
use crate::managers::{Receiving, Sending};
|
||||
use async_priority_channel::{unbounded, Receiver, Sender};
|
||||
use futures::executor::block_on;
|
||||
use mpi::topology::Process;
|
||||
use mpi::traits::*;
|
||||
use mpi::Tag;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use std::mem::transmute;
|
||||
use thread_priority::{set_current_thread_priority, ThreadPriority, ThreadPriorityValue};
|
||||
use threadpool::ThreadPool;
|
||||
|
||||
const MASTER_TO_WORKER: Tag = 0;
|
||||
const WORKER_TO_MASTER: Tag = 1;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Priority(pub i32);
|
||||
|
||||
pub trait TaskGraph {
|
||||
type Task;
|
||||
type Result;
|
||||
|
||||
fn init(&mut self) -> Vec<(Priority, Self::Task)>;
|
||||
fn commit_result(&mut self, result: Self::Result) -> Vec<(Priority, Self::Task)>;
|
||||
// fn no_work_in_queue(&self) -> bool;
|
||||
fn is_finished(&self) -> bool;
|
||||
}
|
||||
|
||||
struct ClusterCharge {
|
||||
available_parallelism: usize,
|
||||
charge: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn async_task_graph_queue_master<
|
||||
T: Sync + Clone + Send + 'static,
|
||||
U: TaskGraph<Task = Task, Result = Result>,
|
||||
Task: Send + 'static,
|
||||
RemoteTask: Serialize + DeserializeOwned + Send + 'static,
|
||||
Result: Serialize + DeserializeOwned + Send + 'static,
|
||||
>(
|
||||
&self,
|
||||
task_graph: &mut U,
|
||||
state: T,
|
||||
f: impl Fn(&T, Task) -> Result + Sync + Clone + Send + 'static,
|
||||
convert: impl Fn(&T, Task) -> RemoteTask + Sync + Clone + Send + 'static,
|
||||
) {
|
||||
let (send_task, receive_task) = unbounded::<Task, Priority>();
|
||||
let (send_result, receive_result) = unbounded::<(Result, usize), Priority>();
|
||||
|
||||
// let mut sent_inputs = vec![];
|
||||
|
||||
let mut charge = ClusterCharge {
|
||||
available_parallelism: std::thread::available_parallelism().unwrap().get(),
|
||||
charge: vec![0; self.size],
|
||||
};
|
||||
|
||||
{
|
||||
let state = state.clone();
|
||||
let n_workers = (std::thread::available_parallelism().unwrap().get() - 1).max(1);
|
||||
let priority =
|
||||
ThreadPriority::Crossplatform(ThreadPriorityValue::try_from(32).unwrap());
|
||||
|
||||
launch_threadpool(
|
||||
priority,
|
||||
n_workers,
|
||||
&receive_task,
|
||||
&send_result,
|
||||
move |receive_task, send_result, state| {
|
||||
let f = f.clone();
|
||||
|
||||
let (input, priority) = block_on(receive_task.recv()).unwrap();
|
||||
|
||||
let result = f(state, input);
|
||||
|
||||
block_on(send_result.send((result, 0), priority)).unwrap();
|
||||
},
|
||||
state,
|
||||
);
|
||||
}
|
||||
|
||||
let worker_senders: Vec<_> = (1..self.size)
|
||||
.map(|rank| {
|
||||
let (send_task, receive_task) = unbounded::<Task, Priority>();
|
||||
let process_at_rank: Process<'static> =
|
||||
unsafe { transmute(self.world.process_at_rank(rank as i32)) };
|
||||
|
||||
let convert = convert.clone();
|
||||
|
||||
let state = state.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
// set_current_thread_priority(priority).unwrap();
|
||||
|
||||
let mut sent_inputs = vec![];
|
||||
|
||||
while let Ok((task, priority)) = block_on(receive_task.recv()) {
|
||||
let remote_task = convert(&state, task);
|
||||
|
||||
let buffer = bincode::serialize(&(remote_task, priority)).unwrap();
|
||||
|
||||
sent_inputs.push(Sending::new(buffer, &process_at_rank, MASTER_TO_WORKER))
|
||||
}
|
||||
|
||||
for a in sent_inputs {
|
||||
a.abort()
|
||||
}
|
||||
});
|
||||
send_task
|
||||
})
|
||||
.collect();
|
||||
|
||||
for rank in 1..self.size {
|
||||
let send_result = send_result.clone();
|
||||
|
||||
let process_at_rank: Process<'static> =
|
||||
unsafe { transmute(self.world.process_at_rank(rank as i32)) };
|
||||
|
||||
std::thread::spawn(move || {
|
||||
// set_current_thread_priority(priority).unwrap();
|
||||
let mut receives = VecDeque::new();
|
||||
|
||||
for _ in 0..100 {
|
||||
receives.push_back(Some(Receiving::new(&process_at_rank, WORKER_TO_MASTER)))
|
||||
}
|
||||
|
||||
loop {
|
||||
let Receiving { buffer, future } = receives.pop_front().unwrap().unwrap();
|
||||
|
||||
receives.push_back(Some(Receiving::new(&process_at_rank, WORKER_TO_MASTER)));
|
||||
|
||||
future.wait();
|
||||
|
||||
let (result, priority) = bincode::deserialize(&buffer).unwrap();
|
||||
|
||||
block_on(send_result.send((result, rank), priority)).unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (priority, task) in task_graph.init() {
|
||||
self.enqueue_request(&mut charge, &send_task, priority, task, &worker_senders);
|
||||
}
|
||||
|
||||
let mut empty = 0;
|
||||
let mut not_empty = 0;
|
||||
|
||||
while !task_graph.is_finished() {
|
||||
if receive_result.is_empty() {
|
||||
empty += 1;
|
||||
} else {
|
||||
not_empty += 1;
|
||||
}
|
||||
|
||||
let ((result, rank), _priority) = block_on(receive_result.recv()).unwrap();
|
||||
|
||||
charge.charge[rank] -= 1;
|
||||
|
||||
self.handle_new_result(task_graph, result, &mut charge, &send_task, &worker_senders);
|
||||
}
|
||||
|
||||
dbg!(empty, not_empty);
|
||||
|
||||
for i in charge.charge {
|
||||
assert_eq!(i, 0);
|
||||
}
|
||||
|
||||
std::mem::forget(send_task);
|
||||
}
|
||||
|
||||
fn handle_new_result<U: TaskGraph>(
|
||||
&self,
|
||||
task_graph: &mut U,
|
||||
result: U::Result,
|
||||
charge: &mut ClusterCharge,
|
||||
send_task: &Sender<U::Task, Priority>,
|
||||
sent_inputs: &[Sender<U::Task, Priority>],
|
||||
) {
|
||||
let new_tasks = task_graph.commit_result(result);
|
||||
|
||||
for (priority, task) in new_tasks {
|
||||
self.enqueue_request(charge, send_task, priority, task, sent_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
fn enqueue_request<Task>(
|
||||
&self,
|
||||
charge: &mut ClusterCharge,
|
||||
send_task: &Sender<Task, Priority>,
|
||||
priority: Priority,
|
||||
task: Task,
|
||||
sent_inputs: &[Sender<Task, Priority>],
|
||||
) {
|
||||
let rank = if charge.charge[self.root_rank as usize] < charge.available_parallelism {
|
||||
self.root_rank as usize
|
||||
} else {
|
||||
charge
|
||||
.charge
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by_key(|(_index, charge)| *charge)
|
||||
.unwrap()
|
||||
.0
|
||||
};
|
||||
|
||||
charge.charge[rank] += 1;
|
||||
|
||||
if rank == self.root_rank as usize {
|
||||
block_on(send_task.send(task, priority)).unwrap();
|
||||
} else {
|
||||
block_on(sent_inputs[rank - 1].send(task, priority)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn async_task_graph_queue_worker<
|
||||
T: Sync + Clone + Send + 'static,
|
||||
RemoteTask: Serialize + DeserializeOwned + Send,
|
||||
Result: Serialize + DeserializeOwned + Send,
|
||||
>(
|
||||
&self,
|
||||
state: T,
|
||||
f: impl Fn(&T, RemoteTask) -> Result + Sync + Clone + Send + 'static,
|
||||
) {
|
||||
let f = move |state: &T, serialized_input: &Vec<u8>| {
|
||||
let (input, priority): (RemoteTask, Priority) =
|
||||
bincode::deserialize(serialized_input).unwrap();
|
||||
|
||||
let result = f(state, input);
|
||||
|
||||
bincode::serialize(&(result, priority)).unwrap()
|
||||
};
|
||||
|
||||
let (send_task, receive_task) = crossbeam_channel::unbounded::<Vec<u8>>();
|
||||
let (send_result, receive_result) = crossbeam_channel::unbounded::<Vec<u8>>();
|
||||
|
||||
{
|
||||
let state = state.clone();
|
||||
let f = f.clone();
|
||||
let n_workers = (std::thread::available_parallelism().unwrap().get() - 1).max(1);
|
||||
let priority =
|
||||
ThreadPriority::Crossplatform(ThreadPriorityValue::try_from(32).unwrap());
|
||||
|
||||
launch_threadpool2(
|
||||
priority,
|
||||
n_workers,
|
||||
&receive_task,
|
||||
&send_result,
|
||||
move |receive_task, send_result, state| {
|
||||
let f = f.clone();
|
||||
|
||||
let input = receive_task.recv().unwrap();
|
||||
|
||||
let result = f(state, &input);
|
||||
|
||||
send_result.send(result).unwrap();
|
||||
},
|
||||
state,
|
||||
);
|
||||
}
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
{
|
||||
let root_process: Process<'static> =
|
||||
unsafe { transmute(self.world.process_at_rank(self.root_rank)) };
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let mut receives = VecDeque::new();
|
||||
|
||||
for _ in 0..100 {
|
||||
receives.push_back(Some(Receiving::new(&root_process, MASTER_TO_WORKER)))
|
||||
}
|
||||
loop {
|
||||
let Receiving { buffer, future } = receives.pop_front().unwrap().unwrap();
|
||||
|
||||
receives.push_back(Some(Receiving::new(&root_process, MASTER_TO_WORKER)));
|
||||
|
||||
future.wait();
|
||||
|
||||
send_task.send(buffer).unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut send: VecDeque<Sending> = VecDeque::new();
|
||||
|
||||
'outer: loop {
|
||||
if let Ok(output) = receive_result.recv() {
|
||||
send.push_back(Sending::new(output, &root_process, WORKER_TO_MASTER));
|
||||
}
|
||||
|
||||
while let Some(front) = send.front_mut() {
|
||||
if let Some(a) = front.a.take() {
|
||||
match a.test() {
|
||||
Ok(_) => {
|
||||
let b = send.pop_front();
|
||||
|
||||
assert!(b.unwrap().a.is_none());
|
||||
}
|
||||
Err(front_a) => {
|
||||
front.a = Some(front_a);
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn launch_threadpool<
|
||||
T: Clone + Send + 'static,
|
||||
U: Send + 'static,
|
||||
V: Send + 'static,
|
||||
W: Send + Ord + 'static,
|
||||
// X: Send + Ord + 'static,
|
||||
>(
|
||||
priority: ThreadPriority,
|
||||
n_workers: usize,
|
||||
receive_task: &Receiver<U, W>,
|
||||
send_result: &Sender<V, Priority>,
|
||||
function: impl Fn(&Receiver<U, W>, &Sender<V, Priority>, &T) + Send + Clone + 'static,
|
||||
state: T,
|
||||
) {
|
||||
let pool = ThreadPool::new(n_workers);
|
||||
|
||||
for _ in 0..n_workers {
|
||||
let receive_task = receive_task.clone();
|
||||
let send_result = send_result.clone();
|
||||
let function = function.clone();
|
||||
|
||||
let state = state.clone();
|
||||
|
||||
pool.execute(move || {
|
||||
set_current_thread_priority(priority).unwrap();
|
||||
|
||||
loop {
|
||||
function(&receive_task, &send_result, &state);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn launch_threadpool2<T: Clone + Send + 'static, U: Send + 'static, V: Send + 'static>(
|
||||
priority: ThreadPriority,
|
||||
n_workers: usize,
|
||||
receive_task: &crossbeam_channel::Receiver<U>,
|
||||
send_result: &crossbeam_channel::Sender<V>,
|
||||
function: impl Fn(&crossbeam_channel::Receiver<U>, &crossbeam_channel::Sender<V>, &T)
|
||||
+ Send
|
||||
+ Clone
|
||||
+ 'static,
|
||||
state: T,
|
||||
) {
|
||||
let pool = ThreadPool::new(n_workers);
|
||||
|
||||
for _ in 0..n_workers {
|
||||
let receive_task = receive_task.clone();
|
||||
let send_result = send_result.clone();
|
||||
let function = function.clone();
|
||||
|
||||
let state = state.clone();
|
||||
|
||||
pool.execute(move || {
|
||||
set_current_thread_priority(priority).unwrap();
|
||||
|
||||
loop {
|
||||
function(&receive_task, &send_result, &state);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
36
mpi_test/src/context.rs
Normal file
36
mpi_test/src/context.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use mpi::environment::Universe;
|
||||
use mpi::topology::SimpleCommunicator;
|
||||
use mpi::traits::*;
|
||||
use mpi::Threading;
|
||||
|
||||
pub struct Context {
|
||||
pub universe: Universe,
|
||||
pub world: SimpleCommunicator,
|
||||
pub size: usize,
|
||||
pub rank: i32,
|
||||
pub root_rank: i32,
|
||||
pub is_root: bool,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl Context {
|
||||
pub fn new() -> Self {
|
||||
let (universe, _) = mpi::initialize_with_threading(Threading::Multiple).unwrap();
|
||||
let world = universe.world();
|
||||
|
||||
let size = world.size() as usize;
|
||||
let rank = world.rank();
|
||||
let root_rank = 0;
|
||||
|
||||
let is_root = rank == root_rank;
|
||||
|
||||
Context {
|
||||
universe,
|
||||
world,
|
||||
size,
|
||||
rank,
|
||||
root_rank,
|
||||
is_root,
|
||||
}
|
||||
}
|
||||
}
|
||||
201
mpi_test/src/examples/async_batch.rs
Normal file
201
mpi_test/src/examples/async_batch.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use crate::context::Context;
|
||||
use crate::N;
|
||||
use mpi::traits::*;
|
||||
use std::time::Instant;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
|
||||
|
||||
impl Context {
|
||||
pub fn async_pbs_batch(&self) {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let mut cks_opt = None;
|
||||
|
||||
let mut sks_serialized = vec![];
|
||||
let mut sks_serialized_len = 0;
|
||||
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
cks_opt = Some(cks);
|
||||
|
||||
sks_serialized = bincode::serialize(&sks).unwrap();
|
||||
sks_serialized_len = sks_serialized.len();
|
||||
}
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
if sks_serialized.is_empty() {
|
||||
sks_serialized = vec![0; sks_serialized_len];
|
||||
}
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized);
|
||||
|
||||
let sks: ServerKey = bincode::deserialize(&sks_serialized).unwrap();
|
||||
|
||||
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
|
||||
|
||||
if self.is_root {
|
||||
let cks = cks_opt.as_ref().unwrap();
|
||||
|
||||
let mut inputs = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let ct = cks.unchecked_encrypt(i % 16);
|
||||
|
||||
inputs.push(ct);
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let elements_per_node = N as usize / self.size;
|
||||
|
||||
let serialized: Vec<_> = (1..self.size)
|
||||
.map(|dest_rank| {
|
||||
let inputs_chunk =
|
||||
&inputs[elements_per_node * dest_rank..elements_per_node * (dest_rank + 1)];
|
||||
|
||||
bincode::serialize(inputs_chunk).unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let lens: Vec<_> = serialized.iter().map(|a| a.len()).collect();
|
||||
|
||||
let sent_len: Vec<_> = lens
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| {
|
||||
let dest_rank = i as i32 + 1;
|
||||
let process = self.world.process_at_rank(dest_rank);
|
||||
|
||||
process.immediate_send(a)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let sent_vec: Vec<_> = serialized
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, a)| {
|
||||
let dest_rank = i as i32 + 1;
|
||||
let process = self.world.process_at_rank(dest_rank);
|
||||
|
||||
process.immediate_send(a)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for i in sent_len {
|
||||
i.wait();
|
||||
}
|
||||
|
||||
for i in sent_vec {
|
||||
i.wait();
|
||||
}
|
||||
|
||||
let mut outputs: Vec<_> = inputs[0..elements_per_node]
|
||||
.iter()
|
||||
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
|
||||
.collect();
|
||||
|
||||
let lens: Vec<_> = (1..self.size)
|
||||
.map(|dest_rank| {
|
||||
let process = self.world.process_at_rank(dest_rank as i32);
|
||||
process.immediate_receive()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut results: Vec<Vec<u8>> =
|
||||
lens.into_iter().map(|len| vec![0; len.get().0]).collect();
|
||||
|
||||
let sent: Vec<_> = results
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, a)| {
|
||||
let dest_rank = i as i32 + 1;
|
||||
let process = self.world.process_at_rank(dest_rank);
|
||||
|
||||
process.immediate_receive_into(a)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for i in sent {
|
||||
i.wait();
|
||||
}
|
||||
|
||||
for result in results.iter() {
|
||||
let outputs_chunk: Vec<Ciphertext> = bincode::deserialize(result).unwrap();
|
||||
|
||||
outputs.extend(outputs_chunk);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
println!("{N} PBS in {}s", duration_sec);
|
||||
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
|
||||
for (i, ct) in outputs.iter().enumerate() {
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), (i as u64 + 1) % 16);
|
||||
}
|
||||
|
||||
println!("All good 2");
|
||||
} else {
|
||||
let (len, _) = root_process.receive();
|
||||
|
||||
let mut input = vec![0; len];
|
||||
|
||||
// let mut status;
|
||||
|
||||
root_process.receive_into(input.as_mut_slice());
|
||||
|
||||
let input: Vec<Ciphertext> = bincode::deserialize(&input).unwrap();
|
||||
|
||||
let output: Vec<_> = input
|
||||
.iter()
|
||||
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
|
||||
.collect();
|
||||
|
||||
let output = bincode::serialize(&output).unwrap();
|
||||
|
||||
root_process.send(&output.len());
|
||||
|
||||
root_process.send(&output);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_mpi_immediate(&self) {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
if self.is_root {
|
||||
let process = self.world.process_at_rank(1);
|
||||
|
||||
let input = vec![1, 2, 3];
|
||||
|
||||
let len = [input.len()];
|
||||
|
||||
let a = process.immediate_send(&len);
|
||||
|
||||
let b = process.immediate_send(input.as_slice());
|
||||
|
||||
// drop(b);
|
||||
let b2 = process.immediate_send(input.as_slice());
|
||||
|
||||
a.wait();
|
||||
b.wait();
|
||||
b2.wait();
|
||||
|
||||
// let (outputs_chunks_serialized, _status) = process.receive_vec();
|
||||
} else if self.rank == 1 {
|
||||
let (len, _) = root_process.receive();
|
||||
|
||||
let mut input = vec![0; len];
|
||||
|
||||
// let mut status;
|
||||
|
||||
let future = root_process.immediate_receive_into(input.as_mut_slice());
|
||||
|
||||
future.wait();
|
||||
|
||||
dbg!(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
65
mpi_test/src/examples/async_flat_graph.rs
Normal file
65
mpi_test/src/examples/async_flat_graph.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use crate::async_pbs_graph::Node;
|
||||
use crate::context::Context;
|
||||
use crate::N;
|
||||
use petgraph::Graph;
|
||||
use std::sync::Arc;
|
||||
use tfhe::shortint::gen_keys;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
|
||||
impl Context {
|
||||
pub fn async_flat_graph(&self) {
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let mut graph = Graph::new();
|
||||
|
||||
let mut expected_outputs = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let plain = i % 16;
|
||||
|
||||
let encrypted = cks.unchecked_encrypt(plain);
|
||||
|
||||
let f = |x| x + 2;
|
||||
|
||||
let lookup_table = sks.generate_lookup_table(f);
|
||||
|
||||
// dbg!(cks.decrypt_message_and_carry(&sks.apply_lookup_table(&encrypted,
|
||||
// &lookup_table)));
|
||||
|
||||
let input = graph.add_node(Node::Computed(encrypted));
|
||||
let output = graph.add_node(Node::ToCompute {
|
||||
lookup_table: lookup_table.clone(),
|
||||
});
|
||||
|
||||
graph.add_edge(input, output, 1);
|
||||
|
||||
expected_outputs.push((output, f(plain)));
|
||||
}
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let (graph, duration) = self.async_pbs_graph_queue_master1(sks, graph);
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
println!("{N} PBS in {}s", duration_sec);
|
||||
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
|
||||
for (node_index, expected_decryption) in expected_outputs {
|
||||
let node = graph.node_weight(node_index).unwrap();
|
||||
|
||||
let ct = match node {
|
||||
Node::Computed(ct) => ct,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
|
||||
}
|
||||
|
||||
println!("All good 4");
|
||||
} else {
|
||||
self.async_pbs_graph_queue_worker1();
|
||||
}
|
||||
}
|
||||
}
|
||||
120
mpi_test/src/examples/async_list.rs
Normal file
120
mpi_test/src/examples/async_list.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use crate::async_task_graph::{Priority, TaskGraph};
|
||||
use crate::context::Context;
|
||||
use crate::managers::IndexedCt;
|
||||
use crate::N;
|
||||
use mpi::traits::*;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use tfhe::shortint::server_key::LookupTableOwned;
|
||||
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
|
||||
|
||||
struct ListOfPbs {
|
||||
pub inputs: Vec<Ciphertext>,
|
||||
pub outputs: HashMap<usize, Ciphertext>,
|
||||
}
|
||||
|
||||
impl ListOfPbs {
|
||||
fn new(inputs: Vec<Ciphertext>) -> Self {
|
||||
Self {
|
||||
inputs,
|
||||
outputs: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskGraph for ListOfPbs {
|
||||
type Task = IndexedCt;
|
||||
|
||||
type Result = IndexedCt;
|
||||
|
||||
fn init(&mut self) -> Vec<(Priority, IndexedCt)> {
|
||||
self.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, ct)| (Priority(0), IndexedCt { index: i, ct }))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn commit_result(&mut self, result: IndexedCt) -> Vec<(Priority, IndexedCt)> {
|
||||
self.outputs.insert(result.index, result.ct);
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn is_finished(&self) -> bool {
|
||||
self.outputs.len() == self.inputs.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn async_pbs_list_queue(&self) {
|
||||
if self.is_root {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let mut sks_serialized = bincode::serialize(&sks).unwrap();
|
||||
let mut sks_serialized_len = sks_serialized.len();
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized);
|
||||
|
||||
let lookup_table = Arc::new(sks.generate_lookup_table(|x| (x + 1) % 16));
|
||||
|
||||
let inputs: Vec<_> = (0..N).map(|i| cks.unchecked_encrypt(i % 16)).collect();
|
||||
|
||||
let start = Instant::now();
|
||||
let mut a = ListOfPbs::new(inputs);
|
||||
self.async_task_graph_queue_master(
|
||||
&mut a,
|
||||
(sks, lookup_table),
|
||||
|(sks, lookup_table), input| run_pbs(input, sks, lookup_table),
|
||||
);
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
println!("{N} PBS in {}s", duration_sec);
|
||||
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
|
||||
for (i, ct) in a.outputs.iter() {
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), (*i as u64 + 1) % 16);
|
||||
}
|
||||
|
||||
println!("All good 3");
|
||||
} else {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let mut sks_serialized_len = 0;
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
let mut sks_serialized = vec![0; sks_serialized_len];
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized);
|
||||
|
||||
let sks: Arc<ServerKey> = Arc::new(bincode::deserialize(&sks_serialized).unwrap());
|
||||
|
||||
let lookup_table = Arc::new(sks.generate_lookup_table(|x| (x + 1) % 16));
|
||||
|
||||
self.async_task_graph_queue_worker(
|
||||
(sks, lookup_table),
|
||||
|(sks, lookup_table), input| run_pbs(input, sks, lookup_table),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_pbs(input: &IndexedCt, sks: &ServerKey, lookup_table: &LookupTableOwned) -> IndexedCt {
|
||||
IndexedCt {
|
||||
ct: sks.apply_lookup_table(&input.ct, lookup_table),
|
||||
index: input.index,
|
||||
}
|
||||
}
|
||||
494
mpi_test/src/examples/async_mul.rs
Normal file
494
mpi_test/src/examples/async_mul.rs
Normal file
@@ -0,0 +1,494 @@
|
||||
use crate::async_pbs_graph::{Lut, Node};
|
||||
use crate::context::Context;
|
||||
use core::panic;
|
||||
use itertools::{zip_eq, Itertools};
|
||||
use petgraph::prelude::NodeIndex;
|
||||
use petgraph::Graph;
|
||||
use std::collections::BinaryHeap;
|
||||
use std::sync::Arc;
|
||||
use tfhe::core_crypto::commons::traits::UnsignedInteger;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use tfhe::shortint::{gen_keys, ServerKey};
|
||||
|
||||
impl Context {
|
||||
pub fn async_mul(&self, num_blocks: i32) {
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut graph = Graph::new();
|
||||
let mut expected_outputs = vec![];
|
||||
|
||||
let cut_into_block = |number| {
|
||||
let mut number = number;
|
||||
(0..num_blocks)
|
||||
.map(|_| {
|
||||
let new = number % 4;
|
||||
number /= 4;
|
||||
new
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let cut_into_nodes = |graph: &mut Graph<Node, u64>, number: u64| {
|
||||
cut_into_block(number)
|
||||
.into_iter()
|
||||
.map(|block| {
|
||||
graph.add_node(Node::Computed(Arc::new(cks.unchecked_encrypt(block))))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let i = 24533;
|
||||
let j = 53864;
|
||||
|
||||
let in1 = cut_into_nodes(&mut graph, i);
|
||||
let in2 = cut_into_nodes(&mut graph, j);
|
||||
|
||||
let result = mul_graph(&mut graph, &sks, &in1, &in2);
|
||||
|
||||
for (i, j) in zip_eq(&result, cut_into_block(i.wrapping_mul(j))) {
|
||||
expected_outputs.push((*i, j));
|
||||
}
|
||||
|
||||
// println!("{:?}", Dot::with_config(&graph, &[Config::NodeNoLabel]));
|
||||
|
||||
let (graph, duration) = self.async_pbs_graph_queue_master1(sks.clone(), graph);
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
for (node_index, expected_decryption) in expected_outputs {
|
||||
let node = graph.node_weight(node_index).unwrap();
|
||||
|
||||
let ct = match node {
|
||||
Node::Computed(ct) => ct,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// dbg!(cks.decrypt_message_and_carry(ct), expected_decryption);
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
|
||||
}
|
||||
println!("All good 7");
|
||||
|
||||
println!("MPI {num_blocks} block mul in {}s", duration_sec);
|
||||
|
||||
panic!();
|
||||
} else {
|
||||
self.async_pbs_graph_queue_worker1();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_graph(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
lhs: &[NodeIndex],
|
||||
rhs: &[NodeIndex],
|
||||
) -> Vec<NodeIndex> {
|
||||
let len = lhs.len();
|
||||
|
||||
assert_eq!(len, rhs.len());
|
||||
|
||||
let mut terms_for_mul_low: Vec<BinaryHeap<NodeWithDepth>> =
|
||||
compute_terms_for_mul_low(graph, sks, lhs, rhs)
|
||||
.into_iter()
|
||||
.map(|a| {
|
||||
a.into_iter()
|
||||
.map(|node| NodeWithDepth { node, depth: 0 })
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
terms_for_mul_low.reverse();
|
||||
|
||||
assert_eq!(len, terms_for_mul_low.len());
|
||||
|
||||
let mut sum_messages = vec![];
|
||||
|
||||
let mut sum_carries = vec![];
|
||||
|
||||
let first_list = terms_for_mul_low.pop().unwrap();
|
||||
|
||||
assert_eq!(first_list.len(), 1);
|
||||
let (first_message, first_carry) = sum_blocks(graph, sks, first_list, None);
|
||||
|
||||
assert!(first_carry.is_none());
|
||||
|
||||
for _ in 1..(len - 1) {
|
||||
let messages = terms_for_mul_low.pop().unwrap();
|
||||
|
||||
let carries = terms_for_mul_low.last_mut();
|
||||
|
||||
let (message, carry) = sum_blocks(graph, sks, messages, carries);
|
||||
|
||||
sum_messages.push(message);
|
||||
sum_carries.push(carry.unwrap());
|
||||
}
|
||||
|
||||
let (last_message, last_carry) = sum_blocks(graph, sks, terms_for_mul_low.pop().unwrap(), None);
|
||||
|
||||
assert!(terms_for_mul_low.is_empty());
|
||||
|
||||
sum_messages.push(last_message);
|
||||
|
||||
assert!(last_carry.is_none());
|
||||
|
||||
let mut result = vec![];
|
||||
|
||||
result.push(first_message);
|
||||
|
||||
result.push(sum_messages.remove(0));
|
||||
|
||||
assert_eq!(sum_messages.len(), sum_carries.len());
|
||||
|
||||
result.extend(&add_propagate_carry(
|
||||
graph,
|
||||
sks,
|
||||
&sum_messages,
|
||||
&sum_carries,
|
||||
));
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn compute_terms_for_mul_low(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
lhs: &[NodeIndex],
|
||||
rhs: &[NodeIndex],
|
||||
) -> Vec<Vec<NodeIndex>> {
|
||||
let message_modulus = sks.message_modulus.0 as u64;
|
||||
assert!(message_modulus <= sks.carry_modulus.0 as u64);
|
||||
|
||||
assert_eq!(rhs.len(), rhs.len());
|
||||
let len = rhs.len();
|
||||
|
||||
let mut message_part_terms_generator = vec![vec![]; len];
|
||||
|
||||
for (i, rhs_block) in rhs.iter().enumerate() {
|
||||
for (j, lhs_block) in lhs.iter().enumerate() {
|
||||
if (i + j) < len {
|
||||
let node = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::BivarMulLow,
|
||||
});
|
||||
|
||||
graph.add_edge(*lhs_block, node, 1);
|
||||
graph.add_edge(*rhs_block, node, message_modulus);
|
||||
|
||||
message_part_terms_generator[i + j].push(node);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message_modulus > 2 {
|
||||
for (i, rhs_block) in rhs.iter().enumerate() {
|
||||
for (j, lhs_block) in lhs.iter().enumerate() {
|
||||
if (i + j + 1) < len {
|
||||
let node = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::BivarMulHigh,
|
||||
});
|
||||
|
||||
graph.add_edge(*lhs_block, node, 1);
|
||||
graph.add_edge(*rhs_block, node, message_modulus);
|
||||
|
||||
message_part_terms_generator[i + j + 1].push(node);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message_part_terms_generator
|
||||
}
|
||||
|
||||
struct NodeWithDepth {
|
||||
node: NodeIndex,
|
||||
depth: u32,
|
||||
}
|
||||
|
||||
impl PartialEq for NodeWithDepth {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.depth == other.depth
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for NodeWithDepth {}
|
||||
|
||||
impl PartialOrd for NodeWithDepth {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for NodeWithDepth {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
other.depth.cmp(&self.depth)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_blocks(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
mut messages: BinaryHeap<NodeWithDepth>,
|
||||
mut carries: Option<&mut BinaryHeap<NodeWithDepth>>,
|
||||
) -> (NodeIndex, Option<NodeIndex>) {
|
||||
assert!(!messages.is_empty());
|
||||
|
||||
let message_modulus = sks.message_modulus.0 as u64;
|
||||
|
||||
// We don´t want a carry bigger than message_modulus
|
||||
let group_size = ((message_modulus * message_modulus - 1) / (message_modulus - 1)) as usize;
|
||||
|
||||
if messages.len() == 1 {
|
||||
return (messages.pop().unwrap().node, None);
|
||||
}
|
||||
|
||||
let mut sum_n_most_shallow_terms = |messages: &mut BinaryHeap<NodeWithDepth>, to_add_now| {
|
||||
let mut adding: Vec<NodeIndex> = vec![];
|
||||
|
||||
let mut max_depth = 0;
|
||||
|
||||
for _ in 0..to_add_now {
|
||||
let NodeWithDepth { node, depth } = messages.pop().unwrap();
|
||||
|
||||
if depth > max_depth {
|
||||
max_depth = depth;
|
||||
}
|
||||
adding.push(node);
|
||||
}
|
||||
|
||||
if let Some(carries) = &mut carries {
|
||||
let (sum, carry) = checked_add(graph, sks, &adding, true);
|
||||
|
||||
messages.push(NodeWithDepth {
|
||||
node: sum,
|
||||
depth: max_depth + 1,
|
||||
});
|
||||
|
||||
carries.push(NodeWithDepth {
|
||||
node: carry.unwrap(),
|
||||
depth: max_depth + 1,
|
||||
});
|
||||
} else {
|
||||
let (sum, carry) = checked_add(graph, sks, &adding, false);
|
||||
|
||||
messages.push(NodeWithDepth {
|
||||
node: sum,
|
||||
depth: max_depth + 1,
|
||||
});
|
||||
|
||||
assert!(carry.is_none());
|
||||
}
|
||||
};
|
||||
|
||||
if messages.len() > group_size {
|
||||
let n = (messages.len() - group_size) / (group_size - 1);
|
||||
let y = messages.len() - group_size - n * (group_size - 1);
|
||||
|
||||
assert_eq!(messages.len(), group_size + (group_size - 1) * n + y);
|
||||
|
||||
if y != 0 {
|
||||
sum_n_most_shallow_terms(&mut messages, y + 1);
|
||||
}
|
||||
|
||||
assert_eq!(messages.len(), group_size + (group_size - 1) * n);
|
||||
|
||||
for _ in 0..n {
|
||||
sum_n_most_shallow_terms(&mut messages, group_size);
|
||||
}
|
||||
|
||||
assert!(messages.len() == group_size);
|
||||
}
|
||||
|
||||
let mut adding = vec![];
|
||||
|
||||
while let Some(NodeWithDepth { node, .. }) = messages.pop() {
|
||||
adding.push(node);
|
||||
}
|
||||
|
||||
if carries.is_some() {
|
||||
checked_add(graph, sks, &adding, true)
|
||||
} else {
|
||||
checked_add(graph, sks, &adding, false)
|
||||
}
|
||||
}
|
||||
|
||||
fn checked_add(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
blocks_ref: &[NodeIndex],
|
||||
build_carry: bool,
|
||||
) -> (NodeIndex, Option<NodeIndex>) {
|
||||
assert!(blocks_ref.len() > 1);
|
||||
|
||||
let message_modulus = sks.message_modulus.0 as u64;
|
||||
|
||||
// We don´t want a carry bigger than message_modulus
|
||||
let group_size = (message_modulus * message_modulus - 1) / (message_modulus - 1);
|
||||
|
||||
assert!(blocks_ref.len() <= group_size as usize);
|
||||
|
||||
let sum = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::ExtractMessage,
|
||||
});
|
||||
|
||||
for i in blocks_ref {
|
||||
graph.add_edge(*i, sum, 1);
|
||||
}
|
||||
|
||||
let carry = if build_carry {
|
||||
let new_carry = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::ExtractCarry,
|
||||
});
|
||||
|
||||
for i in blocks_ref {
|
||||
graph.add_edge(*i, new_carry, 1);
|
||||
}
|
||||
|
||||
Some(new_carry)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(sum, carry)
|
||||
}
|
||||
|
||||
fn add_propagate_carry(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
ct1: &[NodeIndex],
|
||||
ct2: &[NodeIndex],
|
||||
) -> Vec<NodeIndex> {
|
||||
let generates_or_propagates = generate_init_carry_array(graph, ct1, ct2);
|
||||
|
||||
let (input_carries, _output_carry) =
|
||||
compute_carry_propagation_parallelized_low_latency(graph, sks, generates_or_propagates);
|
||||
|
||||
(0..ct1.len())
|
||||
.map(|i| {
|
||||
let node = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::ExtractMessage,
|
||||
});
|
||||
|
||||
graph.add_edge(ct1[i], node, 1);
|
||||
graph.add_edge(ct2[i], node, 1);
|
||||
if i > 0 {
|
||||
graph.add_edge(input_carries[i - 1], node, 1);
|
||||
}
|
||||
node
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn compute_carry_propagation_parallelized_low_latency(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
generates_or_propagates: Vec<NodeIndex>,
|
||||
) -> (Vec<NodeIndex>, NodeIndex) {
|
||||
let modulus = sks.message_modulus.0 as u64;
|
||||
|
||||
// Type annotations are required, otherwise we get confusing errors
|
||||
// "implementation of `FnOnce` is not general enough"
|
||||
let sum_function =
|
||||
|graph: &mut Graph<Node, u64>, block_carry: NodeIndex, previous_block_carry: NodeIndex| {
|
||||
let node = graph.add_node(Node::ToCompute {
|
||||
lookup_table: Lut::PrefixSumCarryPropagation,
|
||||
});
|
||||
|
||||
graph.add_edge(block_carry, node, modulus);
|
||||
|
||||
graph.add_edge(previous_block_carry, node, 1);
|
||||
|
||||
node
|
||||
};
|
||||
|
||||
let mut carries_out =
|
||||
compute_prefix_sum_hillis_steele(graph, sks, generates_or_propagates, sum_function);
|
||||
|
||||
let last_block_out_carry = carries_out.pop().unwrap();
|
||||
(carries_out, last_block_out_carry)
|
||||
}
|
||||
|
||||
fn generate_init_carry_array(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
ct1: &[NodeIndex],
|
||||
ct2: &[NodeIndex],
|
||||
) -> Vec<NodeIndex> {
|
||||
let generates_or_propagates: Vec<_> = ct1
|
||||
.iter()
|
||||
.zip_eq(ct2.iter())
|
||||
.enumerate()
|
||||
.map(|(i, (block1, block2))| {
|
||||
let lookup_table = if i == 0 {
|
||||
// The first block can only output a carry
|
||||
Lut::DoesBlockGenerateCarry
|
||||
} else {
|
||||
Lut::DoesBlockGenerateOrPropagate
|
||||
};
|
||||
|
||||
let node = graph.add_node(Node::ToCompute { lookup_table });
|
||||
|
||||
graph.add_edge(*block1, node, 1);
|
||||
graph.add_edge(*block2, node, 1);
|
||||
|
||||
node
|
||||
})
|
||||
.collect();
|
||||
|
||||
generates_or_propagates
|
||||
}
|
||||
|
||||
pub(crate) fn compute_prefix_sum_hillis_steele<F>(
|
||||
graph: &mut Graph<Node, u64>,
|
||||
sks: &ServerKey,
|
||||
mut generates_or_propagates: Vec<NodeIndex>,
|
||||
sum_function: F,
|
||||
) -> Vec<NodeIndex>
|
||||
where
|
||||
F: for<'a> Fn(&'a mut Graph<Node, u64>, NodeIndex, NodeIndex) -> NodeIndex + Sync,
|
||||
{
|
||||
debug_assert!(sks.message_modulus.0 * sks.carry_modulus.0 >= (1 << 4));
|
||||
|
||||
let num_blocks = generates_or_propagates.len();
|
||||
let num_steps = generates_or_propagates.len().ceil_ilog2() as usize;
|
||||
|
||||
let mut space = 1;
|
||||
let mut step_output = generates_or_propagates.clone();
|
||||
for _ in 0..num_steps {
|
||||
for (i, block) in step_output[space..num_blocks].iter_mut().enumerate() {
|
||||
let prev_block_carry = generates_or_propagates[i];
|
||||
*block = sum_function(graph, *block, prev_block_carry);
|
||||
}
|
||||
for i in space..num_blocks {
|
||||
generates_or_propagates[i].clone_from(&step_output[i]);
|
||||
}
|
||||
|
||||
space *= 2;
|
||||
}
|
||||
|
||||
generates_or_propagates
|
||||
}
|
||||
|
||||
#[repr(u64)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum OutputCarry {
|
||||
/// The block does not generate nor propagate a carry
|
||||
None = 0,
|
||||
/// The block generates a carry
|
||||
Generated = 1,
|
||||
/// The block will propagate a carry if it ever
|
||||
/// receives one
|
||||
Propagated = 2,
|
||||
}
|
||||
|
||||
pub fn prefix_sum_carry_propagation(msb: u64, lsb: u64) -> u64 {
|
||||
if msb == OutputCarry::Propagated as u64 {
|
||||
lsb
|
||||
} else {
|
||||
msb
|
||||
}
|
||||
}
|
||||
57
mpi_test/src/examples/async_small_graph.rs
Normal file
57
mpi_test/src/examples/async_small_graph.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::async_pbs_graph::Node;
|
||||
use crate::context::Context;
|
||||
use petgraph::Graph;
|
||||
use std::sync::Arc;
|
||||
use tfhe::shortint::gen_keys;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
|
||||
impl Context {
|
||||
pub fn async_small_graph(&self) {
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let mut graph = Graph::new();
|
||||
|
||||
let encrypted = cks.unchecked_encrypt(1);
|
||||
|
||||
let f = |x| (x + 1) % 16;
|
||||
let g = |x| (x + 2) % 16;
|
||||
|
||||
let node1 = graph.add_node(Node::Computed(encrypted));
|
||||
let node2 = graph.add_node(Node::ToCompute {
|
||||
lookup_table: sks.generate_lookup_table(f),
|
||||
});
|
||||
|
||||
let node3 = graph.add_node(Node::ToCompute {
|
||||
lookup_table: sks.generate_lookup_table(g),
|
||||
});
|
||||
|
||||
graph.add_edge(node1, node2, 1);
|
||||
|
||||
graph.add_edge(node2, node3, 1);
|
||||
graph.add_edge(node1, node3, 2);
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let (graph, duration) = self.async_pbs_graph_queue_master1(sks, graph);
|
||||
|
||||
let _duration_sec = duration.as_secs_f32();
|
||||
|
||||
// println!("{N} PBS in {}s", duration_sec);
|
||||
// println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
|
||||
let node = graph.node_weight(node3).unwrap();
|
||||
|
||||
let ct = match node {
|
||||
Node::Computed(ct) => ct,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), g(2 + f(1)));
|
||||
|
||||
println!("All good 5");
|
||||
} else {
|
||||
self.async_pbs_graph_queue_worker1();
|
||||
}
|
||||
}
|
||||
}
|
||||
92
mpi_test/src/examples/async_small_mul.rs
Normal file
92
mpi_test/src/examples/async_small_mul.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use crate::async_pbs_graph::Node;
|
||||
use crate::context::Context;
|
||||
use core::panic;
|
||||
use petgraph::Graph;
|
||||
use std::sync::Arc;
|
||||
use tfhe::shortint::gen_keys;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
|
||||
impl Context {
|
||||
pub fn async_small_mul(&self) {
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let bivar_mul_lut = sks.generate_lookup_table_bivariate(|a, b| (a * b) % 4).acc;
|
||||
|
||||
let mut graph = Graph::new();
|
||||
let mut expected_outputs = vec![];
|
||||
|
||||
for j in 0..16 {
|
||||
for i in 0..16 {
|
||||
let in1_low = graph.add_node(Node::Computed(cks.unchecked_encrypt(j % 4)));
|
||||
let in1_high = graph.add_node(Node::Computed(cks.unchecked_encrypt(j / 4)));
|
||||
|
||||
let in2_low = graph.add_node(Node::Computed(cks.unchecked_encrypt(i % 4)));
|
||||
let in2_high = graph.add_node(Node::Computed(cks.unchecked_encrypt(i / 4)));
|
||||
|
||||
let out_low = graph.add_node(Node::ToCompute {
|
||||
lookup_table: bivar_mul_lut.clone(),
|
||||
});
|
||||
|
||||
graph.add_edge(in1_low, out_low, 1);
|
||||
graph.add_edge(in2_low, out_low, 4);
|
||||
|
||||
let out_high_0 = graph.add_node(Node::ToCompute {
|
||||
lookup_table: //sks.generate_lookup_table(|a| (((a / 4) * (a % 4)) / 4) % 4),
|
||||
sks.generate_lookup_table_bivariate(|a, b| ((a * b) / 4)%4).acc
|
||||
|
||||
});
|
||||
|
||||
graph.add_edge(in1_low, out_high_0, 1);
|
||||
graph.add_edge(in2_low, out_high_0, 4);
|
||||
|
||||
let out_high_1 = graph.add_node(Node::ToCompute {
|
||||
lookup_table: bivar_mul_lut.clone(),
|
||||
});
|
||||
|
||||
graph.add_edge(in1_low, out_high_1, 1);
|
||||
graph.add_edge(in2_high, out_high_1, 4);
|
||||
|
||||
let out_high_2 = graph.add_node(Node::ToCompute {
|
||||
lookup_table: bivar_mul_lut.clone(),
|
||||
});
|
||||
|
||||
graph.add_edge(in1_high, out_high_2, 1);
|
||||
graph.add_edge(in2_low, out_high_2, 4);
|
||||
|
||||
let out_high = graph.add_node(Node::ToCompute {
|
||||
lookup_table: sks.generate_lookup_table(|a| a % 4),
|
||||
});
|
||||
|
||||
graph.add_edge(out_high_1, out_high, 1);
|
||||
graph.add_edge(out_high_2, out_high, 1);
|
||||
graph.add_edge(out_high_0, out_high, 1);
|
||||
|
||||
expected_outputs.push((out_low, (i * j) % 4));
|
||||
expected_outputs.push((out_high, ((i * j) / 4) % 4));
|
||||
}
|
||||
}
|
||||
|
||||
let (graph, duration) = self.async_pbs_graph_queue_master1(sks.clone(), graph);
|
||||
|
||||
let _duration_sec = duration.as_secs_f32();
|
||||
|
||||
for (node_index, expected_decryption) in expected_outputs {
|
||||
let node = graph.node_weight(node_index).unwrap();
|
||||
|
||||
let ct = match node {
|
||||
Node::Computed(ct) => ct,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
|
||||
}
|
||||
println!("All good 6");
|
||||
panic!();
|
||||
} else {
|
||||
self.async_pbs_graph_queue_worker1();
|
||||
}
|
||||
}
|
||||
}
|
||||
78
mpi_test/src/examples/local.rs
Normal file
78
mpi_test/src/examples/local.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use crate::context::Context;
|
||||
use std::time::Instant;
|
||||
use tfhe::shortint::gen_keys;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
|
||||
fn local() {
|
||||
const N: u64 = 1;
|
||||
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
let mut inputs = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let ct = cks.unchecked_encrypt(i % 16);
|
||||
|
||||
inputs.push(ct);
|
||||
}
|
||||
|
||||
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let _outputs: Vec<_> = inputs
|
||||
.iter()
|
||||
// .par_iter()
|
||||
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
|
||||
.collect();
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
println!("{N} PBS in {}s", duration_sec);
|
||||
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
}
|
||||
|
||||
fn local_mul(num_blocks: usize) {
|
||||
use tfhe::integer::gen_keys_radix;
|
||||
|
||||
// Generate the client key and the server key:
|
||||
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
|
||||
|
||||
let clear_1: u64 = 255;
|
||||
let clear_2: u64 = 143;
|
||||
|
||||
// Encrypt two messages
|
||||
let ctxt_1 = cks.encrypt(clear_1);
|
||||
let ctxt_2 = cks.encrypt(clear_2);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Compute homomorphically a multiplication
|
||||
let _ct_res = sks.unchecked_mul_parallelized(&ctxt_1, &ctxt_2);
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
// Decrypt
|
||||
// let res: u64 = cks.decrypt(&ct_res);
|
||||
// assert_eq!((clear_1 * clear_2) % 256, res);
|
||||
|
||||
println!("{num_blocks} block mul in {}s", duration_sec);
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn run_local_on_root(&self) {
|
||||
if self.is_root {
|
||||
local();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_local_mul_on_root(&self, num_blocks: usize) {
|
||||
if self.is_root {
|
||||
local_mul(num_blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
9
mpi_test/src/examples/mod.rs
Normal file
9
mpi_test/src/examples/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod async_batch;
|
||||
// pub mod async_flat_graph;
|
||||
// pub mod async_list;
|
||||
pub mod async_mul;
|
||||
// pub mod async_small_graph;
|
||||
// pub mod async_small_mul;
|
||||
pub mod local;
|
||||
pub mod sync_pbs_batch;
|
||||
pub mod test_request;
|
||||
107
mpi_test/src/examples/sync_pbs_batch.rs
Normal file
107
mpi_test/src/examples/sync_pbs_batch.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::context::Context;
|
||||
use crate::N;
|
||||
use mpi::traits::*;
|
||||
use std::time::Instant;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
|
||||
|
||||
impl Context {
|
||||
pub fn sync_pbs_batch(&self) {
|
||||
let root_process = self.world.process_at_rank(self.root_rank);
|
||||
|
||||
let mut cks_opt = None;
|
||||
|
||||
let mut sks_serialized = vec![];
|
||||
let mut sks_serialized_len = 0;
|
||||
|
||||
if self.is_root {
|
||||
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
|
||||
cks_opt = Some(cks);
|
||||
|
||||
sks_serialized = bincode::serialize(&sks).unwrap();
|
||||
sks_serialized_len = sks_serialized.len();
|
||||
}
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized_len);
|
||||
|
||||
if sks_serialized.is_empty() {
|
||||
sks_serialized = vec![0; sks_serialized_len];
|
||||
}
|
||||
|
||||
root_process.broadcast_into(&mut sks_serialized);
|
||||
|
||||
let sks: ServerKey = bincode::deserialize(&sks_serialized).unwrap();
|
||||
|
||||
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
|
||||
|
||||
if self.is_root {
|
||||
let cks = cks_opt.as_ref().unwrap();
|
||||
|
||||
let mut inputs = vec![];
|
||||
|
||||
for i in 0..N {
|
||||
let ct = cks.unchecked_encrypt(i % 16);
|
||||
|
||||
inputs.push(ct);
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
let elements_per_node = N as usize / self.size;
|
||||
|
||||
for dest_rank in 1..self.size {
|
||||
let process = self.world.process_at_rank(dest_rank as i32);
|
||||
|
||||
let inputs_chunk =
|
||||
&inputs[elements_per_node * dest_rank..elements_per_node * (dest_rank + 1)];
|
||||
|
||||
let inputs_chunk_serialized = bincode::serialize(inputs_chunk).unwrap();
|
||||
|
||||
process.send(&inputs_chunk_serialized);
|
||||
}
|
||||
|
||||
let mut outputs: Vec<_> = inputs[0..elements_per_node]
|
||||
.iter()
|
||||
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
|
||||
.collect();
|
||||
|
||||
for dest_rank in 1..self.size {
|
||||
let process = self.world.process_at_rank(dest_rank as i32);
|
||||
|
||||
let (outputs_chunks_serialized, _status) = process.receive_vec();
|
||||
|
||||
let outputs_chunk: Vec<Ciphertext> =
|
||||
bincode::deserialize(&outputs_chunks_serialized).unwrap();
|
||||
|
||||
outputs.extend(outputs_chunk);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
let duration_sec = duration.as_secs_f32();
|
||||
|
||||
println!("{N} PBS in {}s", duration_sec);
|
||||
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
|
||||
|
||||
for (i, ct) in outputs.iter().enumerate() {
|
||||
assert_eq!(cks.decrypt_message_and_carry(ct), (i as u64 + 1) % 16);
|
||||
}
|
||||
|
||||
println!("All good 1");
|
||||
} else {
|
||||
let (inputs_chunks_serialized, _status) = root_process.receive_vec();
|
||||
|
||||
let inputs_chunk: Vec<Ciphertext> =
|
||||
bincode::deserialize(&inputs_chunks_serialized).unwrap();
|
||||
|
||||
let outputs_chunk: Vec<_> = inputs_chunk
|
||||
.iter()
|
||||
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
|
||||
.collect();
|
||||
|
||||
let outputs_chunk_serialized = bincode::serialize(&outputs_chunk).unwrap();
|
||||
|
||||
root_process.send(&outputs_chunk_serialized);
|
||||
}
|
||||
}
|
||||
}
|
||||
33
mpi_test/src/examples/test_request.rs
Normal file
33
mpi_test/src/examples/test_request.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use crate::context::Context;
|
||||
use crate::managers::{advance_receiving, Receiving, Sending};
|
||||
use mpi::traits::*;
|
||||
|
||||
impl Context {
|
||||
pub fn test_request(&self) {
|
||||
let tag = 1;
|
||||
|
||||
if self.is_root {
|
||||
let process = self.world.process_at_rank(1);
|
||||
|
||||
for i in 0..3 {
|
||||
let Sending { buffer: _, a } = Sending::new(vec![1, 2, i], &process, tag);
|
||||
a.unwrap().wait();
|
||||
}
|
||||
} else {
|
||||
let process = self.world.process_at_rank(0);
|
||||
|
||||
let mut receive = Some(Receiving::new(&process, tag));
|
||||
|
||||
for _ in 0..3 {
|
||||
let buffer = loop {
|
||||
if let Some(buffer) = advance_receiving(&mut receive) {
|
||||
break buffer;
|
||||
}
|
||||
};
|
||||
|
||||
dbg!(buffer);
|
||||
}
|
||||
receive.unwrap().abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
30
mpi_test/src/main.rs
Normal file
30
mpi_test/src/main.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use context::Context;
|
||||
|
||||
const N: u64 = 25;
|
||||
fn main() {
|
||||
let context = Context::new();
|
||||
|
||||
// simple_logger::init().unwrap();
|
||||
|
||||
// context.run_local_on_root();
|
||||
|
||||
// context.sync_pbs_batch();
|
||||
|
||||
// context.async_pbs_batch();
|
||||
|
||||
// context.test_request();
|
||||
|
||||
// context.async_pbs_list_queue();
|
||||
|
||||
// context.async_small_mul();
|
||||
|
||||
context.run_local_mul_on_root(32);
|
||||
|
||||
context.async_mul(32);
|
||||
}
|
||||
|
||||
pub mod async_pbs_graph;
|
||||
pub mod async_task_graph;
|
||||
pub mod context;
|
||||
pub mod examples;
|
||||
pub mod managers;
|
||||
78
mpi_test/src/managers.rs
Normal file
78
mpi_test/src/managers.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use mpi::request::Request;
|
||||
use mpi::topology::Process;
|
||||
use mpi::traits::*;
|
||||
use mpi::Tag;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::mem::transmute;
|
||||
use tfhe::shortint::Ciphertext;
|
||||
|
||||
const MAX_SIZE: usize = 100_000;
|
||||
|
||||
pub struct Receiving {
|
||||
pub buffer: Vec<u8>,
|
||||
pub future: Request<'static, [u8]>,
|
||||
}
|
||||
|
||||
// impl Drop for Receiving {
|
||||
// fn drop(&mut self) {
|
||||
// panic!("Here")
|
||||
// }
|
||||
// }
|
||||
|
||||
impl Receiving {
|
||||
pub fn new(process: &Process, tag: Tag) -> Self {
|
||||
let mut buffer = vec![0; MAX_SIZE];
|
||||
|
||||
let future = process
|
||||
.immediate_receive_into_with_tag(unsafe { transmute(buffer.as_mut_slice()) }, tag);
|
||||
|
||||
Self { buffer, future }
|
||||
}
|
||||
|
||||
pub fn abort(self) {
|
||||
// self.future.cancel();
|
||||
std::mem::forget(self.future);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn advance_receiving(receiving: &mut Option<Receiving>) -> Option<Vec<u8>> {
|
||||
let receiver = receiving.take().unwrap();
|
||||
|
||||
match receiver.future.test() {
|
||||
Ok(_status) => Some(receiver.buffer),
|
||||
Err(future) => {
|
||||
*receiving = Some(Receiving {
|
||||
buffer: receiver.buffer,
|
||||
future,
|
||||
});
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Sending {
|
||||
pub buffer: Vec<u8>,
|
||||
pub a: Option<Request<'static, [u8]>>,
|
||||
}
|
||||
|
||||
impl Sending {
|
||||
pub fn new(buffer: Vec<u8>, process: &Process, tag: Tag) -> Self {
|
||||
assert!(buffer.len() < MAX_SIZE);
|
||||
|
||||
let a = Some(process.immediate_send_with_tag(unsafe { transmute(buffer.as_slice()) }, tag));
|
||||
|
||||
Self { buffer, a }
|
||||
}
|
||||
|
||||
pub fn abort(self) {
|
||||
// self.a.unwrap().cancel();
|
||||
std::mem::forget(self.a);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct IndexedCt {
|
||||
pub index: usize,
|
||||
pub ct: Ciphertext,
|
||||
}
|
||||
@@ -563,6 +563,17 @@ pub fn lwe_ciphertext_cleartext_mul_assign<Scalar, InCont>(
|
||||
slice_wrapping_scalar_mul_assign(lhs.as_mut(), rhs.0);
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_add_cleartext_mul_assign<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: &LweCiphertext<InCont>,
|
||||
scalar: Cleartext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
slice_wrapping_add_scalar_mul_assign(lhs.as_mut(), rhs.as_ref(), scalar.0);
|
||||
}
|
||||
|
||||
/// Multiply the left-hand side [`LWE ciphertext`](`LweCiphertext`) by the right-hand side cleartext
|
||||
/// writing the result in the output [`LWE ciphertext`](`LweCiphertext`).
|
||||
///
|
||||
|
||||
@@ -363,7 +363,7 @@ impl ServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[must_use]
|
||||
pub struct LookupTable<C: Container<Element = u64>> {
|
||||
pub acc: GlweCiphertext<C>,
|
||||
|
||||
@@ -207,6 +207,15 @@ impl ServerKey {
|
||||
unchecked_scalar_mul_assign(ct, scalar);
|
||||
}
|
||||
|
||||
pub fn unchecked_add_scalar_mul_assign(
|
||||
&self,
|
||||
ct: &mut Ciphertext,
|
||||
scaled_ct: &Ciphertext,
|
||||
scalar: u8,
|
||||
) {
|
||||
unchecked_add_scalar_mul_assign(ct, scaled_ct, scalar);
|
||||
}
|
||||
|
||||
/// Multiply one ciphertext with a scalar in the case the carry space cannot fit the product
|
||||
/// applying the message space modulus in the process.
|
||||
///
|
||||
@@ -535,3 +544,17 @@ pub(crate) fn unchecked_scalar_mul_assign(ct: &mut Ciphertext, scalar: u8) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unchecked_add_scalar_mul_assign(
|
||||
ct: &mut Ciphertext,
|
||||
scaled_ct: &Ciphertext,
|
||||
scalar: u8,
|
||||
) {
|
||||
ct.set_noise_level(ct.noise_level() + scaled_ct.noise_level() * scalar as usize);
|
||||
ct.degree = Degree::new(ct.degree.get() + scaled_ct.degree.get() * scalar as usize);
|
||||
|
||||
let scalar = u64::from(scalar);
|
||||
let cleartext_scalar = Cleartext(scalar);
|
||||
|
||||
lwe_ciphertext_add_cleartext_mul_assign(&mut ct.ct, &scaled_ct.ct, cleartext_scalar);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user