mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
refactor(core): factorize plan map management
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::plan::new_from_plan_map;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::prelude::*;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
use tfhe_ntt::prime64::Plan;
|
||||
@@ -24,8 +24,10 @@ impl Ntt64 {
|
||||
}
|
||||
|
||||
// Key is (polynomial size, modulus).
|
||||
type PlanMap = RwLock<HashMap<(usize, u64), Arc<OnceLock<Arc<Plan>>>>>;
|
||||
type PlanMap = crate::core_crypto::commons::plan::PlanMap<(usize, u64), Plan>;
|
||||
|
||||
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
|
||||
|
||||
fn plans() -> &'static PlanMap {
|
||||
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
@@ -39,39 +41,16 @@ impl Ntt64 {
|
||||
|
||||
let n = size.0;
|
||||
let modulus = modulus.get_custom_modulus() as u64;
|
||||
let get_plan = || {
|
||||
let plans = global_plans.read().unwrap();
|
||||
let plan = plans.get(&(n, modulus)).cloned();
|
||||
drop(plans);
|
||||
|
||||
plan.map(|p| {
|
||||
p.get_or_init(|| {
|
||||
Arc::new(Plan::try_new(n, modulus).unwrap_or_else(|| {
|
||||
panic!("could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})")
|
||||
}))
|
||||
})
|
||||
.clone()
|
||||
let plan = new_from_plan_map(global_plans, (n, modulus), |(n, modulus)| {
|
||||
Plan::try_new(n, modulus).unwrap_or_else(|| {
|
||||
panic!(
|
||||
"could not generate an NTT plan for the given (size, modulus) ({n}, {modulus})"
|
||||
)
|
||||
})
|
||||
};
|
||||
});
|
||||
|
||||
get_plan().map_or_else(
|
||||
|| {
|
||||
// If we don't find a plan for the given polynomial size and modulus, we insert a
|
||||
// new OnceLock, drop the write lock on the map and then let
|
||||
// get_plan() initialize the OnceLock (without holding the write
|
||||
// lock on the map).
|
||||
let mut plans = global_plans.write().unwrap();
|
||||
if let Entry::Vacant(v) = plans.entry((n, modulus)) {
|
||||
v.insert(Arc::new(OnceLock::new()));
|
||||
}
|
||||
drop(plans);
|
||||
|
||||
Self {
|
||||
plan: get_plan().unwrap(),
|
||||
}
|
||||
},
|
||||
|plan| Self { plan },
|
||||
)
|
||||
Self { plan }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ pub mod math;
|
||||
pub mod noise_formulas;
|
||||
pub mod numeric;
|
||||
pub mod parameters;
|
||||
pub mod plan;
|
||||
pub mod utils;
|
||||
|
||||
// Refactor modules
|
||||
|
||||
33
tfhe/src/core_crypto/commons/plan.rs
Normal file
33
tfhe/src/core_crypto/commons/plan.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
|
||||
pub type PlanMap<Key, Value> = RwLock<HashMap<Key, Arc<OnceLock<Arc<Value>>>>>;
|
||||
|
||||
pub fn new_from_plan_map<Key: Eq + Hash + Copy, Value>(
|
||||
values_map: &PlanMap<Key, Value>,
|
||||
key: Key,
|
||||
new_value: impl Fn(Key) -> Value,
|
||||
) -> Arc<Value> {
|
||||
let get_plan = || {
|
||||
let plans = values_map.read().unwrap();
|
||||
let plan = plans.get(&key).cloned();
|
||||
drop(plans);
|
||||
|
||||
plan.map(|p| p.get_or_init(|| Arc::new(new_value(key))).clone())
|
||||
};
|
||||
|
||||
get_plan().unwrap_or_else(|| {
|
||||
// If we don't find a plan for the given size, we insert a new OnceLock,
|
||||
// drop the write lock on the map and then let get_plan() initialize the OnceLock
|
||||
// (without holding the write lock on the map).
|
||||
let mut plans = values_map.write().unwrap();
|
||||
if let Entry::Vacant(v) = plans.entry(key) {
|
||||
v.insert(Arc::new(OnceLock::new()));
|
||||
}
|
||||
drop(plans);
|
||||
|
||||
get_plan().unwrap()
|
||||
})
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{CastFrom, CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::PolynomialSize;
|
||||
use crate::core_crypto::commons::plan::new_from_plan_map;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use core::any::TypeId;
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
use tfhe_fft::fft128::{f128, Plan};
|
||||
@@ -46,7 +46,8 @@ impl Fft128 {
|
||||
}
|
||||
}
|
||||
|
||||
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<PlanWrapper>>>>>;
|
||||
type PlanMap = crate::core_crypto::commons::plan::PlanMap<usize, PlanWrapper>;
|
||||
|
||||
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
|
||||
fn plans() -> &'static PlanMap {
|
||||
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
@@ -58,34 +59,10 @@ impl Fft128 {
|
||||
let global_plans = plans();
|
||||
|
||||
let n = size.0;
|
||||
let get_plan = || {
|
||||
let plans = global_plans.read().unwrap();
|
||||
let plan = plans.get(&n).cloned();
|
||||
drop(plans);
|
||||
|
||||
plan.map(|p| {
|
||||
p.get_or_init(|| Arc::new(PlanWrapper(Plan::new(n / 2))))
|
||||
.clone()
|
||||
})
|
||||
};
|
||||
let plan = new_from_plan_map(global_plans, n, |n| PlanWrapper(Plan::new(n / 2)));
|
||||
|
||||
get_plan().map_or_else(
|
||||
|| {
|
||||
// If we don't find a plan for the given size, we insert a new OnceLock,
|
||||
// drop the write lock on the map and then let get_plan() initialize the OnceLock
|
||||
// (without holding the write lock on the map).
|
||||
let mut plans = global_plans.write().unwrap();
|
||||
if let Entry::Vacant(v) = plans.entry(n) {
|
||||
v.insert(Arc::new(OnceLock::new()));
|
||||
}
|
||||
drop(plans);
|
||||
|
||||
Self {
|
||||
plan: get_plan().unwrap(),
|
||||
}
|
||||
},
|
||||
|plan| Self { plan },
|
||||
)
|
||||
Self { plan }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::core_crypto::backward_compatibility::fft_impl::{
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::CastInto;
|
||||
use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize};
|
||||
use crate::core_crypto::commons::plan::new_from_plan_map;
|
||||
use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainerOwned};
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -12,7 +13,6 @@ use aligned_vec::{avec, ABox};
|
||||
use dyn_stack::{PodStack, SizeOverflow, StackReq};
|
||||
use rayon::prelude::*;
|
||||
use std::any::TypeId;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::mem::{align_of, size_of};
|
||||
use std::sync::{Arc, OnceLock, RwLock};
|
||||
@@ -99,7 +99,8 @@ impl Fft {
|
||||
}
|
||||
}
|
||||
|
||||
type PlanMap = RwLock<HashMap<usize, Arc<OnceLock<Arc<(Twisties, Plan)>>>>>;
|
||||
type PlanMap = crate::core_crypto::commons::plan::PlanMap<usize, (Twisties, Plan)>;
|
||||
|
||||
pub(crate) static PLANS: OnceLock<PlanMap> = OnceLock::new();
|
||||
fn plans() -> &'static PlanMap {
|
||||
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
@@ -151,55 +152,33 @@ impl Fft {
|
||||
let global_plans = plans();
|
||||
|
||||
let n = size.0;
|
||||
let get_plan = || {
|
||||
let plans = global_plans.read().unwrap();
|
||||
let plan = plans.get(&n).cloned();
|
||||
drop(plans);
|
||||
|
||||
plan.map(|p| {
|
||||
p.get_or_init(|| {
|
||||
#[cfg(not(feature = "experimental-force_fft_algo_dif4"))]
|
||||
{
|
||||
Arc::new((
|
||||
Twisties::new(n / 2),
|
||||
Plan::new(n / 2, Method::Measure(Duration::from_millis(10))),
|
||||
))
|
||||
}
|
||||
#[cfg(feature = "experimental-force_fft_algo_dif4")]
|
||||
{
|
||||
Arc::new((
|
||||
Twisties::new(n / 2),
|
||||
Plan::new(
|
||||
n / 2,
|
||||
Method::UserProvided {
|
||||
base_algo: tfhe_fft::ordered::FftAlgo::Dif4,
|
||||
base_n: n / 2,
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
})
|
||||
.clone()
|
||||
})
|
||||
let new = |n| {
|
||||
#[cfg(not(feature = "experimental-force_fft_algo_dif4"))]
|
||||
{
|
||||
(
|
||||
Twisties::new(n / 2),
|
||||
Plan::new(n / 2, Method::Measure(Duration::from_millis(10))),
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "experimental-force_fft_algo_dif4")]
|
||||
{
|
||||
(
|
||||
Twisties::new(n / 2),
|
||||
Plan::new(
|
||||
n / 2,
|
||||
Method::UserProvided {
|
||||
base_algo: tfhe_fft::ordered::FftAlgo::Dif4,
|
||||
base_n: n / 2,
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
get_plan().map_or_else(
|
||||
|| {
|
||||
// If we don't find a plan for the given size, we insert a new OnceLock,
|
||||
// drop the write lock on the map and then let get_plan() initialize the OnceLock
|
||||
// (without holding the write lock on the map).
|
||||
let mut plans = global_plans.write().unwrap();
|
||||
if let Entry::Vacant(v) = plans.entry(n) {
|
||||
v.insert(Arc::new(OnceLock::new()));
|
||||
}
|
||||
drop(plans);
|
||||
let plan = new_from_plan_map(global_plans, n, new);
|
||||
|
||||
Self {
|
||||
plan: get_plan().unwrap(),
|
||||
}
|
||||
},
|
||||
|plan| Self { plan },
|
||||
)
|
||||
Self { plan }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user