Files
tfhe-rs/tfhe/examples/utilities/shortint_key_sizes.rs
2025-02-18 13:19:28 +01:00

292 lines
9.4 KiB
Rust

#[path = "../../benches/utilities.rs"]
mod utilities;
use crate::utilities::{write_to_json, CryptoParametersRecord, OperatorType};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::current_params::*;
use tfhe::shortint::parameters::*;
use tfhe::shortint::{
ClassicPBSParameters, ClientKey, CompactPrivateKey, CompressedCompactPublicKey,
CompressedKeySwitchingKey, CompressedServerKey, PBSParameters,
};
fn write_result(file: &mut File, name: &str, value: usize) {
let line = format!("{name},{value}\n");
let error_message = format!("cannot write {name} result into file");
file.write_all(line.as_bytes()).expect(&error_message);
}
fn client_server_key_sizes(results_file: &Path) {
let shortint_params_vec: Vec<PBSParameters> = vec![
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(),
V1_0_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128.into(),
V1_0_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128.into(),
V1_0_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128.into(),
V1_0_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128.into(),
V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64.into(),
V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64.into(),
V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64.into(),
V1_0_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64.into(),
V1_0_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64.into(),
V1_0_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64.into(),
];
File::create(results_file).expect("create results file failed");
let mut file = OpenOptions::new()
.append(true)
.open(results_file)
.expect("cannot open results file");
let operator = OperatorType::Atomic;
println!("Generating shortint (ClientKey, ServerKey)");
for (i, params) in shortint_params_vec.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}",
i + 1,
shortint_params_vec.len(),
params.name().to_lowercase()
);
let keys = KEY_CACHE.get_from_param(params);
let cks = keys.client_key();
let sks = keys.server_key();
let ksk_size = sks.key_switching_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_ksk", params.name());
write_result(&mut file, &test_name, ksk_size);
write_to_json::<u64, _>(
&test_name,
params,
params.name(),
"KSK",
&operator,
0,
vec![],
);
println!(
"Element in KSK: {}, size in bytes: {}",
sks.key_switching_key_size_elements(),
ksk_size,
);
let bsk_size = sks.bootstrapping_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_bsk", params.name());
write_result(&mut file, &test_name, bsk_size);
write_to_json::<u64, _>(
&test_name,
params,
params.name(),
"BSK",
&operator,
0,
vec![],
);
println!(
"Element in BSK: {}, size in bytes: {}",
sks.bootstrapping_key_size_elements(),
bsk_size,
);
let sks_compressed = CompressedServerKey::new(cks);
let bsk_compressed_size = sks_compressed
.bootstrapping_key
.bootstrapping_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_bsk_compressed", params.name());
write_result(&mut file, &test_name, bsk_compressed_size);
write_to_json::<u64, _>(
&test_name,
params,
params.name(),
"BSK",
&operator,
0,
vec![],
);
println!(
"Element in BSK compressed: {}, size in bytes: {}",
sks_compressed
.bootstrapping_key
.bootstrapping_key_size_elements(),
bsk_compressed_size,
);
// Clear keys as we go to avoid filling the RAM
KEY_CACHE.clear_in_memory_cache()
}
}
fn measure_serialized_size<T: serde::Serialize, P: Into<CryptoParametersRecord<u64>> + Clone>(
to_serialize: &T,
param: P,
param_name: &str,
test_name_suffix: &str,
display_name: &str,
file: &mut File,
) {
let serialized = bincode::serialize(to_serialize).unwrap();
let size = serialized.len();
let test_name = format!("shortint_key_sizes_{}_{}", param_name, test_name_suffix);
write_result(file, &test_name, size);
write_to_json::<u64, _>(
&test_name,
param.clone(),
param_name,
display_name,
&OperatorType::Atomic,
0,
vec![],
);
println!(
"{} {} -> size: {} bytes",
test_name_suffix, param_name, size,
);
}
fn tuniform_key_set_sizes(results_file: &Path) {
File::create(results_file).expect("create results file failed");
let mut file = OpenOptions::new()
.append(true)
.open(results_file)
.expect("cannot open results file");
println!("Measuring shortint key sizes:");
let param_fhe = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let param_fhe_name = param_fhe.name();
let cks = ClientKey::new(param_fhe);
let compressed_sks = CompressedServerKey::new(&cks);
let sks = compressed_sks.decompress();
measure_serialized_size(
&sks.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk",
"KSK",
&mut file,
);
measure_serialized_size(
&compressed_sks.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk_compressed",
"KSK",
&mut file,
);
measure_serialized_size(
&sks.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk",
"BSK",
&mut file,
);
measure_serialized_size(
&compressed_sks.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk_compressed",
"BSK",
&mut file,
);
let param_pke = PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let param_pke_name = stringify!(PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128);
let compact_private_key = CompactPrivateKey::new(param_pke);
let compressed_pk = CompressedCompactPublicKey::new(&compact_private_key);
let pk = compressed_pk.decompress();
measure_serialized_size(&pk, param_pke, param_pke_name, "cpk", "CPK", &mut file);
measure_serialized_size(
&compressed_pk,
param_pke,
param_pke_name,
"cpk_compressed",
"CPK",
&mut file,
);
let param_compression = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let param_compression_name = stringify!(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128);
let params_tuple = (
param_compression,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
);
let private_compression_key = cks.new_compression_private_key(param_compression);
let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);
measure_serialized_size(
&compression_key,
params_tuple,
param_compression_name,
"compression_key",
"CompressionKey",
&mut file,
);
measure_serialized_size(
&decompression_key,
params_tuple,
param_compression_name,
"decompression_key",
"CompressionKey",
&mut file,
);
let param_casting = PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let param_casting_name = stringify!(PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128);
let compressed_casting_key = CompressedKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &compressed_sks),
param_casting,
);
let casting_key = compressed_casting_key.decompress();
measure_serialized_size(
&casting_key.into_raw_parts().0,
param_casting,
param_casting_name,
"casting_key",
"CastKey",
&mut file,
);
measure_serialized_size(
&compressed_casting_key.into_raw_parts().0,
param_casting,
param_casting_name,
"casting_key_compressed",
"CastKey",
&mut file,
);
}
fn main() {
let work_dir = std::env::current_dir().unwrap();
println!("work_dir: {}", std::env::current_dir().unwrap().display());
// Change workdir so that the location of the keycache matches the one for tests
let mut new_work_dir = work_dir;
new_work_dir.push("tfhe");
std::env::set_current_dir(new_work_dir).unwrap();
let results_file = Path::new("shortint_key_sizes.csv");
client_server_key_sizes(results_file);
tuniform_key_set_sizes(results_file);
}