mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(hpu): Adding support for modulus switch mean compensation
Including the pfail 2e-128 parameter set. Note: The HPU mockup still does not support mean compensation.
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
bpip_use = true
|
bpip_use = true
|
||||||
bpip_use_opportunism = true
|
bpip_use_opportunism = true
|
||||||
bpip_timeout = 100_000
|
bpip_timeout = 100_000
|
||||||
|
mod_switch_mean_comp = true
|
||||||
|
|
||||||
[board]
|
[board]
|
||||||
ct_mem = 32768
|
ct_mem = 32768
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ offset= 0x10
|
|||||||
owner="Parameter"
|
owner="Parameter"
|
||||||
read_access="Read"
|
read_access="Read"
|
||||||
write_access="None"
|
write_access="None"
|
||||||
default={Param="VERSION"}
|
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
|
||||||
|
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
|
||||||
|
|
||||||
[section.info.register.ntt_architecture]
|
[section.info.register.ntt_architecture]
|
||||||
description="NTT architecture"
|
description="NTT architecture"
|
||||||
@@ -254,3 +255,15 @@ description="BPIP configuration"
|
|||||||
read_access="Read"
|
read_access="Read"
|
||||||
write_access="Write"
|
write_access="Write"
|
||||||
default={Cst=0xffffffff}
|
default={Cst=0xffffffff}
|
||||||
|
|
||||||
|
# =====================================================================================================================
|
||||||
|
[section.keyswitch]
|
||||||
|
offset= 0x3000
|
||||||
|
description="Keyswitch Configuration"
|
||||||
|
|
||||||
|
[section.keyswitch.register.config]
|
||||||
|
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
|
||||||
|
owner="User"
|
||||||
|
read_access="Read"
|
||||||
|
write_access="Write"
|
||||||
|
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
bpip_use = true
|
bpip_use = true
|
||||||
bpip_use_opportunism = true
|
bpip_use_opportunism = true
|
||||||
bpip_timeout = 100_000
|
bpip_timeout = 100_000
|
||||||
|
mod_switch_mean_comp = true
|
||||||
|
|
||||||
[board]
|
[board]
|
||||||
ct_mem = 4096
|
ct_mem = 4096
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
bpip_use = true
|
bpip_use = true
|
||||||
bpip_use_opportunism = true
|
bpip_use_opportunism = true
|
||||||
bpip_timeout = 100_000
|
bpip_timeout = 100_000
|
||||||
|
mod_switch_mean_comp = true
|
||||||
|
|
||||||
[board]
|
[board]
|
||||||
ct_mem = 32768
|
ct_mem = 32768
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ offset= 0x10
|
|||||||
owner="Parameter"
|
owner="Parameter"
|
||||||
read_access="Read"
|
read_access="Read"
|
||||||
write_access="None"
|
write_access="None"
|
||||||
default={Param="VERSION"}
|
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
|
||||||
|
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
|
||||||
|
|
||||||
[section.info.register.ntt_architecture]
|
[section.info.register.ntt_architecture]
|
||||||
description="NTT architecture"
|
description="NTT architecture"
|
||||||
@@ -254,3 +255,15 @@ description="BPIP configuration"
|
|||||||
read_access="Read"
|
read_access="Read"
|
||||||
write_access="Write"
|
write_access="Write"
|
||||||
default={Cst=0xffffffff}
|
default={Cst=0xffffffff}
|
||||||
|
|
||||||
|
# =====================================================================================================================
|
||||||
|
[section.keyswitch]
|
||||||
|
offset= 0x3000
|
||||||
|
description="Keyswitch Configuration"
|
||||||
|
|
||||||
|
[section.keyswitch.register.config]
|
||||||
|
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
|
||||||
|
owner="User"
|
||||||
|
read_access="Read"
|
||||||
|
write_access="Write"
|
||||||
|
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cb9ebedd0987130c4f6e1ef09f279d92f083815c1383da4b257198a33ab4881e
|
oid sha256:4c17c71cedd183daeabf93649b61b294619445b695754606066051b10ecef27c
|
||||||
size 80293531
|
size 83001412
|
||||||
|
|||||||
@@ -165,7 +165,26 @@ impl HpuBackend {
|
|||||||
config.rtl.bpip_timeout,
|
config.rtl.bpip_timeout,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let ks_config_reg = regmap
|
||||||
|
.register()
|
||||||
|
.get("keyswitch::config")
|
||||||
|
.expect("Unknown register, check regmap definition");
|
||||||
|
hpu_hw.write_reg(
|
||||||
|
*ks_config_reg.offset() as u64,
|
||||||
|
ks_config_reg.from_field(
|
||||||
|
[(
|
||||||
|
"mod_switch_mean_comp",
|
||||||
|
config.rtl.mod_switch_mean_comp as u32,
|
||||||
|
)]
|
||||||
|
.into(),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
info!("{params:?}");
|
info!("{params:?}");
|
||||||
|
debug!(
|
||||||
|
"Keyswitch registers {:?}",
|
||||||
|
rtl::runtime::InfoKeyswitch::from_rtl(&mut hpu_hw, ®map)
|
||||||
|
);
|
||||||
debug!(
|
debug!(
|
||||||
"Isc registers {:?}",
|
"Isc registers {:?}",
|
||||||
rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, ®map)
|
rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, ®map)
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ pub struct RtlConfig {
|
|||||||
pub bpip_use_opportunism: bool,
|
pub bpip_use_opportunism: bool,
|
||||||
/// Timeout value to start Bpip even if batch isn't full
|
/// Timeout value to start Bpip even if batch isn't full
|
||||||
pub bpip_timeout: u32,
|
pub bpip_timeout: u32,
|
||||||
|
/// Use modulus switch mean compensation
|
||||||
|
pub mod_switch_mean_comp: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// On-board memory configuration
|
/// On-board memory configuration
|
||||||
|
|||||||
@@ -416,6 +416,21 @@ pub const MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C: HpuPBSParameters = HpuPBSPa
|
|||||||
ciphertext_width: 64,
|
ciphertext_width: 64,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47: HpuPBSParameters = HpuPBSParameters {
|
||||||
|
lwe_dimension: 879,
|
||||||
|
glwe_dimension: 1,
|
||||||
|
polynomial_size: 2048,
|
||||||
|
lwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(3),
|
||||||
|
glwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(17),
|
||||||
|
pbs_base_log: 23,
|
||||||
|
pbs_level: 1,
|
||||||
|
ks_base_log: 2,
|
||||||
|
ks_level: 8,
|
||||||
|
message_width: 2,
|
||||||
|
carry_width: 2,
|
||||||
|
ciphertext_width: 64,
|
||||||
|
};
|
||||||
|
|
||||||
impl FromRtl for HpuPBSParameters {
|
impl FromRtl for HpuPBSParameters {
|
||||||
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
|
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
|
||||||
let pbs_app = regmap
|
let pbs_app = regmap
|
||||||
@@ -456,6 +471,7 @@ impl FromRtl for HpuPBSParameters {
|
|||||||
11 => MSG2_CARRY2_TUNIFORM,
|
11 => MSG2_CARRY2_TUNIFORM,
|
||||||
12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA,
|
12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA,
|
||||||
13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C,
|
13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C,
|
||||||
|
14 => MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47,
|
||||||
_ => panic!("Unknown TfheAppName encoding"),
|
_ => panic!("Unknown TfheAppName encoding"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -932,3 +932,33 @@ impl ErrorHpu {
|
|||||||
self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64);
|
self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct InfoKeyswitch {
|
||||||
|
/// Use modulus switch mean compensation
|
||||||
|
mod_switch_mean_comp: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRtl for InfoKeyswitch {
|
||||||
|
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
|
||||||
|
// Info structure have method to update
|
||||||
|
// Instead of redefine parsing here, use a default construct and update methods
|
||||||
|
let mut infos = Self::default();
|
||||||
|
infos.update(ffi_hw, regmap);
|
||||||
|
infos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InfoKeyswitch {
|
||||||
|
pub fn update_mod_switch_mean_comp(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
|
||||||
|
let reg = regmap
|
||||||
|
.register()
|
||||||
|
.get("keyswitch::config")
|
||||||
|
.expect("Unknown register, check regmap definition");
|
||||||
|
self.mod_switch_mean_comp = ffi_hw.read_reg(*reg.offset() as u64) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
|
||||||
|
self.update_mod_switch_mean_comp(ffi_hw, regmap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
[pbs_params]
|
||||||
|
lwe_dimension=879
|
||||||
|
glwe_dimension=1
|
||||||
|
polynomial_size=2048
|
||||||
|
lwe_noise_distribution={TUniformBound= 3}
|
||||||
|
glwe_noise_distribution={TUniformBound= 17}
|
||||||
|
pbs_base_log= 23
|
||||||
|
pbs_level= 1
|
||||||
|
ks_base_log= 2
|
||||||
|
ks_level= 8
|
||||||
|
message_width= 2
|
||||||
|
carry_width= 2
|
||||||
|
ciphertext_width= 64
|
||||||
|
opportunistic=true
|
||||||
|
|
||||||
|
[ntt_params]
|
||||||
|
core_arch= {GF64=[5,6]}
|
||||||
|
min_pbs_nb= 11
|
||||||
|
batch_pbs_nb= 12
|
||||||
|
total_pbs_nb= 32
|
||||||
|
ct_width= 64
|
||||||
|
radix= 2
|
||||||
|
stg_nb= 11
|
||||||
|
prime_modulus= "GF64"
|
||||||
|
psi= 64
|
||||||
|
delta= 5
|
||||||
|
|
||||||
|
[ks_params]
|
||||||
|
width= 21
|
||||||
|
lbx= 3
|
||||||
|
lby= 64
|
||||||
|
lbz= 3
|
||||||
|
|
||||||
|
[pc_params]
|
||||||
|
ksk_pc= 16
|
||||||
|
ksk_bytes_w= 32
|
||||||
|
bsk_pc= 8
|
||||||
|
bsk_bytes_w= 32
|
||||||
|
pem_pc= 2
|
||||||
|
pem_bytes_w= 32
|
||||||
|
glwe_bytes_w= 32
|
||||||
|
|
||||||
|
[regf_params]
|
||||||
|
reg_nb= 64
|
||||||
|
coef_nb= 32
|
||||||
|
[isc_params]
|
||||||
|
min_iop_size= 4
|
||||||
|
depth= 64
|
||||||
@@ -29,8 +29,6 @@ use modules::{DdrMem, HbmBank, RegisterEvent, RegisterMap, UCore, HBM_BANK_NB};
|
|||||||
use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem;
|
use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem;
|
||||||
use tfhe::tfhe_hpu_backend::prelude::*;
|
use tfhe::tfhe_hpu_backend::prelude::*;
|
||||||
|
|
||||||
use serde_json;
|
|
||||||
|
|
||||||
pub struct HpuSim {
|
pub struct HpuSim {
|
||||||
config: HpuConfig,
|
config: HpuConfig,
|
||||||
params: MockupParameters,
|
params: MockupParameters,
|
||||||
@@ -178,7 +176,7 @@ impl HpuSim {
|
|||||||
}
|
}
|
||||||
RegisterReq::PbsParams => {
|
RegisterReq::PbsParams => {
|
||||||
self.ipc.register_ack(RegisterAck::PbsParams(
|
self.ipc.register_ack(RegisterAck::PbsParams(
|
||||||
self.params.rtl_params.pbs_params.clone(),
|
self.params.rtl_params.pbs_params,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -672,7 +670,7 @@ impl HpuSim {
|
|||||||
// Compute Lut properties
|
// Compute Lut properties
|
||||||
let (modulus_sup, box_size, fn_stride) = {
|
let (modulus_sup, box_size, fn_stride) = {
|
||||||
let pbs_p = &self.params.rtl_params.pbs_params;
|
let pbs_p = &self.params.rtl_params.pbs_params;
|
||||||
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
|
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
|
||||||
let box_size = pbs_p.polynomial_size / modulus_sup;
|
let box_size = pbs_p.polynomial_size / modulus_sup;
|
||||||
// Max valid degree for a ciphertext when using the LUT we generate
|
// Max valid degree for a ciphertext when using the LUT we generate
|
||||||
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
||||||
@@ -725,7 +723,7 @@ impl HpuSim {
|
|||||||
|
|
||||||
let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus);
|
let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus);
|
||||||
|
|
||||||
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, &bsk);
|
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, bsk);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst_rid.0,
|
dst_rid.0,
|
||||||
@@ -736,7 +734,7 @@ impl HpuSim {
|
|||||||
// Compute ManyLut function stride
|
// Compute ManyLut function stride
|
||||||
let fn_stride = {
|
let fn_stride = {
|
||||||
let pbs_p = &self.params.rtl_params.pbs_params;
|
let pbs_p = &self.params.rtl_params.pbs_params;
|
||||||
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
|
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
|
||||||
let box_size = pbs_p.polynomial_size / modulus_sup;
|
let box_size = pbs_p.polynomial_size / modulus_sup;
|
||||||
// Max valid degree for a ciphertext when using the LUT we generate
|
// Max valid degree for a ciphertext when using the LUT we generate
|
||||||
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ impl RegisterMap {
|
|||||||
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
|
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
|
||||||
}
|
}
|
||||||
"info::ntt_modulo" => {
|
"info::ntt_modulo" => {
|
||||||
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus.clone() as u8) as u32
|
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
"info::application" => {
|
"info::application" => {
|
||||||
@@ -147,6 +147,10 @@ impl RegisterMap {
|
|||||||
APPLICATION_NAME_OFS + 11
|
APPLICATION_NAME_OFS + 11
|
||||||
} else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params {
|
} else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params {
|
||||||
APPLICATION_NAME_OFS + 12
|
APPLICATION_NAME_OFS + 12
|
||||||
|
} else if MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C == self.rtl_params.pbs_params {
|
||||||
|
APPLICATION_NAME_OFS + 13
|
||||||
|
} else if MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47 == self.rtl_params.pbs_params {
|
||||||
|
APPLICATION_NAME_OFS + 14
|
||||||
} else {
|
} else {
|
||||||
// Custom simulation parameters set
|
// Custom simulation parameters set
|
||||||
// -> Return 1 without NAME_OFS
|
// -> Return 1 without NAME_OFS
|
||||||
|
|||||||
Reference in New Issue
Block a user