feat(hpu): Adding support for modulus switch mean compensation

Including the pfail 2e-128 parameter set.

Note: The HPU mockup still does not support mean compensation.
This commit is contained in:
Helder Campos
2025-05-30 19:02:51 +01:00
parent fe5542f39e
commit 25362b2db2
13 changed files with 157 additions and 11 deletions

View File

@@ -13,6 +13,7 @@
bpip_use = true bpip_use = true
bpip_use_opportunism = true bpip_use_opportunism = true
bpip_timeout = 100_000 bpip_timeout = 100_000
mod_switch_mean_comp = true
[board] [board]
ct_mem = 32768 ct_mem = 32768

View File

@@ -49,7 +49,8 @@ offset= 0x10
owner="Parameter" owner="Parameter"
read_access="Read" read_access="Read"
write_access="None" write_access="None"
default={Param="VERSION"} field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
[section.info.register.ntt_architecture] [section.info.register.ntt_architecture]
description="NTT architecture" description="NTT architecture"
@@ -254,3 +255,15 @@ description="BPIP configuration"
read_access="Read" read_access="Read"
write_access="Write" write_access="Write"
default={Cst=0xffffffff} default={Cst=0xffffffff}
# =====================================================================================================================
[section.keyswitch]
offset= 0x3000
description="Keyswitch Configuration"
[section.keyswitch.register.config]
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
owner="User"
read_access="Read"
write_access="Write"
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}

View File

@@ -11,6 +11,7 @@
bpip_use = true bpip_use = true
bpip_use_opportunism = true bpip_use_opportunism = true
bpip_timeout = 100_000 bpip_timeout = 100_000
mod_switch_mean_comp = true
[board] [board]
ct_mem = 4096 ct_mem = 4096

View File

@@ -16,6 +16,7 @@
bpip_use = true bpip_use = true
bpip_use_opportunism = true bpip_use_opportunism = true
bpip_timeout = 100_000 bpip_timeout = 100_000
mod_switch_mean_comp = true
[board] [board]
ct_mem = 32768 ct_mem = 32768

View File

@@ -49,7 +49,8 @@ offset= 0x10
owner="Parameter" owner="Parameter"
read_access="Read" read_access="Read"
write_access="None" write_access="None"
default={Param="VERSION"} field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
[section.info.register.ntt_architecture] [section.info.register.ntt_architecture]
description="NTT architecture" description="NTT architecture"
@@ -254,3 +255,15 @@ description="BPIP configuration"
read_access="Read" read_access="Read"
write_access="Write" write_access="Write"
default={Cst=0xffffffff} default={Cst=0xffffffff}
# =====================================================================================================================
[section.keyswitch]
offset= 0x3000
description="Keyswitch Configuration"
[section.keyswitch.register.config]
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
owner="User"
read_access="Read"
write_access="Write"
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:cb9ebedd0987130c4f6e1ef09f279d92f083815c1383da4b257198a33ab4881e oid sha256:4c17c71cedd183daeabf93649b61b294619445b695754606066051b10ecef27c
size 80293531 size 83001412

View File

@@ -165,7 +165,26 @@ impl HpuBackend {
config.rtl.bpip_timeout, config.rtl.bpip_timeout,
); );
let ks_config_reg = regmap
.register()
.get("keyswitch::config")
.expect("Unknown register, check regmap definition");
hpu_hw.write_reg(
*ks_config_reg.offset() as u64,
ks_config_reg.from_field(
[(
"mod_switch_mean_comp",
config.rtl.mod_switch_mean_comp as u32,
)]
.into(),
),
);
info!("{params:?}"); info!("{params:?}");
debug!(
"Keyswitch registers {:?}",
rtl::runtime::InfoKeyswitch::from_rtl(&mut hpu_hw, &regmap)
);
debug!( debug!(
"Isc registers {:?}", "Isc registers {:?}",
rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, &regmap) rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, &regmap)

View File

@@ -77,6 +77,8 @@ pub struct RtlConfig {
pub bpip_use_opportunism: bool, pub bpip_use_opportunism: bool,
/// Timeout value to start Bpip even if batch isn't full /// Timeout value to start Bpip even if batch isn't full
pub bpip_timeout: u32, pub bpip_timeout: u32,
/// Use modulus switch mean compensation
pub mod_switch_mean_comp: bool,
} }
/// On-board memory configuration /// On-board memory configuration

View File

@@ -416,6 +416,21 @@ pub const MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C: HpuPBSParameters = HpuPBSPa
ciphertext_width: 64, ciphertext_width: 64,
}; };
pub const MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47: HpuPBSParameters = HpuPBSParameters {
lwe_dimension: 879,
glwe_dimension: 1,
polynomial_size: 2048,
lwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(3),
glwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(17),
pbs_base_log: 23,
pbs_level: 1,
ks_base_log: 2,
ks_level: 8,
message_width: 2,
carry_width: 2,
ciphertext_width: 64,
};
impl FromRtl for HpuPBSParameters { impl FromRtl for HpuPBSParameters {
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self { fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
let pbs_app = regmap let pbs_app = regmap
@@ -456,6 +471,7 @@ impl FromRtl for HpuPBSParameters {
11 => MSG2_CARRY2_TUNIFORM, 11 => MSG2_CARRY2_TUNIFORM,
12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA, 12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA,
13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C, 13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C,
14 => MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47,
_ => panic!("Unknown TfheAppName encoding"), _ => panic!("Unknown TfheAppName encoding"),
} }
} }

View File

@@ -932,3 +932,33 @@ impl ErrorHpu {
self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64); self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64);
} }
} }
#[derive(Debug, Default)]
pub struct InfoKeyswitch {
/// Use modulus switch mean compensation
mod_switch_mean_comp: bool,
}
impl FromRtl for InfoKeyswitch {
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
// Info structure have method to update
// Instead of redefine parsing here, use a default construct and update methods
let mut infos = Self::default();
infos.update(ffi_hw, regmap);
infos
}
}
impl InfoKeyswitch {
pub fn update_mod_switch_mean_comp(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
let reg = regmap
.register()
.get("keyswitch::config")
.expect("Unknown register, check regmap definition");
self.mod_switch_mean_comp = ffi_hw.read_reg(*reg.offset() as u64) != 0;
}
pub fn update(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
self.update_mod_switch_mean_comp(ffi_hw, regmap);
}
}

View File

@@ -0,0 +1,48 @@
[pbs_params]
lwe_dimension=879
glwe_dimension=1
polynomial_size=2048
lwe_noise_distribution={TUniformBound= 3}
glwe_noise_distribution={TUniformBound= 17}
pbs_base_log= 23
pbs_level= 1
ks_base_log= 2
ks_level= 8
message_width= 2
carry_width= 2
ciphertext_width= 64
opportunistic=true
[ntt_params]
core_arch= {GF64=[5,6]}
min_pbs_nb= 11
batch_pbs_nb= 12
total_pbs_nb= 32
ct_width= 64
radix= 2
stg_nb= 11
prime_modulus= "GF64"
psi= 64
delta= 5
[ks_params]
width= 21
lbx= 3
lby= 64
lbz= 3
[pc_params]
ksk_pc= 16
ksk_bytes_w= 32
bsk_pc= 8
bsk_bytes_w= 32
pem_pc= 2
pem_bytes_w= 32
glwe_bytes_w= 32
[regf_params]
reg_nb= 64
coef_nb= 32
[isc_params]
min_iop_size= 4
depth= 64

View File

@@ -29,8 +29,6 @@ use modules::{DdrMem, HbmBank, RegisterEvent, RegisterMap, UCore, HBM_BANK_NB};
use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem; use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem;
use tfhe::tfhe_hpu_backend::prelude::*; use tfhe::tfhe_hpu_backend::prelude::*;
use serde_json;
pub struct HpuSim { pub struct HpuSim {
config: HpuConfig, config: HpuConfig,
params: MockupParameters, params: MockupParameters,
@@ -178,7 +176,7 @@ impl HpuSim {
} }
RegisterReq::PbsParams => { RegisterReq::PbsParams => {
self.ipc.register_ack(RegisterAck::PbsParams( self.ipc.register_ack(RegisterAck::PbsParams(
self.params.rtl_params.pbs_params.clone(), self.params.rtl_params.pbs_params,
)); ));
} }
} }
@@ -672,7 +670,7 @@ impl HpuSim {
// Compute Lut properties // Compute Lut properties
let (modulus_sup, box_size, fn_stride) = { let (modulus_sup, box_size, fn_stride) = {
let pbs_p = &self.params.rtl_params.pbs_params; let pbs_p = &self.params.rtl_params.pbs_params;
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width; let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
let box_size = pbs_p.polynomial_size / modulus_sup; let box_size = pbs_p.polynomial_size / modulus_sup;
// Max valid degree for a ciphertext when using the LUT we generate // Max valid degree for a ciphertext when using the LUT we generate
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1 // If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
@@ -725,7 +723,7 @@ impl HpuSim {
let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus); let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus);
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, &bsk); blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, bsk);
assert_eq!( assert_eq!(
dst_rid.0, dst_rid.0,
@@ -736,7 +734,7 @@ impl HpuSim {
// Compute ManyLut function stride // Compute ManyLut function stride
let fn_stride = { let fn_stride = {
let pbs_p = &self.params.rtl_params.pbs_params; let pbs_p = &self.params.rtl_params.pbs_params;
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width; let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
let box_size = pbs_p.polynomial_size / modulus_sup; let box_size = pbs_p.polynomial_size / modulus_sup;
// Max valid degree for a ciphertext when using the LUT we generate // Max valid degree for a ciphertext when using the LUT we generate
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1 // If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1

View File

@@ -127,7 +127,7 @@ impl RegisterMap {
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32 (ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
} }
"info::ntt_modulo" => { "info::ntt_modulo" => {
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus.clone() as u8) as u32 MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
} }
"info::application" => { "info::application" => {
@@ -147,6 +147,10 @@ impl RegisterMap {
APPLICATION_NAME_OFS + 11 APPLICATION_NAME_OFS + 11
} else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params { } else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 12 APPLICATION_NAME_OFS + 12
} else if MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 13
} else if MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47 == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 14
} else { } else {
// Custom simulation parameters set // Custom simulation parameters set
// -> Return 1 without NAME_OFS // -> Return 1 without NAME_OFS