mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
chore(ci): check ks32 parameters with lattice estimator
A small refactoring has been done to handle ciphertext modulus in a more convenient way.
This commit is contained in:
@@ -4,17 +4,50 @@ use std::io::Write;
|
||||
use std::path::Path;
|
||||
use tfhe::boolean::parameters::{BooleanParameters, VEC_BOOLEAN_PARAM};
|
||||
use tfhe::core_crypto::commons::parameters::{GlweDimension, LweDimension, PolynomialSize};
|
||||
use tfhe::core_crypto::prelude::{DynamicDistribution, TUniform, UnsignedInteger};
|
||||
use tfhe::core_crypto::prelude::{
|
||||
CiphertextModulus, DynamicDistribution, TUniform, UnsignedInteger,
|
||||
};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::current_params::{
|
||||
VEC_ALL_CLASSIC_PBS_PARAMETERS, VEC_ALL_COMPACT_PUBLIC_KEY_ENCRYPTION_PARAMETERS,
|
||||
VEC_ALL_COMPRESSION_PARAMETERS, VEC_ALL_MULTI_BIT_PBS_PARAMETERS,
|
||||
VEC_ALL_NOISE_SQUASHING_PARAMETERS,
|
||||
VEC_ALL_COMPRESSION_PARAMETERS, VEC_ALL_HPU_PARAMETERS, VEC_ALL_KS32_PARAMETERS,
|
||||
VEC_ALL_MULTI_BIT_PBS_PARAMETERS, VEC_ALL_NOISE_SQUASHING_PARAMETERS,
|
||||
};
|
||||
use tfhe::shortint::parameters::{
|
||||
CompactPublicKeyEncryptionParameters, CompressionParameters, NoiseSquashingParameters,
|
||||
ShortintParameterSet,
|
||||
};
|
||||
use tfhe::shortint::AtomicPatternParameters;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Hash)]
|
||||
pub enum ParamModulus {
|
||||
NativeU128,
|
||||
Other(u128),
|
||||
}
|
||||
impl ParamModulus {
|
||||
fn from_ciphertext_modulus<Scalar: UnsignedInteger>(
|
||||
ct_modulus: CiphertextModulus<Scalar>,
|
||||
) -> Self {
|
||||
let scalar_bits = ct_modulus.associated_scalar_bits();
|
||||
assert!(scalar_bits <= 128, "ciphertext modulus is too large");
|
||||
|
||||
if ct_modulus.is_native_modulus() {
|
||||
if scalar_bits == 128 {
|
||||
ParamModulus::NativeU128
|
||||
} else {
|
||||
ParamModulus::Other(1u128.checked_shl(scalar_bits as u32).unwrap())
|
||||
}
|
||||
} else {
|
||||
ParamModulus::Other(ct_modulus.get_custom_modulus())
|
||||
}
|
||||
}
|
||||
|
||||
fn as_f64(&self) -> f64 {
|
||||
match self {
|
||||
Self::NativeU128 => 2.0_f64.powi(128),
|
||||
Self::Other(u) => *u as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ParamDetails<T: UnsignedInteger> {
|
||||
fn lwe_dimension(&self) -> LweDimension;
|
||||
@@ -22,7 +55,8 @@ pub trait ParamDetails<T: UnsignedInteger> {
|
||||
fn lwe_noise_distribution(&self) -> DynamicDistribution<T>;
|
||||
fn glwe_noise_distribution(&self) -> DynamicDistribution<T>;
|
||||
fn polynomial_size(&self) -> PolynomialSize;
|
||||
fn log_ciphertext_modulus(&self) -> usize;
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus;
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus;
|
||||
}
|
||||
|
||||
impl ParamDetails<u32> for BooleanParameters {
|
||||
@@ -45,12 +79,16 @@ impl ParamDetails<u32> for BooleanParameters {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
fn log_ciphertext_modulus(&self) -> usize {
|
||||
32
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::Other(1u128.checked_shl(u32::BITS).unwrap())
|
||||
}
|
||||
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::Other(1u128.checked_shl(u32::BITS).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ParamDetails<u64> for ShortintParameterSet {
|
||||
impl ParamDetails<u64> for AtomicPatternParameters {
|
||||
fn lwe_dimension(&self) -> LweDimension {
|
||||
self.lwe_dimension()
|
||||
}
|
||||
@@ -70,9 +108,17 @@ impl ParamDetails<u64> for ShortintParameterSet {
|
||||
self.polynomial_size()
|
||||
}
|
||||
|
||||
fn log_ciphertext_modulus(&self) -> usize {
|
||||
assert!(self.ciphertext_modulus().is_native_modulus());
|
||||
64
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
match self {
|
||||
Self::Standard(p) => ParamModulus::from_ciphertext_modulus(p.ciphertext_modulus()),
|
||||
Self::KeySwitch32(p) => {
|
||||
ParamModulus::from_ciphertext_modulus(p.post_keyswitch_ciphertext_modulus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::from_ciphertext_modulus(self.ciphertext_modulus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,9 +144,14 @@ impl ParamDetails<u64> for CompactPublicKeyEncryptionParameters {
|
||||
panic!("polynomial_size not applicable for compact public-key encryption parameters")
|
||||
}
|
||||
|
||||
fn log_ciphertext_modulus(&self) -> usize {
|
||||
assert!(self.ciphertext_modulus.is_native_modulus());
|
||||
64
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::from_ciphertext_modulus(self.ciphertext_modulus)
|
||||
}
|
||||
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
panic!(
|
||||
"glwe_ciphertext_modulus not applicable for compact public-key encryption parameters"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,8 +175,12 @@ impl ParamDetails<u64> for CompressionParameters {
|
||||
self.packing_ks_polynomial_size
|
||||
}
|
||||
|
||||
fn log_ciphertext_modulus(&self) -> usize {
|
||||
64
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
panic!("lwe_ciphertext_modulus not applicable for compression parameters")
|
||||
}
|
||||
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::from_ciphertext_modulus(CiphertextModulus::<u64>::new_native())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,9 +205,12 @@ impl ParamDetails<u128> for NoiseSquashingParameters {
|
||||
self.polynomial_size
|
||||
}
|
||||
|
||||
fn log_ciphertext_modulus(&self) -> usize {
|
||||
assert!(self.ciphertext_modulus.is_native_modulus());
|
||||
u128::BITS as usize
|
||||
fn lwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
panic!("lwe_ciphertext_modulus not applicable for NoiseSquashingParameters")
|
||||
}
|
||||
|
||||
fn glwe_ciphertext_modulus(&self) -> ParamModulus {
|
||||
ParamModulus::from_ciphertext_modulus(self.ciphertext_modulus)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,27 +222,26 @@ enum ParametersFormat {
|
||||
}
|
||||
|
||||
type NoiseDistributionString = String;
|
||||
type LogCiphertextModulus = usize;
|
||||
|
||||
#[derive(Eq, PartialEq, Hash)]
|
||||
struct ParamGroupKey {
|
||||
lwe_dimension: LweDimension,
|
||||
log_ciphertext_modulus: LogCiphertextModulus,
|
||||
ciphertext_modulus: ParamModulus,
|
||||
noise_distribution: NoiseDistributionString,
|
||||
// TODO might not need to be hashed since LWE and GLWE share the same security check
|
||||
parameters_format: ParametersFormat,
|
||||
}
|
||||
|
||||
fn format_modulus_as_string(log_ciphertext_modulus: usize) -> String {
|
||||
if log_ciphertext_modulus > 128 {
|
||||
panic!("Exponent too large");
|
||||
fn format_modulus_as_string(modulus: ParamModulus) -> String {
|
||||
match modulus {
|
||||
ParamModulus::NativeU128 => {
|
||||
// What are you gonna do, call the police ?
|
||||
"340282366920938463463374607431768211456".to_string()
|
||||
}
|
||||
ParamModulus::Other(u) => {
|
||||
format!("{u}")
|
||||
}
|
||||
}
|
||||
if log_ciphertext_modulus == 128 {
|
||||
// What are you gonna do, call the police ?
|
||||
return "340282366920938463463374607431768211456".to_string();
|
||||
}
|
||||
|
||||
format!("{}", 1u128 << log_ciphertext_modulus)
|
||||
}
|
||||
|
||||
///Function to print in the lattice_estimator format the parameters
|
||||
@@ -197,14 +254,14 @@ pub fn format_lwe_parameters_to_lattice_estimator<U: UnsignedInteger, T: ParamDe
|
||||
match param.lwe_noise_distribution() {
|
||||
DynamicDistribution::Gaussian(distrib) => {
|
||||
let modular_std_dev =
|
||||
param.log_ciphertext_modulus() as f64 + distrib.standard_dev().0.log2();
|
||||
(param.lwe_ciphertext_modulus().as_f64() * distrib.standard_dev().0).log2();
|
||||
|
||||
format!(
|
||||
"{}_LWE = LWE.Parameters(\n n = {},\n q ={},\n Xs=ND.Uniform(0,1), \n \
|
||||
Xe=ND.DiscreteGaussian({}),\n tag=('{}_lwe',) \n)\n\n",
|
||||
name,
|
||||
param.lwe_dimension().0,
|
||||
format_modulus_as_string(param.log_ciphertext_modulus()),
|
||||
format_modulus_as_string(param.lwe_ciphertext_modulus()),
|
||||
2.0_f64.powf(modular_std_dev),
|
||||
similar_params.join("_lwe', '")
|
||||
)
|
||||
@@ -215,7 +272,7 @@ pub fn format_lwe_parameters_to_lattice_estimator<U: UnsignedInteger, T: ParamDe
|
||||
Xe=ND.DiscreteGaussian({}),\n tag=('{}_lwe',) \n)\n\n",
|
||||
name,
|
||||
param.lwe_dimension().0,
|
||||
format_modulus_as_string(param.log_ciphertext_modulus()),
|
||||
format_modulus_as_string(param.lwe_ciphertext_modulus()),
|
||||
tuniform_equivalent_gaussian_std_dev(&distrib),
|
||||
similar_params.join("_lwe', '")
|
||||
)
|
||||
@@ -233,7 +290,7 @@ pub fn format_glwe_parameters_to_lattice_estimator<U: UnsignedInteger, T: ParamD
|
||||
match param.glwe_noise_distribution() {
|
||||
DynamicDistribution::Gaussian(distrib) => {
|
||||
let modular_std_dev =
|
||||
param.log_ciphertext_modulus() as f64 + distrib.standard_dev().0.log2();
|
||||
(param.glwe_ciphertext_modulus().as_f64() * distrib.standard_dev().0).log2();
|
||||
|
||||
format!(
|
||||
"{}_GLWE = LWE.Parameters(\n n = {},\n q = {},\n Xs=ND.Uniform(0,1), \n \
|
||||
@@ -243,7 +300,7 @@ pub fn format_glwe_parameters_to_lattice_estimator<U: UnsignedInteger, T: ParamD
|
||||
.glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(param.polynomial_size())
|
||||
.0,
|
||||
format_modulus_as_string(param.log_ciphertext_modulus()),
|
||||
format_modulus_as_string(param.glwe_ciphertext_modulus()),
|
||||
2.0_f64.powf(modular_std_dev),
|
||||
similar_params.join("_glwe', '")
|
||||
)
|
||||
@@ -257,7 +314,7 @@ pub fn format_glwe_parameters_to_lattice_estimator<U: UnsignedInteger, T: ParamD
|
||||
.glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(param.polynomial_size())
|
||||
.0,
|
||||
format_modulus_as_string(param.log_ciphertext_modulus()),
|
||||
format_modulus_as_string(param.glwe_ciphertext_modulus()),
|
||||
tuniform_equivalent_gaussian_std_dev(&distrib),
|
||||
similar_params.join("_glwe', '")
|
||||
)
|
||||
@@ -294,7 +351,7 @@ fn write_all_params_in_file<U: UnsignedInteger, T: ParamDetails<U> + Copy + Name
|
||||
ParametersFormat::LweGlwe => vec![
|
||||
ParamGroupKey {
|
||||
lwe_dimension: params.lwe_dimension(),
|
||||
log_ciphertext_modulus: params.log_ciphertext_modulus(),
|
||||
ciphertext_modulus: params.lwe_ciphertext_modulus(),
|
||||
noise_distribution: params.lwe_noise_distribution().to_string(),
|
||||
parameters_format: ParametersFormat::Lwe,
|
||||
},
|
||||
@@ -302,14 +359,14 @@ fn write_all_params_in_file<U: UnsignedInteger, T: ParamDetails<U> + Copy + Name
|
||||
lwe_dimension: params
|
||||
.glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(params.polynomial_size()),
|
||||
log_ciphertext_modulus: params.log_ciphertext_modulus(),
|
||||
ciphertext_modulus: params.glwe_ciphertext_modulus(),
|
||||
noise_distribution: params.glwe_noise_distribution().to_string(),
|
||||
parameters_format: ParametersFormat::Glwe,
|
||||
},
|
||||
],
|
||||
ParametersFormat::Lwe => vec![ParamGroupKey {
|
||||
lwe_dimension: params.lwe_dimension(),
|
||||
log_ciphertext_modulus: params.log_ciphertext_modulus(),
|
||||
ciphertext_modulus: params.lwe_ciphertext_modulus(),
|
||||
noise_distribution: params.lwe_noise_distribution().to_string(),
|
||||
parameters_format: ParametersFormat::Lwe,
|
||||
}],
|
||||
@@ -317,7 +374,7 @@ fn write_all_params_in_file<U: UnsignedInteger, T: ParamDetails<U> + Copy + Name
|
||||
lwe_dimension: params
|
||||
.glwe_dimension()
|
||||
.to_equivalent_lwe_dimension(params.polynomial_size()),
|
||||
log_ciphertext_modulus: params.log_ciphertext_modulus(),
|
||||
ciphertext_modulus: params.glwe_ciphertext_modulus(),
|
||||
noise_distribution: params.glwe_noise_distribution().to_string(),
|
||||
parameters_format: ParametersFormat::Glwe,
|
||||
}],
|
||||
@@ -388,7 +445,7 @@ fn main() {
|
||||
|
||||
let classic_pbs: Vec<_> = VEC_ALL_CLASSIC_PBS_PARAMETERS
|
||||
.into_iter()
|
||||
.map(|p| (ShortintParameterSet::from(*p.0), Some(p.1)))
|
||||
.map(|p| (AtomicPatternParameters::from(*p.0), Some(p.1)))
|
||||
.collect();
|
||||
write_all_params_in_file(
|
||||
"shortint_classic_parameters_lattice_estimator.sage",
|
||||
@@ -398,7 +455,7 @@ fn main() {
|
||||
|
||||
let multi_bit_pbs: Vec<_> = VEC_ALL_MULTI_BIT_PBS_PARAMETERS
|
||||
.into_iter()
|
||||
.map(|p| (ShortintParameterSet::from(*p.0), Some(p.1)))
|
||||
.map(|p| (AtomicPatternParameters::from(*p.0), Some(p.1)))
|
||||
.collect();
|
||||
write_all_params_in_file(
|
||||
"shortint_multi_bit_parameters_lattice_estimator.sage",
|
||||
@@ -436,13 +493,23 @@ fn main() {
|
||||
ParametersFormat::Glwe,
|
||||
);
|
||||
|
||||
// TODO perform this gathering later
|
||||
// let wopbs = ALL_PARAMETER_VEC_WOPBS
|
||||
// .iter()
|
||||
// .map(|p| ShortintParameterSet::from(*p))
|
||||
// .collect::<Vec<_>>();
|
||||
// write_all_params_in_file(
|
||||
// "shortint_wopbs_parameters_lattice_estimator.sage",
|
||||
// &wopbs,
|
||||
// );
|
||||
let ks32_params: Vec<_> = VEC_ALL_KS32_PARAMETERS
|
||||
.into_iter()
|
||||
.map(|p| (AtomicPatternParameters::from(*p.0), Some(p.1)))
|
||||
.collect();
|
||||
write_all_params_in_file(
|
||||
"shortint_ks32_parameters_lattice_estimator.sage",
|
||||
&ks32_params,
|
||||
ParametersFormat::LweGlwe,
|
||||
);
|
||||
|
||||
let hpu_params: Vec<_> = VEC_ALL_HPU_PARAMETERS
|
||||
.into_iter()
|
||||
.map(|p| (AtomicPatternParameters::from(*p.0), Some(p.1)))
|
||||
.collect();
|
||||
write_all_params_in_file(
|
||||
"shortint_hpu_parameters_lattice_estimator.sage",
|
||||
&hpu_params,
|
||||
ParametersFormat::LweGlwe,
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user