chore(gpu): update 4_1_1 params to match specialized pbs

This commit is contained in:
Guillermo Oyarzun
2025-08-26 18:21:04 +02:00
committed by Agnès Leroy
parent 34743ea304
commit a8f391a442
96 changed files with 3610 additions and 648 deletions

View File

@@ -5,8 +5,8 @@ use std::mem::MaybeUninit;
use std::{array, iter};
use tfhe::prelude::*;
use tfhe::shortint::parameters::current_params::{
V1_3_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
};
use tfhe::{set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32};
@@ -190,10 +190,10 @@ fn main() -> Result<(), std::io::Error> {
let config = match args.multibit {
None => ConfigBuilder::default(),
Some(2) => ConfigBuilder::with_custom_parameters(
V1_3_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
),
Some(3) => ConfigBuilder::with_custom_parameters(
V1_3_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
),
Some(v) => {
panic!("Invalid multibit setting {v}");

View File

@@ -32,9 +32,9 @@ const KSK_PARAMS: [(
ClassicPBSParameters,
ShortintKeySwitchingParameters,
); 1] = [(
V1_3_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_KEYSWITCH_1_1_KS_PBS_TO_2_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_KEYSWITCH_1_1_KS_PBS_TO_2_2_KS_PBS_GAUSSIAN_2M128,
)];
fn client_server_keys() {
@@ -64,12 +64,12 @@ fn client_server_keys() {
let coverage_only: bool = matches.get_flag("coverage_only");
if multi_bit_only {
const MULTI_BIT_PARAMS: [MultiBitPBSParameters; 6] = [
V1_3_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_3_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
];
generate_pbs_multi_bit_keys(&MULTI_BIT_PARAMS);
@@ -97,7 +97,7 @@ fn client_server_keys() {
#[cfg(feature = "experimental")]
{
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 1] = [(
V1_3_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
)];
generate_wopbs_keys(&WOPBS_PARAMS);
@@ -111,21 +111,21 @@ fn client_server_keys() {
// TUniform
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
// Gaussian
V1_3_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_1_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_1_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_1_CARRY_4_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_1_CARRY_5_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_1_CARRY_6_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_2_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_3_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_3_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_3_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_4_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_5_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_1_CARRY_6_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_2_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_3_CARRY_1_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_3_CARRY_2_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
V1_4_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128,
// 2M64 as backup as 2M128 is too slow
V1_3_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
];
generate_pbs_keys(&PBS_KEYS);
@@ -133,19 +133,19 @@ fn client_server_keys() {
{
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [
(
V1_3_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
LEGACY_WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS,
),
(
V1_3_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
),
(
V1_3_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
LEGACY_WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS,
),
(
V1_3_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
V1_4_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
LEGACY_WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS,
),
];