chore: add backward multi bit decompression keys

This commit is contained in:
Mayeul@Zama
2025-10-17 15:04:45 +02:00
committed by mayeul-zama
parent 92dcd38e30
commit 859d5e4e1f
12 changed files with 477 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
[package]
name = "generate_1_5"
edition = "2024"
license.workspace = true
version.workspace = true
[dependencies]
clap.workspace = true
# TFHE-rs
tfhe = { features = [
"boolean",
"integer",
"shortint",
"zk-pok",
"experimental-force_fft_algo_dif4",
], path = "../../../../tfhe" }
tfhe-versionable = { path = "../../../tfhe-versionable" }
# Uncomment this and remove the lines above once the current tfhe-rs version has been released
# tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", tag = "tfhe-rs-1.5.0", features = [
# "boolean",
# "integer",
# "shortint",
# "zk-pok",
# "experimental-force_fft_algo_dif4",
# ] }
# tfhe-versionable = { git = "https://github.com/zama-ai/tfhe-rs.git", tag = "tfhe-rs-1.5.0" }
tfhe-backward-compat-data = { path = "../.." }

View File

@@ -0,0 +1,82 @@
mod utils;
use std::borrow::Cow;
use std::fs::create_dir_all;
use std::path::Path;
use tfhe::boolean::engine::BooleanEngine;
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
use tfhe::shortint::engine::ShortintEngine;
use tfhe::{CompressedServerKey, Seed};
use tfhe_backward_compat_data::generate::*;
use tfhe_backward_compat_data::*;
use utils::*;
const HL_CLIENTKEY_TEST: HlClientKeyTest = HlClientKeyTest {
test_filename: Cow::Borrowed("client_key"),
parameters: INSECURE_SMALL_TEST_PARAMS_MS_MEAN_COMPENSATION,
};
const HL_COMPRESSED_SERVERKEY_TEST: HlServerKeyTest = HlServerKeyTest {
test_filename: Cow::Borrowed("compressed_server_key"),
client_key_filename: Cow::Borrowed("client_key.cbor"),
rerand_cpk_filename: None,
compressed: true,
};
const HL_SERVERKEY_WITH_COMPRESSION_TEST: HlServerKeyTest = HlServerKeyTest {
test_filename: Cow::Borrowed("server_key_with_compression"),
client_key_filename: Cow::Borrowed("client_key.cbor"),
rerand_cpk_filename: None,
compressed: false,
};
pub struct V1_5;
impl TfhersVersion for V1_5 {
const VERSION_NUMBER: &'static str = "1.5";
fn seed_prng(seed: u128) {
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
let shortint_engine = ShortintEngine::new_from_seeder(&mut seeder);
ShortintEngine::with_thread_local_mut(|local_engine| {
let _ = std::mem::replace(local_engine, shortint_engine);
});
let boolean_engine = BooleanEngine::new_from_seeder(&mut seeder);
BooleanEngine::replace_thread_local(boolean_engine);
}
fn gen_shortint_data<P: AsRef<Path>>(_base_data_dir: P) -> Vec<TestMetadata> {
Vec::new()
}
fn gen_hl_data<P: AsRef<Path>>(base_data_dir: P) -> Vec<TestMetadata> {
let dir = Self::data_dir(base_data_dir).join(HL_MODULE_NAME);
create_dir_all(&dir).unwrap();
let config =
tfhe::ConfigBuilder::with_custom_parameters(HL_CLIENTKEY_TEST.parameters.convert())
.enable_compression(INSECURE_TEST_PARAMS_TUNIFORM_COMPRESSION_MULTIBIT.convert())
.build();
let (hl_client_key, hl_server_key) = tfhe::generate_keys(config);
let compressed_server_key = CompressedServerKey::new(&hl_client_key);
store_versioned_test(&hl_client_key, &dir, &HL_CLIENTKEY_TEST.test_filename);
store_versioned_test(
&compressed_server_key,
&dir,
&HL_COMPRESSED_SERVERKEY_TEST.test_filename,
);
store_versioned_test(
&hl_server_key,
&dir,
&HL_SERVERKEY_WITH_COMPRESSION_TEST.test_filename,
);
vec![
TestMetadata::HlClientKey(HL_CLIENTKEY_TEST),
TestMetadata::HlServerKey(HL_COMPRESSED_SERVERKEY_TEST),
TestMetadata::HlServerKey(HL_SERVERKEY_WITH_COMPRESSION_TEST),
]
}
}

View File

@@ -0,0 +1,36 @@
use std::fs::remove_dir_all;
use std::path::PathBuf;
use clap::Parser;
use generate_1_5::V1_5;
use tfhe_backward_compat_data::dir_for_version;
use tfhe_backward_compat_data::generate::{
display_metadata, gen_all_data, update_metadata_for_version,
};
#[derive(Parser, Debug)]
struct Args {
/// The path where the backward data should be stored
#[arg(long)]
data_path: PathBuf,
/// Output metadata to stdout instead of writing them to the ron file
#[arg(long, action)]
stdout: bool,
}
fn main() {
let args = Args::parse();
let version_dir = dir_for_version(&args.data_path, "1.5");
// Ignore if directory does not exist
let _ = remove_dir_all(&version_dir);
let data = gen_all_data::<V1_5>(&args.data_path);
if args.stdout {
display_metadata(&data);
} else {
update_metadata_for_version(data, args.data_path);
}
}

View File

@@ -0,0 +1,248 @@
use std::path::Path;
use tfhe::core_crypto::prelude::{
CiphertextModulusLog, LweCiphertextCount, TUniform, UnsignedInteger,
};
use tfhe::shortint::parameters::list_compression::{
ClassicCompressionParameters, MultiBitCompressionParameters,
};
use tfhe::shortint::parameters::*;
use tfhe::shortint::prelude::ModulusSwitchType;
use tfhe::shortint::{MultiBitPBSParameters, PBSParameters};
use tfhe_backward_compat_data::generate::*;
use tfhe_backward_compat_data::*;
use tfhe_versionable::Versionize;
pub(crate) fn store_versioned_test<Data: Versionize + 'static, P: AsRef<Path>>(
msg: &Data,
dir: P,
test_filename: &str,
) {
generic_store_versioned_test(Versionize::versionize, msg, dir, test_filename)
}
#[allow(dead_code)]
pub(crate) fn store_versioned_auxiliary<Data: Versionize + 'static, P: AsRef<Path>>(
msg: &Data,
dir: P,
test_filename: &str,
) {
generic_store_versioned_auxiliary(Versionize::versionize, msg, dir, test_filename)
}
/// This trait allows to convert version independent parameters types defined in
/// `tfhe-backward-compat-data` to the equivalent TFHE-rs parameters for this version.
///
/// This is similar to `Into` but allows to circumvent the orphan rule.
pub(crate) trait ConvertParams<TfheRsParams> {
fn convert(self) -> TfheRsParams;
}
impl<Scalar> ConvertParams<DynamicDistribution<Scalar>> for TestDistribution
where
Scalar: UnsignedInteger,
{
fn convert(self) -> DynamicDistribution<Scalar> {
match self {
TestDistribution::Gaussian { stddev } => {
DynamicDistribution::new_gaussian_from_std_dev(StandardDev(stddev))
}
TestDistribution::TUniform { bound_log2 } => {
DynamicDistribution::TUniform(TUniform::new(bound_log2))
}
}
}
}
impl ConvertParams<ClassicPBSParameters> for TestClassicParameterSet {
fn convert(self) -> ClassicPBSParameters {
let TestClassicParameterSet {
lwe_dimension,
glwe_dimension,
polynomial_size,
lwe_noise_distribution,
glwe_noise_distribution,
pbs_base_log,
pbs_level,
ks_base_log,
ks_level,
message_modulus,
ciphertext_modulus,
carry_modulus,
max_noise_level,
log2_p_fail,
encryption_key_choice,
modulus_switch_noise_reduction_params,
} = self;
ClassicPBSParameters {
lwe_dimension: LweDimension(lwe_dimension),
glwe_dimension: GlweDimension(glwe_dimension),
polynomial_size: PolynomialSize(polynomial_size),
lwe_noise_distribution: lwe_noise_distribution.convert(),
glwe_noise_distribution: glwe_noise_distribution.convert(),
pbs_base_log: DecompositionBaseLog(pbs_base_log),
pbs_level: DecompositionLevelCount(pbs_level),
ks_base_log: DecompositionBaseLog(ks_base_log),
ks_level: DecompositionLevelCount(ks_level),
message_modulus: MessageModulus(message_modulus as u64),
carry_modulus: CarryModulus(carry_modulus as u64),
max_noise_level: MaxNoiseLevel::new(max_noise_level as u64),
log2_p_fail,
ciphertext_modulus: CiphertextModulus::try_new(ciphertext_modulus).unwrap(),
encryption_key_choice: {
match &*encryption_key_choice {
"big" => EncryptionKeyChoice::Big,
"small" => EncryptionKeyChoice::Small,
_ => panic!("Invalid encryption key choice"),
}
},
modulus_switch_noise_reduction_params: modulus_switch_noise_reduction_params.convert(),
}
}
}
impl ConvertParams<ModulusSwitchNoiseReductionParams> for TestModulusSwitchNoiseReductionParams {
fn convert(self) -> ModulusSwitchNoiseReductionParams {
let TestModulusSwitchNoiseReductionParams {
modulus_switch_zeros_count,
ms_bound,
ms_r_sigma_factor,
ms_input_variance,
} = self;
ModulusSwitchNoiseReductionParams {
modulus_switch_zeros_count: LweCiphertextCount(modulus_switch_zeros_count),
ms_bound: NoiseEstimationMeasureBound(ms_bound),
ms_r_sigma_factor: RSigmaFactor(ms_r_sigma_factor),
ms_input_variance: Variance(ms_input_variance),
}
}
}
impl ConvertParams<ModulusSwitchType> for TestModulusSwitchType {
fn convert(self) -> ModulusSwitchType {
match self {
TestModulusSwitchType::Standard => ModulusSwitchType::Standard,
TestModulusSwitchType::DriftTechniqueNoiseReduction(
test_modulus_switch_noise_reduction_params,
) => ModulusSwitchType::DriftTechniqueNoiseReduction(
test_modulus_switch_noise_reduction_params.convert(),
),
TestModulusSwitchType::CenteredMeanNoiseReduction => {
ModulusSwitchType::CenteredMeanNoiseReduction
}
}
}
}
impl ConvertParams<MultiBitPBSParameters> for TestMultiBitParameterSet {
fn convert(self) -> MultiBitPBSParameters {
let TestMultiBitParameterSet {
lwe_dimension,
glwe_dimension,
polynomial_size,
lwe_noise_distribution,
glwe_noise_distribution,
pbs_base_log,
pbs_level,
ks_base_log,
ks_level,
message_modulus,
ciphertext_modulus,
carry_modulus,
max_noise_level,
log2_p_fail,
encryption_key_choice,
grouping_factor,
} = self;
MultiBitPBSParameters {
lwe_dimension: LweDimension(lwe_dimension),
glwe_dimension: GlweDimension(glwe_dimension),
polynomial_size: PolynomialSize(polynomial_size),
lwe_noise_distribution: lwe_noise_distribution.convert(),
glwe_noise_distribution: glwe_noise_distribution.convert(),
pbs_base_log: DecompositionBaseLog(pbs_base_log),
pbs_level: DecompositionLevelCount(pbs_level),
ks_base_log: DecompositionBaseLog(ks_base_log),
ks_level: DecompositionLevelCount(ks_level),
message_modulus: MessageModulus(message_modulus as u64),
carry_modulus: CarryModulus(carry_modulus as u64),
max_noise_level: MaxNoiseLevel::new(max_noise_level as u64),
log2_p_fail,
ciphertext_modulus: CiphertextModulus::try_new(ciphertext_modulus).unwrap(),
encryption_key_choice: {
match &*encryption_key_choice {
"big" => EncryptionKeyChoice::Big,
"small" => EncryptionKeyChoice::Small,
_ => panic!("Invalid encryption key choice"),
}
},
grouping_factor: LweBskGroupingFactor(grouping_factor),
deterministic_execution: false,
}
}
}
impl ConvertParams<PBSParameters> for TestParameterSet {
fn convert(self) -> PBSParameters {
match self {
TestParameterSet::TestClassicParameterSet(test_classic_parameter_set) => {
PBSParameters::PBS(test_classic_parameter_set.convert())
}
TestParameterSet::TestMultiBitParameterSet(test_parameter_set_multi_bit) => {
PBSParameters::MultiBitPBS(test_parameter_set_multi_bit.convert())
}
TestParameterSet::TestKS32ParameterSet(_) => {
panic!("unsupported ks32 parameters for version")
}
}
}
}
impl ConvertParams<CompressionParameters> for TestCompressionParameterSet {
fn convert(self) -> CompressionParameters {
let TestCompressionParameterSet {
br_level,
br_base_log,
packing_ks_level,
packing_ks_base_log,
packing_ks_polynomial_size,
packing_ks_glwe_dimension,
lwe_per_glwe,
storage_log_modulus,
packing_ks_key_noise_distribution,
decompression_grouping_factor,
} = self;
match decompression_grouping_factor {
Some(decompression_grouping_factor) => {
CompressionParameters::MultiBit(MultiBitCompressionParameters {
br_level: DecompositionLevelCount(br_level),
br_base_log: DecompositionBaseLog(br_base_log),
packing_ks_level: DecompositionLevelCount(packing_ks_level),
packing_ks_base_log: DecompositionBaseLog(packing_ks_base_log),
packing_ks_polynomial_size: PolynomialSize(packing_ks_polynomial_size),
packing_ks_glwe_dimension: GlweDimension(packing_ks_glwe_dimension),
lwe_per_glwe: LweCiphertextCount(lwe_per_glwe),
storage_log_modulus: CiphertextModulusLog(storage_log_modulus),
packing_ks_key_noise_distribution: packing_ks_key_noise_distribution.convert(),
decompression_grouping_factor: LweBskGroupingFactor(
decompression_grouping_factor,
),
})
}
None => CompressionParameters::Classic(ClassicCompressionParameters {
br_level: DecompositionLevelCount(br_level),
br_base_log: DecompositionBaseLog(br_base_log),
packing_ks_level: DecompositionLevelCount(packing_ks_level),
packing_ks_base_log: DecompositionBaseLog(packing_ks_base_log),
packing_ks_polynomial_size: PolynomialSize(packing_ks_polynomial_size),
packing_ks_glwe_dimension: GlweDimension(packing_ks_glwe_dimension),
lwe_per_glwe: LweCiphertextCount(lwe_per_glwe),
storage_log_modulus: CiphertextModulusLog(storage_log_modulus),
packing_ks_key_noise_distribution: packing_ks_key_noise_distribution.convert(),
}),
}
}
}

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5c9e22c5b9a031c6e045e9e9244a6620e17072fea4aca043716dfa735b0f87f8
size 18914

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff40b444649863c987557d7c791e5fda7c7824eae6bd3b849e0ae488e5959e7f
size 3548

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5759c7618de2fd921519f145fbde201d6db0994faf6de4864061c94e4ea34f6a
size 46269030

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e18bde6098e48b111d6230724ec5fb4807da63f4aeca9aab3b8308df8929173a
size 52053435

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:697dc1f5878a483964b75d9d2eaf5d364510eab84ad1ea46bdaeabc0a74063ea
size 92635750

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b2b26204f294efc567147baaf20cf5b1d40caceb696ffd5642926ca04afee8bb
size 108393737

View File

@@ -796,4 +796,53 @@
compressed: false,
)),
),
(
tfhe_version_min: "1.5",
tfhe_module: "high_level_api",
metadata: HlClientKey((
test_filename: "client_key",
parameters: TestClassicParameterSet((
lwe_dimension: 2,
glwe_dimension: 1,
polynomial_size: 2048,
lwe_noise_distribution: TUniform(
bound_log2: 45,
),
glwe_noise_distribution: TUniform(
bound_log2: 17,
),
pbs_base_log: 23,
pbs_level: 1,
ks_base_log: 4,
ks_level: 4,
message_modulus: 4,
ciphertext_modulus: 18446744073709551616,
carry_modulus: 4,
max_noise_level: 5,
log2_p_fail: -129.15284804376165,
encryption_key_choice: "big",
modulus_switch_noise_reduction_params: CenteredMeanNoiseReduction,
)),
)),
),
(
tfhe_version_min: "1.5",
tfhe_module: "high_level_api",
metadata: HlServerKey((
test_filename: "compressed_server_key",
client_key_filename: "client_key.cbor",
rerand_cpk_filename: None,
compressed: true,
)),
),
(
tfhe_version_min: "1.5",
tfhe_module: "high_level_api",
metadata: HlServerKey((
test_filename: "server_key_with_compression",
client_key_filename: "client_key.cbor",
rerand_cpk_filename: None,
compressed: false,
)),
),
]

View File

@@ -263,6 +263,20 @@ pub const VALID_TEST_PARAMS_TUNIFORM_COMPRESSION: TestCompressionParameterSet =
decompression_grouping_factor: None,
};
pub const INSECURE_TEST_PARAMS_TUNIFORM_COMPRESSION_MULTIBIT: TestCompressionParameterSet =
TestCompressionParameterSet {
br_level: 1,
br_base_log: 22,
packing_ks_level: 3,
packing_ks_base_log: 4,
packing_ks_polynomial_size: 256,
packing_ks_glwe_dimension: 1,
lwe_per_glwe: 256,
storage_log_modulus: 12,
packing_ks_key_noise_distribution: TestDistribution::TUniform { bound_log2: 43 },
decompression_grouping_factor: Some(4),
};
/// Invalid parameter set to test the limits
pub const INVALID_TEST_PARAMS: TestClassicParameterSet = TestClassicParameterSet {
lwe_dimension: usize::MAX,