refactor(core): factorize plan map management

This commit is contained in:
Mayeul@Zama
2025-12-11 14:23:06 +01:00
parent b7a706a3db
commit aa6284314b
5 changed files with 76 additions and 107 deletions

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::plan::new_from_plan_map;
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::prelude::*;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use tfhe_ntt::prime64::Plan;
@@ -24,8 +24,10 @@ impl Ntt64 {
}
// Key is (polynomial size, modulus).
type PlanMap = RwLock<HashMap<(usize, u64), Arc<OnceLock<Arc<Plan>>>>>;
type PlanMap = crate::core_crypto::commons::plan::PlanMap<(usize, u64), Plan>;
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
}
@@ -39,39 +41,16 @@ impl Ntt64 {
let n = size.0;
let modulus = modulus.get_custom_modulus() as u64;
let get_plan = || {
let plans = global_plans.read().unwrap();
let plan = plans.get(&(n, modulus)).cloned();
drop(plans);
plan.map(|p| {
p.get_or_init(|| {
Arc::new(Plan::try_new(n, modulus).unwrap_or_else(|| {
panic!("could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})")
}))
})
.clone()
let plan = new_from_plan_map(global_plans, (n, modulus), |(n, modulus)| {
Plan::try_new(n, modulus).unwrap_or_else(|| {
panic!(
"could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})"
)
})
};
});
get_plan().map_or_else(
|| {
// If we don't find a plan for the given polynomial size and modulus, we insert a
// new OnceLock, drop the write lock on the map and then let
// get_plan() initialize the OnceLock (without holding the write
// lock on the map).
let mut plans = global_plans.write().unwrap();
if let Entry::Vacant(v) = plans.entry((n, modulus)) {
v.insert(Arc::new(OnceLock::new()));
}
drop(plans);
Self {
plan: get_plan().unwrap(),
}
},
|plan| Self { plan },
)
Self { plan }
}
}

View File

@@ -17,6 +17,7 @@ pub mod math;
pub mod noise_formulas;
pub mod numeric;
pub mod parameters;
pub mod plan;
pub mod utils;
// Refactor modules

View File

@@ -0,0 +1,33 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, OnceLock, RwLock};
pub type PlanMap<Key, Value> = RwLock<HashMap<Key, Arc<OnceLock<Arc<Value>>>>>;
pub fn new_from_plan_map<Key: Eq + Hash + Copy, Value>(
values_map: &PlanMap<Key, Value>,
key: Key,
new_value: impl Fn(Key) -> Value,
) -> Arc<Value> {
let get_plan = || {
let plans = values_map.read().unwrap();
let plan = plans.get(&key).cloned();
drop(plans);
plan.map(|p| p.get_or_init(|| Arc::new(new_value(key))).clone())
};
get_plan().unwrap_or_else(|| {
// If we don't find a plan for the given size, we insert a new OnceLock,
// drop the write lock on the map and then let get_plan() initialize the OnceLock
// (without holding the write lock on the map).
let mut plans = values_map.write().unwrap();
if let Entry::Vacant(v) = plans.entry(key) {
v.insert(Arc::new(OnceLock::new()));
}
drop(plans);
get_plan().unwrap()
})
}

View File

@@ -1,10 +1,10 @@
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
use crate::core_crypto::commons::parameters::PolynomialSize;
use crate::core_crypto::commons::plan::new_from_plan_map;
use crate::core_crypto::commons::utils::izip_eq;
use core::any::TypeId;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use tfhe_fft::fft128::{f128, Plan};
@@ -46,7 +46,8 @@ impl Fft128 {
}
}
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<PlanWrapper>>>>>;
type PlanMap = crate::core_crypto::commons::plan::PlanMap<usize, PlanWrapper>;
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
@@ -58,34 +59,10 @@ impl Fft128 {
let global_plans = plans();
let n = size.0;
let get_plan = || {
let plans = global_plans.read().unwrap();
let plan = plans.get(&n).cloned();
drop(plans);
plan.map(|p| {
p.get_or_init(|| Arc::new(PlanWrapper(Plan::new(n / 2))))
.clone()
})
};
let plan = new_from_plan_map(global_plans, n, |n| PlanWrapper(Plan::new(n / 2)));
get_plan().map_or_else(
|| {
// If we don't find a plan for the given size, we insert a new OnceLock,
// drop the write lock on the map and then let get_plan() initialize the OnceLock
// (without holding the write lock on the map).
let mut plans = global_plans.write().unwrap();
if let Entry::Vacant(v) = plans.entry(n) {
v.insert(Arc::new(OnceLock::new()));
}
drop(plans);
Self {
plan: get_plan().unwrap(),
}
},
|plan| Self { plan },
)
Self { plan }
}
}

View File

@@ -5,6 +5,7 @@ use crate::core_crypto::backward_compatibility::fft_impl::{
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize};
use crate::core_crypto::commons::plan::new_from_plan_map;
use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainerOwned};
use crate::core_crypto::commons::utils::izip_eq;
use crate::core_crypto::entities::*;
@@ -12,7 +13,6 @@ use aligned_vec::{avec, ABox};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use rayon::prelude::*;
use std::any::TypeId;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::mem::{align_of, size_of};
use std::sync::{Arc, OnceLock, RwLock};
@@ -99,7 +99,8 @@ impl Fft {
}
}
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<(Twisties, Plan)>>>>>;
type PlanMap = crate::core_crypto::commons::plan::PlanMap<usize, (Twisties, Plan)>;
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
@@ -151,55 +152,33 @@ impl Fft {
let global_plans = plans();
let n = size.0;
let get_plan = || {
let plans = global_plans.read().unwrap();
let plan = plans.get(&n).cloned();
drop(plans);
plan.map(|p| {
p.get_or_init(|| {
#[cfg(not(feature = "experimental-force_fft_algo_dif4"))]
{
Arc::new((
Twisties::new(n / 2),
Plan::new(n / 2, Method::Measure(Duration::from_millis(10))),
))
}
#[cfg(feature = "experimental-force_fft_algo_dif4")]
{
Arc::new((
Twisties::new(n / 2),
Plan::new(
n / 2,
Method::UserProvided {
base_algo: tfhe_fft::ordered::FftAlgo::Dif4,
base_n: n / 2,
},
),
))
}
})
.clone()
})
let new = |n| {
#[cfg(not(feature = "experimental-force_fft_algo_dif4"))]
{
(
Twisties::new(n / 2),
Plan::new(n / 2, Method::Measure(Duration::from_millis(10))),
)
}
#[cfg(feature = "experimental-force_fft_algo_dif4")]
{
(
Twisties::new(n / 2),
Plan::new(
n / 2,
Method::UserProvided {
base_algo: tfhe_fft::ordered::FftAlgo::Dif4,
base_n: n / 2,
},
),
)
}
};
get_plan().map_or_else(
|| {
// If we don't find a plan for the given size, we insert a new OnceLock,
// drop the write lock on the map and then let get_plan() initialize the OnceLock
// (without holding the write lock on the map).
let mut plans = global_plans.write().unwrap();
if let Entry::Vacant(v) = plans.entry(n) {
v.insert(Arc::new(OnceLock::new()));
}
drop(plans);
let plan = new_from_plan_map(global_plans, n, new);
Self {
plan: get_plan().unwrap(),
}
},
|plan| Self { plan },
)
Self { plan }
}
}