From aa6284314b7c6ffcbb38e63ea51d4fb022ac2b09 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Thu, 11 Dec 2025 14:23:06 +0100 Subject: [PATCH] refactor(core): factorize plan map management --- .../src/core_crypto/commons/math/ntt/ntt64.rs | 43 +++-------- tfhe/src/core_crypto/commons/mod.rs | 1 + tfhe/src/core_crypto/commons/plan.rs | 33 +++++++++ .../fft_impl/fft128/math/fft/mod.rs | 33 ++------- .../fft_impl/fft64/math/fft/mod.rs | 73 +++++++------------ 5 files changed, 76 insertions(+), 107 deletions(-) create mode 100644 tfhe/src/core_crypto/commons/plan.rs diff --git a/tfhe/src/core_crypto/commons/math/ntt/ntt64.rs b/tfhe/src/core_crypto/commons/math/ntt/ntt64.rs index 2eaf09deb..5ba0fa83d 100644 --- a/tfhe/src/core_crypto/commons/math/ntt/ntt64.rs +++ b/tfhe/src/core_crypto/commons/math/ntt/ntt64.rs @@ -1,7 +1,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind; +use crate::core_crypto::commons::plan::new_from_plan_map; use crate::core_crypto::commons::utils::izip_eq; use crate::core_crypto::prelude::*; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use tfhe_ntt::prime64::Plan; @@ -24,8 +24,10 @@ impl Ntt64 { } // Key is (polynomial size, modulus). -type PlanMap = RwLock>>>>; +type PlanMap = crate::core_crypto::commons::plan::PlanMap<(usize, u64), Plan>; + pub(crate) static PLANS: OnceLock = OnceLock::new(); + fn plans() -> &'static PlanMap { PLANS.get_or_init(|| RwLock::new(HashMap::new())) } @@ -39,39 +41,16 @@ impl Ntt64 { let n = size.0; let modulus = modulus.get_custom_modulus() as u64; - let get_plan = || { - let plans = global_plans.read().unwrap(); - let plan = plans.get(&(n, modulus)).cloned(); - drop(plans); - plan.map(|p| { - p.get_or_init(|| { - Arc::new(Plan::try_new(n, modulus).unwrap_or_else(|| { - panic!("could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})") - })) - }) - .clone() + let plan = new_from_plan_map(global_plans, (n, modulus), |(n, modulus)| { + Plan::try_new(n, modulus).unwrap_or_else(|| { + panic!( + "could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})" + ) }) - }; + }); - get_plan().map_or_else( - || { - // If we don't find a plan for the given polynomial size and modulus, we insert a - // new OnceLock, drop the write lock on the map and then let - // get_plan() initialize the OnceLock (without holding the write - // lock on the map). - let mut plans = global_plans.write().unwrap(); - if let Entry::Vacant(v) = plans.entry((n, modulus)) { - v.insert(Arc::new(OnceLock::new())); - } - drop(plans); - - Self { - plan: get_plan().unwrap(), - } - }, - |plan| Self { plan }, - ) + Self { plan } } } diff --git a/tfhe/src/core_crypto/commons/mod.rs b/tfhe/src/core_crypto/commons/mod.rs index 140a63ac3..120850fd2 100644 --- a/tfhe/src/core_crypto/commons/mod.rs +++ b/tfhe/src/core_crypto/commons/mod.rs @@ -17,6 +17,7 @@ pub mod math; pub mod noise_formulas; pub mod numeric; pub mod parameters; +pub mod plan; pub mod utils; // Refactor modules diff --git a/tfhe/src/core_crypto/commons/plan.rs b/tfhe/src/core_crypto/commons/plan.rs new file mode 100644 index 000000000..ded3f0ccc --- /dev/null +++ b/tfhe/src/core_crypto/commons/plan.rs @@ -0,0 +1,33 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::{Arc, OnceLock, RwLock}; + +pub type PlanMap = RwLock>>>>; + +pub fn new_from_plan_map( + values_map: &PlanMap, + key: Key, + new_value: impl Fn(Key) -> Value, +) -> Arc { + let get_plan = || { + let plans = values_map.read().unwrap(); + let plan = plans.get(&key).cloned(); + drop(plans); + + plan.map(|p| p.get_or_init(|| Arc::new(new_value(key))).clone()) + }; + + get_plan().unwrap_or_else(|| { + // If we don't find a plan for the given size, we insert a new OnceLock, + // drop the write lock on the map and then let get_plan() initialize the OnceLock + // (without holding the write lock on the map). + let mut plans = values_map.write().unwrap(); + if let Entry::Vacant(v) = plans.entry(key) { + v.insert(Arc::new(OnceLock::new())); + } + drop(plans); + + get_plan().unwrap() + }) +} diff --git a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs index 47d5135c7..475abf1b9 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/math/fft/mod.rs @@ -1,10 +1,10 @@ use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger}; use crate::core_crypto::commons::parameters::PolynomialSize; +use crate::core_crypto::commons::plan::new_from_plan_map; use crate::core_crypto::commons::utils::izip_eq; use core::any::TypeId; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use tfhe_fft::fft128::{f128, Plan}; @@ -46,7 +46,8 @@ impl Fft128 { } } -type PlanMap = RwLock>>>>; +type PlanMap = crate::core_crypto::commons::plan::PlanMap; + pub(crate) static PLANS: OnceLock = OnceLock::new(); fn plans() -> &'static PlanMap { PLANS.get_or_init(|| RwLock::new(HashMap::new())) @@ -58,34 +59,10 @@ impl Fft128 { let global_plans = plans(); let n = size.0; - let get_plan = || { - let plans = global_plans.read().unwrap(); - let plan = plans.get(&n).cloned(); - drop(plans); - plan.map(|p| { - p.get_or_init(|| Arc::new(PlanWrapper(Plan::new(n / 2)))) - .clone() - }) - }; + let plan = new_from_plan_map(global_plans, n, |n| PlanWrapper(Plan::new(n / 2))); - get_plan().map_or_else( - || { - // If we don't find a plan for the given size, we insert a new OnceLock, - // drop the write lock on the map and then let get_plan() initialize the OnceLock - // (without holding the write lock on the map). - let mut plans = global_plans.write().unwrap(); - if let Entry::Vacant(v) = plans.entry(n) { - v.insert(Arc::new(OnceLock::new())); - } - drop(plans); - - Self { - plan: get_plan().unwrap(), - } - }, - |plan| Self { plan }, - ) + Self { plan } } } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs index 2cb864c98..ff8e6686a 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs @@ -5,6 +5,7 @@ use crate::core_crypto::backward_compatibility::fft_impl::{ use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::CastInto; use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize}; +use crate::core_crypto::commons::plan::new_from_plan_map; use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainerOwned}; use crate::core_crypto::commons::utils::izip_eq; use crate::core_crypto::entities::*; @@ -12,7 +13,6 @@ use aligned_vec::{avec, ABox}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; use rayon::prelude::*; use std::any::TypeId; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::mem::{align_of, size_of}; use std::sync::{Arc, OnceLock, RwLock}; @@ -99,7 +99,8 @@ impl Fft { } } -type PlanMap = RwLock>>>>; +type PlanMap = crate::core_crypto::commons::plan::PlanMap; + pub(crate) static PLANS: OnceLock = OnceLock::new(); fn plans() -> &'static PlanMap { PLANS.get_or_init(|| RwLock::new(HashMap::new())) @@ -151,55 +152,33 @@ impl Fft { let global_plans = plans(); let n = size.0; - let get_plan = || { - let plans = global_plans.read().unwrap(); - let plan = plans.get(&n).cloned(); - drop(plans); - plan.map(|p| { - p.get_or_init(|| { - #[cfg(not(feature = "experimental-force_fft_algo_dif4"))] - { - Arc::new(( - Twisties::new(n / 2), - Plan::new(n / 2, Method::Measure(Duration::from_millis(10))), - )) - } - #[cfg(feature = "experimental-force_fft_algo_dif4")] - { - Arc::new(( - Twisties::new(n / 2), - Plan::new( - n / 2, - Method::UserProvided { - base_algo: tfhe_fft::ordered::FftAlgo::Dif4, - base_n: n / 2, - }, - ), - )) - } - }) - .clone() - }) + let new = |n| { + #[cfg(not(feature = "experimental-force_fft_algo_dif4"))] + { + ( + Twisties::new(n / 2), + Plan::new(n / 2, Method::Measure(Duration::from_millis(10))), + ) + } + #[cfg(feature = "experimental-force_fft_algo_dif4")] + { + ( + Twisties::new(n / 2), + Plan::new( + n / 2, + Method::UserProvided { + base_algo: tfhe_fft::ordered::FftAlgo::Dif4, + base_n: n / 2, + }, + ), + ) + } }; - get_plan().map_or_else( - || { - // If we don't find a plan for the given size, we insert a new OnceLock, - // drop the write lock on the map and then let get_plan() initialize the OnceLock - // (without holding the write lock on the map). - let mut plans = global_plans.write().unwrap(); - if let Entry::Vacant(v) = plans.entry(n) { - v.insert(Arc::new(OnceLock::new())); - } - drop(plans); + let plan = new_from_plan_map(global_plans, n, new); - Self { - plan: get_plan().unwrap(), - } - }, - |plan| Self { plan }, - ) + Self { plan } } }