feat(hpu): Adding support for modulus switch mean compensation

Including the pfail 2e-128 parameter set.

Note: The HPU mockup still does not support mean compensation.
This commit is contained in:
Helder Campos
2025-05-30 19:02:51 +01:00
parent fe5542f39e
commit 25362b2db2
13 changed files with 157 additions and 11 deletions

View File

@@ -13,6 +13,7 @@
bpip_use = true
bpip_use_opportunism = true
bpip_timeout = 100_000
mod_switch_mean_comp = true
[board]
ct_mem = 32768

View File

@@ -49,7 +49,8 @@ offset= 0x10
owner="Parameter"
read_access="Read"
write_access="None"
default={Param="VERSION"}
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
[section.info.register.ntt_architecture]
description="NTT architecture"
@@ -254,3 +255,15 @@ description="BPIP configuration"
read_access="Read"
write_access="Write"
default={Cst=0xffffffff}
# =====================================================================================================================
[section.keyswitch]
offset= 0x3000
description="Keyswitch Configuration"
[section.keyswitch.register.config]
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
owner="User"
read_access="Read"
write_access="Write"
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}

View File

@@ -11,6 +11,7 @@
bpip_use = true
bpip_use_opportunism = true
bpip_timeout = 100_000
mod_switch_mean_comp = true
[board]
ct_mem = 4096

View File

@@ -16,6 +16,7 @@
bpip_use = true
bpip_use_opportunism = true
bpip_timeout = 100_000
mod_switch_mean_comp = true
[board]
ct_mem = 32768

View File

@@ -49,7 +49,8 @@ offset= 0x10
owner="Parameter"
read_access="Read"
write_access="None"
default={Param="VERSION"}
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
[section.info.register.ntt_architecture]
description="NTT architecture"
@@ -254,3 +255,15 @@ description="BPIP configuration"
read_access="Read"
write_access="Write"
default={Cst=0xffffffff}
# =====================================================================================================================
[section.keyswitch]
offset= 0x3000
description="Keyswitch Configuration"
[section.keyswitch.register.config]
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
owner="User"
read_access="Read"
write_access="Write"
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cb9ebedd0987130c4f6e1ef09f279d92f083815c1383da4b257198a33ab4881e
size 80293531
oid sha256:4c17c71cedd183daeabf93649b61b294619445b695754606066051b10ecef27c
size 83001412

View File

@@ -165,7 +165,26 @@ impl HpuBackend {
config.rtl.bpip_timeout,
);
let ks_config_reg = regmap
.register()
.get("keyswitch::config")
.expect("Unknown register, check regmap definition");
hpu_hw.write_reg(
*ks_config_reg.offset() as u64,
ks_config_reg.from_field(
[(
"mod_switch_mean_comp",
config.rtl.mod_switch_mean_comp as u32,
)]
.into(),
),
);
info!("{params:?}");
debug!(
"Keyswitch registers {:?}",
rtl::runtime::InfoKeyswitch::from_rtl(&mut hpu_hw, &regmap)
);
debug!(
"Isc registers {:?}",
rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, &regmap)

View File

@@ -77,6 +77,8 @@ pub struct RtlConfig {
pub bpip_use_opportunism: bool,
/// Timeout value to start Bpip even if batch isn't full
pub bpip_timeout: u32,
/// Use modulus switch mean compensation
pub mod_switch_mean_comp: bool,
}
/// On-board memory configuration

View File

@@ -416,6 +416,21 @@ pub const MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C: HpuPBSParameters = HpuPBSPa
ciphertext_width: 64,
};
pub const MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47: HpuPBSParameters = HpuPBSParameters {
lwe_dimension: 879,
glwe_dimension: 1,
polynomial_size: 2048,
lwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(3),
glwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(17),
pbs_base_log: 23,
pbs_level: 1,
ks_base_log: 2,
ks_level: 8,
message_width: 2,
carry_width: 2,
ciphertext_width: 64,
};
impl FromRtl for HpuPBSParameters {
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
let pbs_app = regmap
@@ -456,6 +471,7 @@ impl FromRtl for HpuPBSParameters {
11 => MSG2_CARRY2_TUNIFORM,
12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA,
13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C,
14 => MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47,
_ => panic!("Unknown TfheAppName encoding"),
}
}

View File

@@ -932,3 +932,33 @@ impl ErrorHpu {
self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64);
}
}
#[derive(Debug, Default)]
pub struct InfoKeyswitch {
/// Use modulus switch mean compensation
mod_switch_mean_comp: bool,
}
impl FromRtl for InfoKeyswitch {
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
// Info structure have method to update
// Instead of redefine parsing here, use a default construct and update methods
let mut infos = Self::default();
infos.update(ffi_hw, regmap);
infos
}
}
impl InfoKeyswitch {
pub fn update_mod_switch_mean_comp(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
let reg = regmap
.register()
.get("keyswitch::config")
.expect("Unknown register, check regmap definition");
self.mod_switch_mean_comp = ffi_hw.read_reg(*reg.offset() as u64) != 0;
}
pub fn update(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
self.update_mod_switch_mean_comp(ffi_hw, regmap);
}
}

View File

@@ -0,0 +1,48 @@
[pbs_params]
lwe_dimension=879
glwe_dimension=1
polynomial_size=2048
lwe_noise_distribution={TUniformBound= 3}
glwe_noise_distribution={TUniformBound= 17}
pbs_base_log= 23
pbs_level= 1
ks_base_log= 2
ks_level= 8
message_width= 2
carry_width= 2
ciphertext_width= 64
opportunistic=true
[ntt_params]
core_arch= {GF64=[5,6]}
min_pbs_nb= 11
batch_pbs_nb= 12
total_pbs_nb= 32
ct_width= 64
radix= 2
stg_nb= 11
prime_modulus= "GF64"
psi= 64
delta= 5
[ks_params]
width= 21
lbx= 3
lby= 64
lbz= 3
[pc_params]
ksk_pc= 16
ksk_bytes_w= 32
bsk_pc= 8
bsk_bytes_w= 32
pem_pc= 2
pem_bytes_w= 32
glwe_bytes_w= 32
[regf_params]
reg_nb= 64
coef_nb= 32
[isc_params]
min_iop_size= 4
depth= 64

View File

@@ -29,8 +29,6 @@ use modules::{DdrMem, HbmBank, RegisterEvent, RegisterMap, UCore, HBM_BANK_NB};
use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem;
use tfhe::tfhe_hpu_backend::prelude::*;
use serde_json;
pub struct HpuSim {
config: HpuConfig,
params: MockupParameters,
@@ -178,7 +176,7 @@ impl HpuSim {
}
RegisterReq::PbsParams => {
self.ipc.register_ack(RegisterAck::PbsParams(
self.params.rtl_params.pbs_params.clone(),
self.params.rtl_params.pbs_params,
));
}
}
@@ -672,7 +670,7 @@ impl HpuSim {
// Compute Lut properties
let (modulus_sup, box_size, fn_stride) = {
let pbs_p = &self.params.rtl_params.pbs_params;
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
let box_size = pbs_p.polynomial_size / modulus_sup;
// Max valid degree for a ciphertext when using the LUT we generate
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
@@ -725,7 +723,7 @@ impl HpuSim {
let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus);
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, &bsk);
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, bsk);
assert_eq!(
dst_rid.0,
@@ -736,7 +734,7 @@ impl HpuSim {
// Compute ManyLut function stride
let fn_stride = {
let pbs_p = &self.params.rtl_params.pbs_params;
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
let box_size = pbs_p.polynomial_size / modulus_sup;
// Max valid degree for a ciphertext when using the LUT we generate
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1

View File

@@ -127,7 +127,7 @@ impl RegisterMap {
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
}
"info::ntt_modulo" => {
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus.clone() as u8) as u32
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
}
"info::application" => {
@@ -147,6 +147,10 @@ impl RegisterMap {
APPLICATION_NAME_OFS + 11
} else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 12
} else if MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 13
} else if MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47 == self.rtl_params.pbs_params {
APPLICATION_NAME_OFS + 14
} else {
// Custom simulation parameters set
// -> Return 1 without NAME_OFS