mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
refactor(zk): run pke_v2 verification inside dedicated thread pools
Reducing the number of available threads actually improve performance
This commit is contained in:
committed by
Nicolas Sarlin
parent
e7de363d0c
commit
5a62301968
@@ -8,7 +8,9 @@ use crate::serialization::{
|
||||
use core::ops::{Index, IndexMut};
|
||||
use rand::{Rng, RngCore};
|
||||
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
|
||||
use rayon::{ThreadPool, ThreadPoolBuilder};
|
||||
use std::fmt::Display;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
@@ -287,6 +289,85 @@ where
|
||||
|
||||
pub const HASH_METADATA_LEN_BYTES: usize = 256;
|
||||
|
||||
// The verifier is meant to be executed on a large server with a high number of core. However, some
|
||||
// arkworks operations do not scale well in that case, and we actually see decreased performance
|
||||
// for pke_v2 with a higher thread count. This is supposed to be a temporary fix.
|
||||
//
|
||||
// See this issue for more information: https://github.com/arkworks-rs/algebra/issues/976
|
||||
|
||||
/// Number of threads used to run the verification.
|
||||
///
|
||||
/// This value has been determined empirically by running the `pke_v2_verify` benchmark on an aws
|
||||
/// hpc7 96xlarge. Values between 30/50 seem to give good result but 32 is on the lower side (more
|
||||
/// throughput), is a power of 2 and a divisor of 192 (number of CPU cores of the hpc7).
|
||||
const VERIF_MAX_THREADS_COUNT: usize = 32;
|
||||
|
||||
/// Holds a ThreadPool and the number of verification tasks running on it, using the Arc
|
||||
type VerificationPool = Arc<OnceLock<Option<ThreadPool>>>;
|
||||
|
||||
/// The list of pools used for verification
|
||||
static VERIF_POOLS: OnceLock<Vec<VerificationPool>> = OnceLock::new();
|
||||
|
||||
/// Initialize the list of pools, based on the number of available CPU cores
|
||||
fn get_or_init_pools() -> &'static Vec<VerificationPool> {
|
||||
VERIF_POOLS.get_or_init(|| {
|
||||
let total_cores = rayon::current_num_threads();
|
||||
|
||||
// If the number of available cores is smaller than the pool size, default to one pool.
|
||||
let pools_count = total_cores.div_ceil(VERIF_MAX_THREADS_COUNT).max(1);
|
||||
|
||||
(0..pools_count)
|
||||
.map(|_| Arc::new(OnceLock::new()))
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the target function in dedicated rayon threadpool with a limited number of threads.
|
||||
///
|
||||
/// When multiple calls of this function are made in parallel, each of them is executed in a
|
||||
/// dedicated pool, if there is enough free cores on the CPU.
|
||||
fn run_in_pool<OP, R>(f: OP) -> R
|
||||
where
|
||||
OP: FnOnce() -> R + Send,
|
||||
R: Send,
|
||||
{
|
||||
let pools = get_or_init_pools();
|
||||
|
||||
// Select the least loaded pool
|
||||
let mut min_load = usize::MAX;
|
||||
let mut pool_index = 0;
|
||||
|
||||
for (i, pool) in pools.iter().enumerate() {
|
||||
let load = Arc::strong_count(pool);
|
||||
|
||||
if load < min_load {
|
||||
min_load = load;
|
||||
pool_index = i;
|
||||
|
||||
if load == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let selected_pool = pools[pool_index].clone();
|
||||
|
||||
if let Some(pool) = selected_pool
|
||||
.get_or_init(|| {
|
||||
ThreadPoolBuilder::new()
|
||||
.num_threads(VERIF_MAX_THREADS_COUNT)
|
||||
.build()
|
||||
.ok()
|
||||
})
|
||||
.as_ref()
|
||||
{
|
||||
pool.install(f)
|
||||
} else {
|
||||
// If the pool creation failed (ex wasm), we run it in the main threadpool
|
||||
// Since this is just an optimization, it is not worth panicking over
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
pub mod binary;
|
||||
pub mod index;
|
||||
pub mod pke;
|
||||
|
||||
@@ -1829,7 +1829,18 @@ fn compute_a_theta<G: Curve>(
|
||||
}
|
||||
|
||||
#[allow(clippy::result_unit_err)]
|
||||
pub fn verify<G: Curve>(
|
||||
pub fn verify<G: Curve + Send + Sync>(
|
||||
proof: &Proof<G>,
|
||||
public: (&PublicParams<G>, &PublicCommit<G>),
|
||||
metadata: &[u8],
|
||||
) -> Result<(), ()> {
|
||||
// By running it in a limited thread pool, we make sure that the rayon overhead stays minimal
|
||||
// compared to the actual verification work
|
||||
run_in_pool(|| verify_inner(proof, public, metadata))
|
||||
}
|
||||
|
||||
#[allow(clippy::result_unit_err)]
|
||||
pub fn verify_inner<G: Curve>(
|
||||
proof: &Proof<G>,
|
||||
public: (&PublicParams<G>, &PublicCommit<G>),
|
||||
metadata: &[u8],
|
||||
|
||||
Reference in New Issue
Block a user