Compare commits

...

51 Commits

Author SHA1 Message Date
Mayeul@Zama
f4cf43da36 f 2024-02-26 14:32:44 +01:00
Mayeul@Zama
cdce78d3b8 f 2024-02-26 14:06:29 +01:00
Mayeul@Zama
46309a3da1 f 2024-02-26 11:19:10 +01:00
Mayeul@Zama
6de6ac0fc3 f 2024-02-23 19:00:26 +01:00
Mayeul@Zama
746cca0135 optimize sum, not working 2024-02-23 18:58:05 +01:00
Mayeul@Zama
425436dee2 add priority for results queue 2024-02-23 16:42:42 +01:00
Mayeul@Zama
28416693ca keep elemets for last sum to limit dependencies 2024-02-23 15:53:07 +01:00
Mayeul@Zama
04d9314c07 remove old test 2024-02-23 15:53:07 +01:00
Mayeul@Zama
9f9162c42f f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
99bbd0ed7d f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
07b2796904 multisum on other theads 2024-02-23 15:53:07 +01:00
Mayeul@Zama
b0c44aba2f f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
28798813c0 receive and send in separate threads on worker machines 2024-02-23 15:53:07 +01:00
Mayeul@Zama
add6bf8e5d remove useless function 2024-02-23 15:53:07 +01:00
Mayeul@Zama
367783c7dd serialization on different threads 2024-02-23 15:53:07 +01:00
Mayeul@Zama
3c5732ae3a fix name 2024-02-23 15:53:07 +01:00
Mayeul@Zama
1ceceb2e6f f fix mulitplie receivers 2024-02-23 15:53:07 +01:00
Mayeul@Zama
abc05e141a multiple receivers 2024-02-23 15:53:07 +01:00
Mayeul@Zama
0c95bb9024 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
fdf2451a5e use max size 2024-02-23 15:53:07 +01:00
Mayeul@Zama
3c240a8709 add logging 2024-02-23 15:53:07 +01:00
Mayeul@Zama
76ce4bd477 remove dep 2024-02-23 15:53:07 +01:00
Mayeul@Zama
6f4cbfb108 luts are indexed 2024-02-23 15:53:07 +01:00
Mayeul@Zama
160cd437a8 f rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
b5cb6b3d74 f rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
87f2bacaa0 f rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
9e60a39b44 f rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
d8e06acbf6 f rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
9af66d9d60 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
1c3cb60b56 to remove 2024-02-23 15:53:07 +01:00
Mayeul@Zama
8fea689097 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
54f454b9ea f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
d3cb8aa111 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
6654bdbe0c use fas carry propagation 2024-02-23 15:53:07 +01:00
Mayeul@Zama
ae3b33f644 add fast carry propagation 2024-02-23 15:53:07 +01:00
Mayeul@Zama
1c7b1e7fd9 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
c8414772c0 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
1ea776976e add priority 2024-02-23 15:53:07 +01:00
Mayeul@Zama
1c310f28db f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
04cfb1e009 rename 2024-02-23 15:53:07 +01:00
Mayeul@Zama
fe6015f4b1 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
bd172c342d f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
0bcab98438 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
0bdd133be7 f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
ec05c66ea2 make multisum faster 2024-02-23 15:53:07 +01:00
Mayeul@Zama
9340135e31 add add_scalar_mul op 2024-02-23 15:53:07 +01:00
Mayeul@Zama
83e461873f factorize 2024-02-23 15:53:07 +01:00
Mayeul@Zama
a2ad55fedd f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
105466a14c f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
c4e9bd836a f 2024-02-23 15:53:07 +01:00
Mayeul@Zama
d676639200 add mpi test 2024-02-23 15:53:07 +01:00
21 changed files with 2305 additions and 2 deletions

View File

@@ -1,6 +1,13 @@
[workspace]
resolver = "2"
members = ["tfhe", "tasks", "apps/trivium", "concrete-csprng", "backends/tfhe-cuda-backend"]
members = [
"tfhe",
"tasks",
"apps/trivium",
"concrete-csprng",
"backends/tfhe-cuda-backend",
"mpi_test",
]
[profile.bench]
lto = "fat"

33
mpi_test/Cargo.toml Normal file
View File

@@ -0,0 +1,33 @@
[package]
name = "mpi_test"
version = "0.4.0"
edition = "2021"
license = "BSD-3-Clause-Clear"
description = "Cryptographically Secure PRNG used in the TFHE-rs library."
homepage = "https://zama.ai/"
documentation = "https://docs.zama.ai/tfhe-rs"
repository = "https://github.com/zama-ai/tfhe-rs"
readme = "README.md"
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"]
rust-version = "1.72"
[dependencies]
mpi = { path = "../../rsmpi" }
tfhe = { path = "../tfhe", features = [
"shortint",
"integer",
"x86_64",
"internal-keycache",
] }
bincode = "1.3"
threadpool = "1.8"
crossbeam-channel = "0.5"
serde = { version = "1.0", features = ["derive", "rc"] }
petgraph = "0.6.4"
itertools = "*"
thread-priority = "0.15.1"
async-priority-channel = "0.2.0"
futures = "0.3.30"
logging_timer = "1.1.0"
simple_logger = "4.3.3"
smallvec = { version = "1.13.1", features = ["serde"] }

11
mpi_test/run.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
cargo build --profile=release
source /etc/profile.d/modules.sh
module load mpi/openmpi-x86_64
# export LD_LIBRARY_PATH="/usr/lib64/mpich/lib/"
export LD_LIBRARY_PATH="/usr/lib64/openmpi/lib/"
export RUSTFLAGS="-Ctarget-cpu=native"
mpirun -n 6 ../target/devo/mpi_test

View File

@@ -0,0 +1,445 @@
use crate::async_task_graph::{Priority, TaskGraph};
use crate::context::Context;
use crate::examples::async_mul::{prefix_sum_carry_propagation, OutputCarry};
use logging_timer::time;
use mpi::traits::*;
use petgraph::algo::is_cyclic_directed;
use petgraph::stable_graph::NodeIndex;
use petgraph::visit::EdgeRef;
use petgraph::Direction::{Incoming, Outgoing};
use petgraph::Graph;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tfhe::shortint::server_key::LookupTableOwned;
use tfhe::shortint::{Ciphertext, ServerKey};
#[derive(Clone, Serialize, Deserialize)]
pub struct IndexedCt {
index: usize,
ct: Ciphertext,
}
#[derive(Copy, Clone, Serialize, Deserialize)]
pub enum Lut {
ExtractMessage,
ExtractCarry,
BivarMulLow,
BivarMulHigh,
PrefixSumCarryPropagation,
DoesBlockGenerateCarry,
DoesBlockGenerateOrPropagate,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct IndexedCtsAndLut {
index: usize,
cts_and_weights: SmallVec<[(u64, Arc<Ciphertext>); 5]>,
lut: Lut,
}
impl IndexedCtsAndLut {
fn multisum(self, sks: &ServerKey) -> IndexedCtAndLut {
let IndexedCtsAndLut {
index,
cts_and_weights,
lut,
} = self;
let mut cts_and_weights = cts_and_weights.into_iter();
let (first_scalar, first_ct) = cts_and_weights.next().unwrap();
let mut multisum_result = sks.unchecked_scalar_mul(&first_ct, first_scalar as u8);
for (scalar, ct) in cts_and_weights {
sks.unchecked_add_scalar_mul_assign(&mut multisum_result, &ct, scalar as u8);
}
IndexedCtAndLut {
index,
ct: multisum_result,
lut,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct IndexedCtAndLut {
index: usize,
ct: Ciphertext,
lut: Lut,
}
impl IndexedCtAndLut {
fn pbs(self, sks: &ServerKey, luts: &Luts) -> IndexedCt {
IndexedCt {
ct: sks.apply_lookup_table(&self.ct, luts.get(self.lut)),
index: self.index,
}
}
}
#[derive(Clone)]
pub enum Node {
Computed(Arc<Ciphertext>),
BootsrapQueued,
ToCompute { lookup_table: Lut },
}
impl Node {
fn ct(&self) -> Option<&Arc<Ciphertext>> {
match self {
Node::Computed(ct) => Some(ct),
_ => None,
}
}
}
impl std::fmt::Debug for Node {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Computed(_) => f.debug_tuple("Computed").finish(),
Self::BootsrapQueued => write!(f, "BootsrapQueued"),
Self::ToCompute { .. } => f.debug_struct("ToCompute").finish(),
}
}
}
pub struct FheGraph {
graph: Graph<(Priority, Node), u64>,
not_computed_nodes_count: usize,
}
fn insert_predecessors_recursively(
graph: &Graph<Node, u64>,
successors_max_depths: &mut HashMap<usize, i32>,
node_index: NodeIndex,
) {
if successors_max_depths.contains_key(&node_index.index()) {
return;
}
if graph
.neighbors_directed(node_index, Outgoing)
.all(|successor| successors_max_depths.contains_key(&successor.index()))
{
let max_successors_depth = graph
.neighbors_directed(node_index, Outgoing)
.map(|successor| successors_max_depths[&successor.index()])
.max();
successors_max_depths.insert(node_index.index(), max_successors_depth.unwrap_or(0) + 1);
for predecessor in graph.neighbors_directed(node_index, Incoming) {
insert_predecessors_recursively(graph, successors_max_depths, predecessor);
}
}
}
impl FheGraph {
pub fn new(graph: Graph<Node, u64>) -> Self {
let not_computed_nodes_count = graph
.node_weights()
.filter(|node| !matches!(&node, Node::Computed(_)))
.count();
let mut successors_max_depth = HashMap::new();
for node_index in graph.node_indices() {
if graph.edges_directed(node_index, Outgoing).next().is_none() {
insert_predecessors_recursively(&graph, &mut successors_max_depth, node_index);
}
}
dbg!(&successors_max_depth.values().max());
let graph = graph.map(
|node_index, node| {
(
Priority(successors_max_depth[&node_index.index()]),
node.clone(),
)
},
|_, edge| *edge,
);
Self {
graph,
not_computed_nodes_count,
}
}
fn test_graph_init(&self) {
assert!(!is_cyclic_directed(&self.graph));
for i in self.graph.node_indices() {
if self.graph.neighbors_directed(i, Incoming).next().is_none() {
assert!(matches!(
&self.graph.node_weight(i),
Some((_, Node::Computed(_)))
))
} else {
assert!(matches!(
&self.graph.node_weight(i),
Some((_, Node::ToCompute { .. }))
))
}
}
}
fn assert_finishable(&self) {
assert!(!is_cyclic_directed(&self.graph));
for i in self.graph.node_indices() {
if self.graph.neighbors_directed(i, Incoming).next().is_none() {
assert!(matches!(
&self.graph.node_weight(i),
Some((_, Node::Computed(_)))
))
}
}
}
#[time]
fn predecessors_list(&self, index: NodeIndex) -> SmallVec<[(u64, Arc<Ciphertext>); 5]> {
self.graph
.edges_directed(index, Incoming)
.map(|edge| {
(
*edge.weight(),
self.graph[edge.source()].1.ct().unwrap().clone(),
)
})
.collect()
}
#[time]
fn build_task(&mut self, index: NodeIndex) -> (Priority, IndexedCtsAndLut) {
let cts_and_weights = self.predecessors_list(index);
let lut = match self.graph.node_weight(index) {
Some((_, Node::ToCompute { lookup_table })) => lookup_table.to_owned(),
_ => unreachable!(),
};
self.graph.node_weight_mut(index).unwrap().1 = Node::BootsrapQueued;
(
Priority(0),
IndexedCtsAndLut {
index: index.index(),
cts_and_weights,
lut,
},
)
}
}
impl TaskGraph for FheGraph {
type Task = IndexedCtsAndLut;
type Result = IndexedCt;
#[time]
fn init(&mut self) -> Vec<(Priority, IndexedCtsAndLut)> {
self.test_graph_init();
let nodes_to_compute: Vec<_> = self
.graph
.node_indices()
.filter(|&i| {
let to_compute =
matches!(self.graph.node_weight(i).unwrap().1, Node::ToCompute { .. });
let all_predecessors_computed =
self.graph
.neighbors_directed(i, Incoming)
.all(|predecessor| {
matches!(
self.graph.node_weight(predecessor).unwrap().1,
Node::Computed(_)
)
});
to_compute && all_predecessors_computed
})
.collect();
nodes_to_compute
.into_iter()
.map(|index| self.build_task(index))
.collect()
}
#[time]
fn commit_result(&mut self, result: IndexedCt) -> Vec<(Priority, IndexedCtsAndLut)> {
self.not_computed_nodes_count -= 1;
// dbg!(self.not_computed_nodes_count);
let IndexedCt { index, ct } = result;
let index = NodeIndex::new(index);
let node_mut = self.graph.node_weight_mut(index).unwrap();
assert!(matches!(node_mut.1, Node::BootsrapQueued));
node_mut.1 = Node::Computed(Arc::new(ct));
let nodes_to_compute: Vec<_> = self
.graph
.neighbors_directed(index, Outgoing)
.filter(|&i| {
assert!(matches!(
self.graph.node_weight(i).unwrap().1,
Node::ToCompute { .. }
));
let all_predecessors_computed =
self.graph
.neighbors_directed(i, Incoming)
.all(|predecessor| {
matches!(
self.graph.node_weight(predecessor).unwrap().1,
Node::Computed(_)
)
});
all_predecessors_computed
})
.collect();
nodes_to_compute
.into_iter()
.map(|index| self.build_task(index))
.collect()
}
fn is_finished(&self) -> bool {
self.not_computed_nodes_count == 0
}
}
impl Context {
pub fn async_pbs_graph_queue_master1(
&self,
sks: Arc<ServerKey>,
graph: Graph<Node, u64>,
) -> (Graph<Node, u64>, Duration) {
let luts = Luts::new(&sks);
let root_process = self.world.process_at_rank(self.root_rank);
let mut sks_serialized = bincode::serialize(sks.as_ref()).unwrap();
let mut sks_serialized_len = sks_serialized.len();
let mut graph = FheGraph::new(graph);
graph.assert_finishable();
root_process.broadcast_into(&mut sks_serialized_len);
root_process.broadcast_into(sks_serialized.as_mut_slice());
let start = Instant::now();
self.async_task_graph_queue_master::<_, _, IndexedCtsAndLut, IndexedCtAndLut, IndexedCt>(
&mut graph,
(sks, luts),
move |(sks, luts), input| input.multisum(sks).pbs(sks, luts),
move |(sks, _), task| task.multisum(sks),
);
let duration = start.elapsed();
(
graph.graph.map(|_, node| node.1.clone(), |_, edge| *edge),
duration,
)
}
pub fn async_pbs_graph_queue_worker1(&self) {
let root_process = self.world.process_at_rank(self.root_rank);
let mut sks_serialized_len = 0;
root_process.broadcast_into(&mut sks_serialized_len);
let mut sks_serialized = vec![0; sks_serialized_len];
root_process.broadcast_into(&mut sks_serialized);
let sks: Arc<ServerKey> = Arc::new(bincode::deserialize(&sks_serialized).unwrap());
let luts = Luts::new(&sks);
self.async_task_graph_queue_worker::<_, IndexedCtAndLut, IndexedCt>(
(sks, luts),
|(sks, luts), input| input.pbs(sks, luts),
);
panic!()
}
}
#[derive(Clone, Serialize, Deserialize)]
struct Luts {
extract_message: LookupTableOwned,
extract_carry: LookupTableOwned,
bivar_mul_low: LookupTableOwned,
bivar_mul_high: LookupTableOwned,
prefix_sum_carry_propagation: LookupTableOwned,
does_block_generate_carry: LookupTableOwned,
does_block_generate_or_propagate: LookupTableOwned,
}
impl Luts {
fn new(sks: &ServerKey) -> Self {
let message_modulus = sks.message_modulus.0 as u64;
Self {
extract_message: sks.generate_lookup_table(|x| x % message_modulus),
extract_carry: sks.generate_lookup_table(|x| x / message_modulus),
bivar_mul_low: sks
.generate_lookup_table_bivariate(|x, y| (x * y) % message_modulus)
.acc,
bivar_mul_high: sks
.generate_lookup_table_bivariate(|x, y| (x * y) / message_modulus)
.acc,
prefix_sum_carry_propagation: sks
.generate_lookup_table_bivariate(prefix_sum_carry_propagation)
.acc,
does_block_generate_carry: sks.generate_lookup_table(|x| {
if x >= message_modulus {
OutputCarry::Generated as u64
} else {
OutputCarry::None as u64
}
}),
does_block_generate_or_propagate: sks.generate_lookup_table(|x| {
if x >= message_modulus {
OutputCarry::Generated as u64
} else if x == (message_modulus - 1) {
OutputCarry::Propagated as u64
} else {
OutputCarry::None as u64
}
}),
}
}
fn get(&self, lut: Lut) -> &LookupTableOwned {
match lut {
Lut::ExtractMessage => &self.extract_message,
Lut::ExtractCarry => &self.extract_carry,
Lut::BivarMulLow => &self.bivar_mul_low,
Lut::BivarMulHigh => &self.bivar_mul_high,
Lut::PrefixSumCarryPropagation => &self.prefix_sum_carry_propagation,
Lut::DoesBlockGenerateCarry => &self.does_block_generate_carry,
Lut::DoesBlockGenerateOrPropagate => &self.does_block_generate_or_propagate,
}
}
}

View File

@@ -0,0 +1,373 @@
use crate::context::Context;
use crate::managers::{Receiving, Sending};
use async_priority_channel::{unbounded, Receiver, Sender};
use futures::executor::block_on;
use mpi::topology::Process;
use mpi::traits::*;
use mpi::Tag;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::mem::transmute;
use thread_priority::{set_current_thread_priority, ThreadPriority, ThreadPriorityValue};
use threadpool::ThreadPool;
const MASTER_TO_WORKER: Tag = 0;
const WORKER_TO_MASTER: Tag = 1;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Priority(pub i32);
pub trait TaskGraph {
type Task;
type Result;
fn init(&mut self) -> Vec<(Priority, Self::Task)>;
fn commit_result(&mut self, result: Self::Result) -> Vec<(Priority, Self::Task)>;
// fn no_work_in_queue(&self) -> bool;
fn is_finished(&self) -> bool;
}
struct ClusterCharge {
available_parallelism: usize,
charge: Vec<usize>,
}
impl Context {
pub fn async_task_graph_queue_master<
T: Sync + Clone + Send + 'static,
U: TaskGraph<Task = Task, Result = Result>,
Task: Send + 'static,
RemoteTask: Serialize + DeserializeOwned + Send + 'static,
Result: Serialize + DeserializeOwned + Send + 'static,
>(
&self,
task_graph: &mut U,
state: T,
f: impl Fn(&T, Task) -> Result + Sync + Clone + Send + 'static,
convert: impl Fn(&T, Task) -> RemoteTask + Sync + Clone + Send + 'static,
) {
let (send_task, receive_task) = unbounded::<Task, Priority>();
let (send_result, receive_result) = unbounded::<(Result, usize), Priority>();
// let mut sent_inputs = vec![];
let mut charge = ClusterCharge {
available_parallelism: std::thread::available_parallelism().unwrap().get(),
charge: vec![0; self.size],
};
{
let state = state.clone();
let n_workers = (std::thread::available_parallelism().unwrap().get() - 1).max(1);
let priority =
ThreadPriority::Crossplatform(ThreadPriorityValue::try_from(32).unwrap());
launch_threadpool(
priority,
n_workers,
&receive_task,
&send_result,
move |receive_task, send_result, state| {
let f = f.clone();
let (input, priority) = block_on(receive_task.recv()).unwrap();
let result = f(state, input);
block_on(send_result.send((result, 0), priority)).unwrap();
},
state,
);
}
let worker_senders: Vec<_> = (1..self.size)
.map(|rank| {
let (send_task, receive_task) = unbounded::<Task, Priority>();
let process_at_rank: Process<'static> =
unsafe { transmute(self.world.process_at_rank(rank as i32)) };
let convert = convert.clone();
let state = state.clone();
std::thread::spawn(move || {
// set_current_thread_priority(priority).unwrap();
let mut sent_inputs = vec![];
while let Ok((task, priority)) = block_on(receive_task.recv()) {
let remote_task = convert(&state, task);
let buffer = bincode::serialize(&(remote_task, priority)).unwrap();
sent_inputs.push(Sending::new(buffer, &process_at_rank, MASTER_TO_WORKER))
}
for a in sent_inputs {
a.abort()
}
});
send_task
})
.collect();
for rank in 1..self.size {
let send_result = send_result.clone();
let process_at_rank: Process<'static> =
unsafe { transmute(self.world.process_at_rank(rank as i32)) };
std::thread::spawn(move || {
// set_current_thread_priority(priority).unwrap();
let mut receives = VecDeque::new();
for _ in 0..100 {
receives.push_back(Some(Receiving::new(&process_at_rank, WORKER_TO_MASTER)))
}
loop {
let Receiving { buffer, future } = receives.pop_front().unwrap().unwrap();
receives.push_back(Some(Receiving::new(&process_at_rank, WORKER_TO_MASTER)));
future.wait();
let (result, priority) = bincode::deserialize(&buffer).unwrap();
block_on(send_result.send((result, rank), priority)).unwrap();
}
});
}
for (priority, task) in task_graph.init() {
self.enqueue_request(&mut charge, &send_task, priority, task, &worker_senders);
}
let mut empty = 0;
let mut not_empty = 0;
while !task_graph.is_finished() {
if receive_result.is_empty() {
empty += 1;
} else {
not_empty += 1;
}
let ((result, rank), _priority) = block_on(receive_result.recv()).unwrap();
charge.charge[rank] -= 1;
self.handle_new_result(task_graph, result, &mut charge, &send_task, &worker_senders);
}
dbg!(empty, not_empty);
for i in charge.charge {
assert_eq!(i, 0);
}
std::mem::forget(send_task);
}
fn handle_new_result<U: TaskGraph>(
&self,
task_graph: &mut U,
result: U::Result,
charge: &mut ClusterCharge,
send_task: &Sender<U::Task, Priority>,
sent_inputs: &[Sender<U::Task, Priority>],
) {
let new_tasks = task_graph.commit_result(result);
for (priority, task) in new_tasks {
self.enqueue_request(charge, send_task, priority, task, sent_inputs);
}
}
fn enqueue_request<Task>(
&self,
charge: &mut ClusterCharge,
send_task: &Sender<Task, Priority>,
priority: Priority,
task: Task,
sent_inputs: &[Sender<Task, Priority>],
) {
let rank = if charge.charge[self.root_rank as usize] < charge.available_parallelism {
self.root_rank as usize
} else {
charge
.charge
.iter()
.enumerate()
.min_by_key(|(_index, charge)| *charge)
.unwrap()
.0
};
charge.charge[rank] += 1;
if rank == self.root_rank as usize {
block_on(send_task.send(task, priority)).unwrap();
} else {
block_on(sent_inputs[rank - 1].send(task, priority)).unwrap();
}
}
pub fn async_task_graph_queue_worker<
T: Sync + Clone + Send + 'static,
RemoteTask: Serialize + DeserializeOwned + Send,
Result: Serialize + DeserializeOwned + Send,
>(
&self,
state: T,
f: impl Fn(&T, RemoteTask) -> Result + Sync + Clone + Send + 'static,
) {
let f = move |state: &T, serialized_input: &Vec<u8>| {
let (input, priority): (RemoteTask, Priority) =
bincode::deserialize(serialized_input).unwrap();
let result = f(state, input);
bincode::serialize(&(result, priority)).unwrap()
};
let (send_task, receive_task) = crossbeam_channel::unbounded::<Vec<u8>>();
let (send_result, receive_result) = crossbeam_channel::unbounded::<Vec<u8>>();
{
let state = state.clone();
let f = f.clone();
let n_workers = (std::thread::available_parallelism().unwrap().get() - 1).max(1);
let priority =
ThreadPriority::Crossplatform(ThreadPriorityValue::try_from(32).unwrap());
launch_threadpool2(
priority,
n_workers,
&receive_task,
&send_result,
move |receive_task, send_result, state| {
let f = f.clone();
let input = receive_task.recv().unwrap();
let result = f(state, &input);
send_result.send(result).unwrap();
},
state,
);
}
let root_process = self.world.process_at_rank(self.root_rank);
{
let root_process: Process<'static> =
unsafe { transmute(self.world.process_at_rank(self.root_rank)) };
std::thread::spawn(move || {
let mut receives = VecDeque::new();
for _ in 0..100 {
receives.push_back(Some(Receiving::new(&root_process, MASTER_TO_WORKER)))
}
loop {
let Receiving { buffer, future } = receives.pop_front().unwrap().unwrap();
receives.push_back(Some(Receiving::new(&root_process, MASTER_TO_WORKER)));
future.wait();
send_task.send(buffer).unwrap();
}
});
}
let mut send: VecDeque<Sending> = VecDeque::new();
'outer: loop {
if let Ok(output) = receive_result.recv() {
send.push_back(Sending::new(output, &root_process, WORKER_TO_MASTER));
}
while let Some(front) = send.front_mut() {
if let Some(a) = front.a.take() {
match a.test() {
Ok(_) => {
let b = send.pop_front();
assert!(b.unwrap().a.is_none());
}
Err(front_a) => {
front.a = Some(front_a);
continue 'outer;
}
}
}
}
}
}
}
fn launch_threadpool<
T: Clone + Send + 'static,
U: Send + 'static,
V: Send + 'static,
W: Send + Ord + 'static,
// X: Send + Ord + 'static,
>(
priority: ThreadPriority,
n_workers: usize,
receive_task: &Receiver<U, W>,
send_result: &Sender<V, Priority>,
function: impl Fn(&Receiver<U, W>, &Sender<V, Priority>, &T) + Send + Clone + 'static,
state: T,
) {
let pool = ThreadPool::new(n_workers);
for _ in 0..n_workers {
let receive_task = receive_task.clone();
let send_result = send_result.clone();
let function = function.clone();
let state = state.clone();
pool.execute(move || {
set_current_thread_priority(priority).unwrap();
loop {
function(&receive_task, &send_result, &state);
}
});
}
}
fn launch_threadpool2<T: Clone + Send + 'static, U: Send + 'static, V: Send + 'static>(
priority: ThreadPriority,
n_workers: usize,
receive_task: &crossbeam_channel::Receiver<U>,
send_result: &crossbeam_channel::Sender<V>,
function: impl Fn(&crossbeam_channel::Receiver<U>, &crossbeam_channel::Sender<V>, &T)
+ Send
+ Clone
+ 'static,
state: T,
) {
let pool = ThreadPool::new(n_workers);
for _ in 0..n_workers {
let receive_task = receive_task.clone();
let send_result = send_result.clone();
let function = function.clone();
let state = state.clone();
pool.execute(move || {
set_current_thread_priority(priority).unwrap();
loop {
function(&receive_task, &send_result, &state);
}
});
}
}

36
mpi_test/src/context.rs Normal file
View File

@@ -0,0 +1,36 @@
use mpi::environment::Universe;
use mpi::topology::SimpleCommunicator;
use mpi::traits::*;
use mpi::Threading;
pub struct Context {
pub universe: Universe,
pub world: SimpleCommunicator,
pub size: usize,
pub rank: i32,
pub root_rank: i32,
pub is_root: bool,
}
#[allow(clippy::new_without_default)]
impl Context {
pub fn new() -> Self {
let (universe, _) = mpi::initialize_with_threading(Threading::Multiple).unwrap();
let world = universe.world();
let size = world.size() as usize;
let rank = world.rank();
let root_rank = 0;
let is_root = rank == root_rank;
Context {
universe,
world,
size,
rank,
root_rank,
is_root,
}
}
}

View File

@@ -0,0 +1,201 @@
use crate::context::Context;
use crate::N;
use mpi::traits::*;
use std::time::Instant;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
impl Context {
pub fn async_pbs_batch(&self) {
let root_process = self.world.process_at_rank(self.root_rank);
let mut cks_opt = None;
let mut sks_serialized = vec![];
let mut sks_serialized_len = 0;
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
cks_opt = Some(cks);
sks_serialized = bincode::serialize(&sks).unwrap();
sks_serialized_len = sks_serialized.len();
}
root_process.broadcast_into(&mut sks_serialized_len);
if sks_serialized.is_empty() {
sks_serialized = vec![0; sks_serialized_len];
}
root_process.broadcast_into(&mut sks_serialized);
let sks: ServerKey = bincode::deserialize(&sks_serialized).unwrap();
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
if self.is_root {
let cks = cks_opt.as_ref().unwrap();
let mut inputs = vec![];
for i in 0..N {
let ct = cks.unchecked_encrypt(i % 16);
inputs.push(ct);
}
let start = Instant::now();
let elements_per_node = N as usize / self.size;
let serialized: Vec<_> = (1..self.size)
.map(|dest_rank| {
let inputs_chunk =
&inputs[elements_per_node * dest_rank..elements_per_node * (dest_rank + 1)];
bincode::serialize(inputs_chunk).unwrap()
})
.collect();
let lens: Vec<_> = serialized.iter().map(|a| a.len()).collect();
let sent_len: Vec<_> = lens
.iter()
.enumerate()
.map(|(i, a)| {
let dest_rank = i as i32 + 1;
let process = self.world.process_at_rank(dest_rank);
process.immediate_send(a)
})
.collect();
let sent_vec: Vec<_> = serialized
.iter()
.enumerate()
.map(|(i, a)| {
let dest_rank = i as i32 + 1;
let process = self.world.process_at_rank(dest_rank);
process.immediate_send(a)
})
.collect();
for i in sent_len {
i.wait();
}
for i in sent_vec {
i.wait();
}
let mut outputs: Vec<_> = inputs[0..elements_per_node]
.iter()
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
.collect();
let lens: Vec<_> = (1..self.size)
.map(|dest_rank| {
let process = self.world.process_at_rank(dest_rank as i32);
process.immediate_receive()
})
.collect();
let mut results: Vec<Vec<u8>> =
lens.into_iter().map(|len| vec![0; len.get().0]).collect();
let sent: Vec<_> = results
.iter_mut()
.enumerate()
.map(|(i, a)| {
let dest_rank = i as i32 + 1;
let process = self.world.process_at_rank(dest_rank);
process.immediate_receive_into(a)
})
.collect();
for i in sent {
i.wait();
}
for result in results.iter() {
let outputs_chunk: Vec<Ciphertext> = bincode::deserialize(result).unwrap();
outputs.extend(outputs_chunk);
}
let duration = start.elapsed();
let duration_sec = duration.as_secs_f32();
println!("{N} PBS in {}s", duration_sec);
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
for (i, ct) in outputs.iter().enumerate() {
assert_eq!(cks.decrypt_message_and_carry(ct), (i as u64 + 1) % 16);
}
println!("All good 2");
} else {
let (len, _) = root_process.receive();
let mut input = vec![0; len];
// let mut status;
root_process.receive_into(input.as_mut_slice());
let input: Vec<Ciphertext> = bincode::deserialize(&input).unwrap();
let output: Vec<_> = input
.iter()
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
.collect();
let output = bincode::serialize(&output).unwrap();
root_process.send(&output.len());
root_process.send(&output);
}
}
pub fn test_mpi_immediate(&self) {
let root_process = self.world.process_at_rank(self.root_rank);
if self.is_root {
let process = self.world.process_at_rank(1);
let input = vec![1, 2, 3];
let len = [input.len()];
let a = process.immediate_send(&len);
let b = process.immediate_send(input.as_slice());
// drop(b);
let b2 = process.immediate_send(input.as_slice());
a.wait();
b.wait();
b2.wait();
// let (outputs_chunks_serialized, _status) = process.receive_vec();
} else if self.rank == 1 {
let (len, _) = root_process.receive();
let mut input = vec![0; len];
// let mut status;
let future = root_process.immediate_receive_into(input.as_mut_slice());
future.wait();
dbg!(input);
}
}
}

View File

@@ -0,0 +1,65 @@
use crate::async_pbs_graph::Node;
use crate::context::Context;
use crate::N;
use petgraph::Graph;
use std::sync::Arc;
use tfhe::shortint::gen_keys;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
impl Context {
pub fn async_flat_graph(&self) {
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let mut graph = Graph::new();
let mut expected_outputs = vec![];
for i in 0..N {
let plain = i % 16;
let encrypted = cks.unchecked_encrypt(plain);
let f = |x| x + 2;
let lookup_table = sks.generate_lookup_table(f);
// dbg!(cks.decrypt_message_and_carry(&sks.apply_lookup_table(&encrypted,
// &lookup_table)));
let input = graph.add_node(Node::Computed(encrypted));
let output = graph.add_node(Node::ToCompute {
lookup_table: lookup_table.clone(),
});
graph.add_edge(input, output, 1);
expected_outputs.push((output, f(plain)));
}
let sks = Arc::new(sks);
let (graph, duration) = self.async_pbs_graph_queue_master1(sks, graph);
let duration_sec = duration.as_secs_f32();
println!("{N} PBS in {}s", duration_sec);
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
for (node_index, expected_decryption) in expected_outputs {
let node = graph.node_weight(node_index).unwrap();
let ct = match node {
Node::Computed(ct) => ct,
_ => unreachable!(),
};
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
}
println!("All good 4");
} else {
self.async_pbs_graph_queue_worker1();
}
}
}

View File

@@ -0,0 +1,120 @@
use crate::async_task_graph::{Priority, TaskGraph};
use crate::context::Context;
use crate::managers::IndexedCt;
use crate::N;
use mpi::traits::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::server_key::LookupTableOwned;
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
struct ListOfPbs {
pub inputs: Vec<Ciphertext>,
pub outputs: HashMap<usize, Ciphertext>,
}
impl ListOfPbs {
fn new(inputs: Vec<Ciphertext>) -> Self {
Self {
inputs,
outputs: HashMap::new(),
}
}
}
impl TaskGraph for ListOfPbs {
type Task = IndexedCt;
type Result = IndexedCt;
fn init(&mut self) -> Vec<(Priority, IndexedCt)> {
self.inputs
.clone()
.into_iter()
.enumerate()
.map(|(i, ct)| (Priority(0), IndexedCt { index: i, ct }))
.collect()
}
fn commit_result(&mut self, result: IndexedCt) -> Vec<(Priority, IndexedCt)> {
self.outputs.insert(result.index, result.ct);
vec![]
}
fn is_finished(&self) -> bool {
self.outputs.len() == self.inputs.len()
}
}
impl Context {
pub fn async_pbs_list_queue(&self) {
if self.is_root {
let root_process = self.world.process_at_rank(self.root_rank);
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let mut sks_serialized = bincode::serialize(&sks).unwrap();
let mut sks_serialized_len = sks_serialized.len();
let sks = Arc::new(sks);
root_process.broadcast_into(&mut sks_serialized_len);
root_process.broadcast_into(&mut sks_serialized);
let lookup_table = Arc::new(sks.generate_lookup_table(|x| (x + 1) % 16));
let inputs: Vec<_> = (0..N).map(|i| cks.unchecked_encrypt(i % 16)).collect();
let start = Instant::now();
let mut a = ListOfPbs::new(inputs);
self.async_task_graph_queue_master(
&mut a,
(sks, lookup_table),
|(sks, lookup_table), input| run_pbs(input, sks, lookup_table),
);
let duration = start.elapsed();
let duration_sec = duration.as_secs_f32();
println!("{N} PBS in {}s", duration_sec);
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
for (i, ct) in a.outputs.iter() {
assert_eq!(cks.decrypt_message_and_carry(ct), (*i as u64 + 1) % 16);
}
println!("All good 3");
} else {
let root_process = self.world.process_at_rank(self.root_rank);
let mut sks_serialized_len = 0;
root_process.broadcast_into(&mut sks_serialized_len);
let mut sks_serialized = vec![0; sks_serialized_len];
root_process.broadcast_into(&mut sks_serialized);
let sks: Arc<ServerKey> = Arc::new(bincode::deserialize(&sks_serialized).unwrap());
let lookup_table = Arc::new(sks.generate_lookup_table(|x| (x + 1) % 16));
self.async_task_graph_queue_worker(
(sks, lookup_table),
|(sks, lookup_table), input| run_pbs(input, sks, lookup_table),
);
}
}
}
fn run_pbs(input: &IndexedCt, sks: &ServerKey, lookup_table: &LookupTableOwned) -> IndexedCt {
IndexedCt {
ct: sks.apply_lookup_table(&input.ct, lookup_table),
index: input.index,
}
}

View File

@@ -0,0 +1,494 @@
use crate::async_pbs_graph::{Lut, Node};
use crate::context::Context;
use core::panic;
use itertools::{zip_eq, Itertools};
use petgraph::prelude::NodeIndex;
use petgraph::Graph;
use std::collections::BinaryHeap;
use std::sync::Arc;
use tfhe::core_crypto::commons::traits::UnsignedInteger;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::{gen_keys, ServerKey};
impl Context {
pub fn async_mul(&self, num_blocks: i32) {
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let sks = Arc::new(sks);
let mut graph = Graph::new();
let mut expected_outputs = vec![];
let cut_into_block = |number| {
let mut number = number;
(0..num_blocks)
.map(|_| {
let new = number % 4;
number /= 4;
new
})
.collect::<Vec<_>>()
};
let cut_into_nodes = |graph: &mut Graph<Node, u64>, number: u64| {
cut_into_block(number)
.into_iter()
.map(|block| {
graph.add_node(Node::Computed(Arc::new(cks.unchecked_encrypt(block))))
})
.collect::<Vec<_>>()
};
let i = 24533;
let j = 53864;
let in1 = cut_into_nodes(&mut graph, i);
let in2 = cut_into_nodes(&mut graph, j);
let result = mul_graph(&mut graph, &sks, &in1, &in2);
for (i, j) in zip_eq(&result, cut_into_block(i.wrapping_mul(j))) {
expected_outputs.push((*i, j));
}
// println!("{:?}", Dot::with_config(&graph, &[Config::NodeNoLabel]));
let (graph, duration) = self.async_pbs_graph_queue_master1(sks.clone(), graph);
let duration_sec = duration.as_secs_f32();
for (node_index, expected_decryption) in expected_outputs {
let node = graph.node_weight(node_index).unwrap();
let ct = match node {
Node::Computed(ct) => ct,
_ => unreachable!(),
};
// dbg!(cks.decrypt_message_and_carry(ct), expected_decryption);
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
}
println!("All good 7");
println!("MPI {num_blocks} block mul in {}s", duration_sec);
panic!();
} else {
self.async_pbs_graph_queue_worker1();
}
}
}
pub fn mul_graph(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
lhs: &[NodeIndex],
rhs: &[NodeIndex],
) -> Vec<NodeIndex> {
let len = lhs.len();
assert_eq!(len, rhs.len());
let mut terms_for_mul_low: Vec<BinaryHeap<NodeWithDepth>> =
compute_terms_for_mul_low(graph, sks, lhs, rhs)
.into_iter()
.map(|a| {
a.into_iter()
.map(|node| NodeWithDepth { node, depth: 0 })
.collect()
})
.collect();
terms_for_mul_low.reverse();
assert_eq!(len, terms_for_mul_low.len());
let mut sum_messages = vec![];
let mut sum_carries = vec![];
let first_list = terms_for_mul_low.pop().unwrap();
assert_eq!(first_list.len(), 1);
let (first_message, first_carry) = sum_blocks(graph, sks, first_list, None);
assert!(first_carry.is_none());
for _ in 1..(len - 1) {
let messages = terms_for_mul_low.pop().unwrap();
let carries = terms_for_mul_low.last_mut();
let (message, carry) = sum_blocks(graph, sks, messages, carries);
sum_messages.push(message);
sum_carries.push(carry.unwrap());
}
let (last_message, last_carry) = sum_blocks(graph, sks, terms_for_mul_low.pop().unwrap(), None);
assert!(terms_for_mul_low.is_empty());
sum_messages.push(last_message);
assert!(last_carry.is_none());
let mut result = vec![];
result.push(first_message);
result.push(sum_messages.remove(0));
assert_eq!(sum_messages.len(), sum_carries.len());
result.extend(&add_propagate_carry(
graph,
sks,
&sum_messages,
&sum_carries,
));
result
}
fn compute_terms_for_mul_low(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
lhs: &[NodeIndex],
rhs: &[NodeIndex],
) -> Vec<Vec<NodeIndex>> {
let message_modulus = sks.message_modulus.0 as u64;
assert!(message_modulus <= sks.carry_modulus.0 as u64);
assert_eq!(rhs.len(), rhs.len());
let len = rhs.len();
let mut message_part_terms_generator = vec![vec![]; len];
for (i, rhs_block) in rhs.iter().enumerate() {
for (j, lhs_block) in lhs.iter().enumerate() {
if (i + j) < len {
let node = graph.add_node(Node::ToCompute {
lookup_table: Lut::BivarMulLow,
});
graph.add_edge(*lhs_block, node, 1);
graph.add_edge(*rhs_block, node, message_modulus);
message_part_terms_generator[i + j].push(node);
} else {
break;
}
}
}
if message_modulus > 2 {
for (i, rhs_block) in rhs.iter().enumerate() {
for (j, lhs_block) in lhs.iter().enumerate() {
if (i + j + 1) < len {
let node = graph.add_node(Node::ToCompute {
lookup_table: Lut::BivarMulHigh,
});
graph.add_edge(*lhs_block, node, 1);
graph.add_edge(*rhs_block, node, message_modulus);
message_part_terms_generator[i + j + 1].push(node);
} else {
break;
}
}
}
}
message_part_terms_generator
}
struct NodeWithDepth {
node: NodeIndex,
depth: u32,
}
impl PartialEq for NodeWithDepth {
fn eq(&self, other: &Self) -> bool {
self.depth == other.depth
}
}
impl Eq for NodeWithDepth {}
impl PartialOrd for NodeWithDepth {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for NodeWithDepth {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.depth.cmp(&self.depth)
}
}
fn sum_blocks(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
mut messages: BinaryHeap<NodeWithDepth>,
mut carries: Option<&mut BinaryHeap<NodeWithDepth>>,
) -> (NodeIndex, Option<NodeIndex>) {
assert!(!messages.is_empty());
let message_modulus = sks.message_modulus.0 as u64;
// We don´t want a carry bigger than message_modulus
let group_size = ((message_modulus * message_modulus - 1) / (message_modulus - 1)) as usize;
if messages.len() == 1 {
return (messages.pop().unwrap().node, None);
}
let mut sum_n_most_shallow_terms = |messages: &mut BinaryHeap<NodeWithDepth>, to_add_now| {
let mut adding: Vec<NodeIndex> = vec![];
let mut max_depth = 0;
for _ in 0..to_add_now {
let NodeWithDepth { node, depth } = messages.pop().unwrap();
if depth > max_depth {
max_depth = depth;
}
adding.push(node);
}
if let Some(carries) = &mut carries {
let (sum, carry) = checked_add(graph, sks, &adding, true);
messages.push(NodeWithDepth {
node: sum,
depth: max_depth + 1,
});
carries.push(NodeWithDepth {
node: carry.unwrap(),
depth: max_depth + 1,
});
} else {
let (sum, carry) = checked_add(graph, sks, &adding, false);
messages.push(NodeWithDepth {
node: sum,
depth: max_depth + 1,
});
assert!(carry.is_none());
}
};
if messages.len() > group_size {
let n = (messages.len() - group_size) / (group_size - 1);
let y = messages.len() - group_size - n * (group_size - 1);
assert_eq!(messages.len(), group_size + (group_size - 1) * n + y);
if y != 0 {
sum_n_most_shallow_terms(&mut messages, y + 1);
}
assert_eq!(messages.len(), group_size + (group_size - 1) * n);
for _ in 0..n {
sum_n_most_shallow_terms(&mut messages, group_size);
}
assert!(messages.len() == group_size);
}
let mut adding = vec![];
while let Some(NodeWithDepth { node, .. }) = messages.pop() {
adding.push(node);
}
if carries.is_some() {
checked_add(graph, sks, &adding, true)
} else {
checked_add(graph, sks, &adding, false)
}
}
fn checked_add(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
blocks_ref: &[NodeIndex],
build_carry: bool,
) -> (NodeIndex, Option<NodeIndex>) {
assert!(blocks_ref.len() > 1);
let message_modulus = sks.message_modulus.0 as u64;
// We don´t want a carry bigger than message_modulus
let group_size = (message_modulus * message_modulus - 1) / (message_modulus - 1);
assert!(blocks_ref.len() <= group_size as usize);
let sum = graph.add_node(Node::ToCompute {
lookup_table: Lut::ExtractMessage,
});
for i in blocks_ref {
graph.add_edge(*i, sum, 1);
}
let carry = if build_carry {
let new_carry = graph.add_node(Node::ToCompute {
lookup_table: Lut::ExtractCarry,
});
for i in blocks_ref {
graph.add_edge(*i, new_carry, 1);
}
Some(new_carry)
} else {
None
};
(sum, carry)
}
fn add_propagate_carry(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
ct1: &[NodeIndex],
ct2: &[NodeIndex],
) -> Vec<NodeIndex> {
let generates_or_propagates = generate_init_carry_array(graph, ct1, ct2);
let (input_carries, _output_carry) =
compute_carry_propagation_parallelized_low_latency(graph, sks, generates_or_propagates);
(0..ct1.len())
.map(|i| {
let node = graph.add_node(Node::ToCompute {
lookup_table: Lut::ExtractMessage,
});
graph.add_edge(ct1[i], node, 1);
graph.add_edge(ct2[i], node, 1);
if i > 0 {
graph.add_edge(input_carries[i - 1], node, 1);
}
node
})
.collect()
}
fn compute_carry_propagation_parallelized_low_latency(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
generates_or_propagates: Vec<NodeIndex>,
) -> (Vec<NodeIndex>, NodeIndex) {
let modulus = sks.message_modulus.0 as u64;
// Type annotations are required, otherwise we get confusing errors
// "implementation of `FnOnce` is not general enough"
let sum_function =
|graph: &mut Graph<Node, u64>, block_carry: NodeIndex, previous_block_carry: NodeIndex| {
let node = graph.add_node(Node::ToCompute {
lookup_table: Lut::PrefixSumCarryPropagation,
});
graph.add_edge(block_carry, node, modulus);
graph.add_edge(previous_block_carry, node, 1);
node
};
let mut carries_out =
compute_prefix_sum_hillis_steele(graph, sks, generates_or_propagates, sum_function);
let last_block_out_carry = carries_out.pop().unwrap();
(carries_out, last_block_out_carry)
}
fn generate_init_carry_array(
graph: &mut Graph<Node, u64>,
ct1: &[NodeIndex],
ct2: &[NodeIndex],
) -> Vec<NodeIndex> {
let generates_or_propagates: Vec<_> = ct1
.iter()
.zip_eq(ct2.iter())
.enumerate()
.map(|(i, (block1, block2))| {
let lookup_table = if i == 0 {
// The first block can only output a carry
Lut::DoesBlockGenerateCarry
} else {
Lut::DoesBlockGenerateOrPropagate
};
let node = graph.add_node(Node::ToCompute { lookup_table });
graph.add_edge(*block1, node, 1);
graph.add_edge(*block2, node, 1);
node
})
.collect();
generates_or_propagates
}
pub(crate) fn compute_prefix_sum_hillis_steele<F>(
graph: &mut Graph<Node, u64>,
sks: &ServerKey,
mut generates_or_propagates: Vec<NodeIndex>,
sum_function: F,
) -> Vec<NodeIndex>
where
F: for<'a> Fn(&'a mut Graph<Node, u64>, NodeIndex, NodeIndex) -> NodeIndex + Sync,
{
debug_assert!(sks.message_modulus.0 * sks.carry_modulus.0 >= (1 << 4));
let num_blocks = generates_or_propagates.len();
let num_steps = generates_or_propagates.len().ceil_ilog2() as usize;
let mut space = 1;
let mut step_output = generates_or_propagates.clone();
for _ in 0..num_steps {
for (i, block) in step_output[space..num_blocks].iter_mut().enumerate() {
let prev_block_carry = generates_or_propagates[i];
*block = sum_function(graph, *block, prev_block_carry);
}
for i in space..num_blocks {
generates_or_propagates[i].clone_from(&step_output[i]);
}
space *= 2;
}
generates_or_propagates
}
#[repr(u64)]
#[derive(PartialEq, Eq)]
pub enum OutputCarry {
/// The block does not generate nor propagate a carry
None = 0,
/// The block generates a carry
Generated = 1,
/// The block will propagate a carry if it ever
/// receives one
Propagated = 2,
}
pub fn prefix_sum_carry_propagation(msb: u64, lsb: u64) -> u64 {
if msb == OutputCarry::Propagated as u64 {
lsb
} else {
msb
}
}

View File

@@ -0,0 +1,57 @@
use crate::async_pbs_graph::Node;
use crate::context::Context;
use petgraph::Graph;
use std::sync::Arc;
use tfhe::shortint::gen_keys;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
impl Context {
pub fn async_small_graph(&self) {
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let mut graph = Graph::new();
let encrypted = cks.unchecked_encrypt(1);
let f = |x| (x + 1) % 16;
let g = |x| (x + 2) % 16;
let node1 = graph.add_node(Node::Computed(encrypted));
let node2 = graph.add_node(Node::ToCompute {
lookup_table: sks.generate_lookup_table(f),
});
let node3 = graph.add_node(Node::ToCompute {
lookup_table: sks.generate_lookup_table(g),
});
graph.add_edge(node1, node2, 1);
graph.add_edge(node2, node3, 1);
graph.add_edge(node1, node3, 2);
let sks = Arc::new(sks);
let (graph, duration) = self.async_pbs_graph_queue_master1(sks, graph);
let _duration_sec = duration.as_secs_f32();
// println!("{N} PBS in {}s", duration_sec);
// println!("{} ms/PBS", duration_sec * 1000. / N as f32);
let node = graph.node_weight(node3).unwrap();
let ct = match node {
Node::Computed(ct) => ct,
_ => unreachable!(),
};
assert_eq!(cks.decrypt_message_and_carry(ct), g(2 + f(1)));
println!("All good 5");
} else {
self.async_pbs_graph_queue_worker1();
}
}
}

View File

@@ -0,0 +1,92 @@
use crate::async_pbs_graph::Node;
use crate::context::Context;
use core::panic;
use petgraph::Graph;
use std::sync::Arc;
use tfhe::shortint::gen_keys;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
impl Context {
pub fn async_small_mul(&self) {
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let sks = Arc::new(sks);
let bivar_mul_lut = sks.generate_lookup_table_bivariate(|a, b| (a * b) % 4).acc;
let mut graph = Graph::new();
let mut expected_outputs = vec![];
for j in 0..16 {
for i in 0..16 {
let in1_low = graph.add_node(Node::Computed(cks.unchecked_encrypt(j % 4)));
let in1_high = graph.add_node(Node::Computed(cks.unchecked_encrypt(j / 4)));
let in2_low = graph.add_node(Node::Computed(cks.unchecked_encrypt(i % 4)));
let in2_high = graph.add_node(Node::Computed(cks.unchecked_encrypt(i / 4)));
let out_low = graph.add_node(Node::ToCompute {
lookup_table: bivar_mul_lut.clone(),
});
graph.add_edge(in1_low, out_low, 1);
graph.add_edge(in2_low, out_low, 4);
let out_high_0 = graph.add_node(Node::ToCompute {
lookup_table: //sks.generate_lookup_table(|a| (((a / 4) * (a % 4)) / 4) % 4),
sks.generate_lookup_table_bivariate(|a, b| ((a * b) / 4)%4).acc
});
graph.add_edge(in1_low, out_high_0, 1);
graph.add_edge(in2_low, out_high_0, 4);
let out_high_1 = graph.add_node(Node::ToCompute {
lookup_table: bivar_mul_lut.clone(),
});
graph.add_edge(in1_low, out_high_1, 1);
graph.add_edge(in2_high, out_high_1, 4);
let out_high_2 = graph.add_node(Node::ToCompute {
lookup_table: bivar_mul_lut.clone(),
});
graph.add_edge(in1_high, out_high_2, 1);
graph.add_edge(in2_low, out_high_2, 4);
let out_high = graph.add_node(Node::ToCompute {
lookup_table: sks.generate_lookup_table(|a| a % 4),
});
graph.add_edge(out_high_1, out_high, 1);
graph.add_edge(out_high_2, out_high, 1);
graph.add_edge(out_high_0, out_high, 1);
expected_outputs.push((out_low, (i * j) % 4));
expected_outputs.push((out_high, ((i * j) / 4) % 4));
}
}
let (graph, duration) = self.async_pbs_graph_queue_master1(sks.clone(), graph);
let _duration_sec = duration.as_secs_f32();
for (node_index, expected_decryption) in expected_outputs {
let node = graph.node_weight(node_index).unwrap();
let ct = match node {
Node::Computed(ct) => ct,
_ => unreachable!(),
};
assert_eq!(cks.decrypt_message_and_carry(ct), expected_decryption);
}
println!("All good 6");
panic!();
} else {
self.async_pbs_graph_queue_worker1();
}
}
}

View File

@@ -0,0 +1,78 @@
use crate::context::Context;
use std::time::Instant;
use tfhe::shortint::gen_keys;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
fn local() {
const N: u64 = 1;
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let mut inputs = vec![];
for i in 0..N {
let ct = cks.unchecked_encrypt(i % 16);
inputs.push(ct);
}
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
let start = Instant::now();
let _outputs: Vec<_> = inputs
.iter()
// .par_iter()
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
.collect();
let duration = start.elapsed();
let duration_sec = duration.as_secs_f32();
println!("{N} PBS in {}s", duration_sec);
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
}
fn local_mul(num_blocks: usize) {
use tfhe::integer::gen_keys_radix;
// Generate the client key and the server key:
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
let clear_1: u64 = 255;
let clear_2: u64 = 143;
// Encrypt two messages
let ctxt_1 = cks.encrypt(clear_1);
let ctxt_2 = cks.encrypt(clear_2);
let start = Instant::now();
// Compute homomorphically a multiplication
let _ct_res = sks.unchecked_mul_parallelized(&ctxt_1, &ctxt_2);
let duration = start.elapsed();
let duration_sec = duration.as_secs_f32();
// Decrypt
// let res: u64 = cks.decrypt(&ct_res);
// assert_eq!((clear_1 * clear_2) % 256, res);
println!("{num_blocks} block mul in {}s", duration_sec);
}
impl Context {
pub fn run_local_on_root(&self) {
if self.is_root {
local();
}
}
pub fn run_local_mul_on_root(&self, num_blocks: usize) {
if self.is_root {
local_mul(num_blocks);
}
}
}

View File

@@ -0,0 +1,9 @@
pub mod async_batch;
// pub mod async_flat_graph;
// pub mod async_list;
pub mod async_mul;
// pub mod async_small_graph;
// pub mod async_small_mul;
pub mod local;
pub mod sync_pbs_batch;
pub mod test_request;

View File

@@ -0,0 +1,107 @@
use crate::context::Context;
use crate::N;
use mpi::traits::*;
use std::time::Instant;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::{gen_keys, Ciphertext, ServerKey};
impl Context {
pub fn sync_pbs_batch(&self) {
let root_process = self.world.process_at_rank(self.root_rank);
let mut cks_opt = None;
let mut sks_serialized = vec![];
let mut sks_serialized_len = 0;
if self.is_root {
let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
cks_opt = Some(cks);
sks_serialized = bincode::serialize(&sks).unwrap();
sks_serialized_len = sks_serialized.len();
}
root_process.broadcast_into(&mut sks_serialized_len);
if sks_serialized.is_empty() {
sks_serialized = vec![0; sks_serialized_len];
}
root_process.broadcast_into(&mut sks_serialized);
let sks: ServerKey = bincode::deserialize(&sks_serialized).unwrap();
let lookup_table = sks.generate_lookup_table(|x| (x + 1) % 16);
if self.is_root {
let cks = cks_opt.as_ref().unwrap();
let mut inputs = vec![];
for i in 0..N {
let ct = cks.unchecked_encrypt(i % 16);
inputs.push(ct);
}
let start = Instant::now();
let elements_per_node = N as usize / self.size;
for dest_rank in 1..self.size {
let process = self.world.process_at_rank(dest_rank as i32);
let inputs_chunk =
&inputs[elements_per_node * dest_rank..elements_per_node * (dest_rank + 1)];
let inputs_chunk_serialized = bincode::serialize(inputs_chunk).unwrap();
process.send(&inputs_chunk_serialized);
}
let mut outputs: Vec<_> = inputs[0..elements_per_node]
.iter()
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
.collect();
for dest_rank in 1..self.size {
let process = self.world.process_at_rank(dest_rank as i32);
let (outputs_chunks_serialized, _status) = process.receive_vec();
let outputs_chunk: Vec<Ciphertext> =
bincode::deserialize(&outputs_chunks_serialized).unwrap();
outputs.extend(outputs_chunk);
}
let duration = start.elapsed();
let duration_sec = duration.as_secs_f32();
println!("{N} PBS in {}s", duration_sec);
println!("{} ms/PBS", duration_sec * 1000. / N as f32);
for (i, ct) in outputs.iter().enumerate() {
assert_eq!(cks.decrypt_message_and_carry(ct), (i as u64 + 1) % 16);
}
println!("All good 1");
} else {
let (inputs_chunks_serialized, _status) = root_process.receive_vec();
let inputs_chunk: Vec<Ciphertext> =
bincode::deserialize(&inputs_chunks_serialized).unwrap();
let outputs_chunk: Vec<_> = inputs_chunk
.iter()
.map(|ct| sks.apply_lookup_table(ct, &lookup_table))
.collect();
let outputs_chunk_serialized = bincode::serialize(&outputs_chunk).unwrap();
root_process.send(&outputs_chunk_serialized);
}
}
}

View File

@@ -0,0 +1,33 @@
use crate::context::Context;
use crate::managers::{advance_receiving, Receiving, Sending};
use mpi::traits::*;
impl Context {
pub fn test_request(&self) {
let tag = 1;
if self.is_root {
let process = self.world.process_at_rank(1);
for i in 0..3 {
let Sending { buffer: _, a } = Sending::new(vec![1, 2, i], &process, tag);
a.unwrap().wait();
}
} else {
let process = self.world.process_at_rank(0);
let mut receive = Some(Receiving::new(&process, tag));
for _ in 0..3 {
let buffer = loop {
if let Some(buffer) = advance_receiving(&mut receive) {
break buffer;
}
};
dbg!(buffer);
}
receive.unwrap().abort();
}
}
}

30
mpi_test/src/main.rs Normal file
View File

@@ -0,0 +1,30 @@
use context::Context;
const N: u64 = 25;
fn main() {
let context = Context::new();
// simple_logger::init().unwrap();
// context.run_local_on_root();
// context.sync_pbs_batch();
// context.async_pbs_batch();
// context.test_request();
// context.async_pbs_list_queue();
// context.async_small_mul();
context.run_local_mul_on_root(32);
context.async_mul(32);
}
pub mod async_pbs_graph;
pub mod async_task_graph;
pub mod context;
pub mod examples;
pub mod managers;

78
mpi_test/src/managers.rs Normal file
View File

@@ -0,0 +1,78 @@
use mpi::request::Request;
use mpi::topology::Process;
use mpi::traits::*;
use mpi::Tag;
use serde::{Deserialize, Serialize};
use std::mem::transmute;
use tfhe::shortint::Ciphertext;
const MAX_SIZE: usize = 100_000;
pub struct Receiving {
pub buffer: Vec<u8>,
pub future: Request<'static, [u8]>,
}
// impl Drop for Receiving {
// fn drop(&mut self) {
// panic!("Here")
// }
// }
impl Receiving {
pub fn new(process: &Process, tag: Tag) -> Self {
let mut buffer = vec![0; MAX_SIZE];
let future = process
.immediate_receive_into_with_tag(unsafe { transmute(buffer.as_mut_slice()) }, tag);
Self { buffer, future }
}
pub fn abort(self) {
// self.future.cancel();
std::mem::forget(self.future);
}
}
pub fn advance_receiving(receiving: &mut Option<Receiving>) -> Option<Vec<u8>> {
let receiver = receiving.take().unwrap();
match receiver.future.test() {
Ok(_status) => Some(receiver.buffer),
Err(future) => {
*receiving = Some(Receiving {
buffer: receiver.buffer,
future,
});
None
}
}
}
pub struct Sending {
pub buffer: Vec<u8>,
pub a: Option<Request<'static, [u8]>>,
}
impl Sending {
pub fn new(buffer: Vec<u8>, process: &Process, tag: Tag) -> Self {
assert!(buffer.len() < MAX_SIZE);
let a = Some(process.immediate_send_with_tag(unsafe { transmute(buffer.as_slice()) }, tag));
Self { buffer, a }
}
pub fn abort(self) {
// self.a.unwrap().cancel();
std::mem::forget(self.a);
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct IndexedCt {
pub index: usize,
pub ct: Ciphertext,
}

View File

@@ -563,6 +563,17 @@ pub fn lwe_ciphertext_cleartext_mul_assign<Scalar, InCont>(
slice_wrapping_scalar_mul_assign(lhs.as_mut(), rhs.0);
}
pub fn lwe_ciphertext_add_cleartext_mul_assign<Scalar, InCont>(
lhs: &mut LweCiphertext<InCont>,
rhs: &LweCiphertext<InCont>,
scalar: Cleartext<Scalar>,
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
slice_wrapping_add_scalar_mul_assign(lhs.as_mut(), rhs.as_ref(), scalar.0);
}
/// Multiply the left-hand side [`LWE ciphertext`](`LweCiphertext`) by the right-hand side cleartext
/// writing the result in the output [`LWE ciphertext`](`LweCiphertext`).
///

View File

@@ -363,7 +363,7 @@ impl ServerKey {
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[must_use]
pub struct LookupTable<C: Container<Element = u64>> {
pub acc: GlweCiphertext<C>,

View File

@@ -207,6 +207,15 @@ impl ServerKey {
unchecked_scalar_mul_assign(ct, scalar);
}
pub fn unchecked_add_scalar_mul_assign(
&self,
ct: &mut Ciphertext,
scaled_ct: &Ciphertext,
scalar: u8,
) {
unchecked_add_scalar_mul_assign(ct, scaled_ct, scalar);
}
/// Multiply one ciphertext with a scalar in the case the carry space cannot fit the product
/// applying the message space modulus in the process.
///
@@ -535,3 +544,17 @@ pub(crate) fn unchecked_scalar_mul_assign(ct: &mut Ciphertext, scalar: u8) {
}
}
}
pub(crate) fn unchecked_add_scalar_mul_assign(
ct: &mut Ciphertext,
scaled_ct: &Ciphertext,
scalar: u8,
) {
ct.set_noise_level(ct.noise_level() + scaled_ct.noise_level() * scalar as usize);
ct.degree = Degree::new(ct.degree.get() + scaled_ct.degree.get() * scalar as usize);
let scalar = u64::from(scalar);
let cleartext_scalar = Cleartext(scalar);
lwe_ciphertext_add_cleartext_mul_assign(&mut ct.ct, &scaled_ct.ct, cleartext_scalar);
}