mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 14:23:53 -05:00
feat(hpu): Adding support for modulus switch mean compensation
Including the pfail 2e-128 parameter set. Note: The HPU mockup still does not support mean compensation.
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
bpip_use = true
|
||||
bpip_use_opportunism = true
|
||||
bpip_timeout = 100_000
|
||||
mod_switch_mean_comp = true
|
||||
|
||||
[board]
|
||||
ct_mem = 32768
|
||||
|
||||
@@ -49,7 +49,8 @@ offset= 0x10
|
||||
owner="Parameter"
|
||||
read_access="Read"
|
||||
write_access="None"
|
||||
default={Param="VERSION"}
|
||||
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
|
||||
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
|
||||
|
||||
[section.info.register.ntt_architecture]
|
||||
description="NTT architecture"
|
||||
@@ -254,3 +255,15 @@ description="BPIP configuration"
|
||||
read_access="Read"
|
||||
write_access="Write"
|
||||
default={Cst=0xffffffff}
|
||||
|
||||
# =====================================================================================================================
|
||||
[section.keyswitch]
|
||||
offset= 0x3000
|
||||
description="Keyswitch Configuration"
|
||||
|
||||
[section.keyswitch.register.config]
|
||||
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
|
||||
owner="User"
|
||||
read_access="Read"
|
||||
write_access="Write"
|
||||
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
bpip_use = true
|
||||
bpip_use_opportunism = true
|
||||
bpip_timeout = 100_000
|
||||
mod_switch_mean_comp = true
|
||||
|
||||
[board]
|
||||
ct_mem = 4096
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
bpip_use = true
|
||||
bpip_use_opportunism = true
|
||||
bpip_timeout = 100_000
|
||||
mod_switch_mean_comp = true
|
||||
|
||||
[board]
|
||||
ct_mem = 32768
|
||||
|
||||
@@ -49,7 +49,8 @@ offset= 0x10
|
||||
owner="Parameter"
|
||||
read_access="Read"
|
||||
write_access="None"
|
||||
default={Param="VERSION"}
|
||||
field.major={size_b=4, default={Param="VERSION_MAJOR"}, description="RTL major version"}
|
||||
field.minor={size_b=4, default={Param="VERSION_MINOR"}, description="RTL minor version"}
|
||||
|
||||
[section.info.register.ntt_architecture]
|
||||
description="NTT architecture"
|
||||
@@ -254,3 +255,15 @@ description="BPIP configuration"
|
||||
read_access="Read"
|
||||
write_access="Write"
|
||||
default={Cst=0xffffffff}
|
||||
|
||||
# =====================================================================================================================
|
||||
[section.keyswitch]
|
||||
offset= 0x3000
|
||||
description="Keyswitch Configuration"
|
||||
|
||||
[section.keyswitch.register.config]
|
||||
description="(1) Use use modulus switching mean compensation. (default), (0) Don't use modulus switching mean compensation."
|
||||
owner="User"
|
||||
read_access="Read"
|
||||
write_access="Write"
|
||||
field.mod_switch_mean_comp = { size_b=1, offset_b=0 , default={Cst=1}, description="Controls whether to use modulus switch mean compensation, aka. Mayeul's Trick."}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cb9ebedd0987130c4f6e1ef09f279d92f083815c1383da4b257198a33ab4881e
|
||||
size 80293531
|
||||
oid sha256:4c17c71cedd183daeabf93649b61b294619445b695754606066051b10ecef27c
|
||||
size 83001412
|
||||
|
||||
@@ -165,7 +165,26 @@ impl HpuBackend {
|
||||
config.rtl.bpip_timeout,
|
||||
);
|
||||
|
||||
let ks_config_reg = regmap
|
||||
.register()
|
||||
.get("keyswitch::config")
|
||||
.expect("Unknown register, check regmap definition");
|
||||
hpu_hw.write_reg(
|
||||
*ks_config_reg.offset() as u64,
|
||||
ks_config_reg.from_field(
|
||||
[(
|
||||
"mod_switch_mean_comp",
|
||||
config.rtl.mod_switch_mean_comp as u32,
|
||||
)]
|
||||
.into(),
|
||||
),
|
||||
);
|
||||
|
||||
info!("{params:?}");
|
||||
debug!(
|
||||
"Keyswitch registers {:?}",
|
||||
rtl::runtime::InfoKeyswitch::from_rtl(&mut hpu_hw, ®map)
|
||||
);
|
||||
debug!(
|
||||
"Isc registers {:?}",
|
||||
rtl::runtime::InfoIsc::from_rtl(&mut hpu_hw, ®map)
|
||||
|
||||
@@ -77,6 +77,8 @@ pub struct RtlConfig {
|
||||
pub bpip_use_opportunism: bool,
|
||||
/// Timeout value to start Bpip even if batch isn't full
|
||||
pub bpip_timeout: u32,
|
||||
/// Use modulus switch mean compensation
|
||||
pub mod_switch_mean_comp: bool,
|
||||
}
|
||||
|
||||
/// On-board memory configuration
|
||||
|
||||
@@ -416,6 +416,21 @@ pub const MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C: HpuPBSParameters = HpuPBSPa
|
||||
ciphertext_width: 64,
|
||||
};
|
||||
|
||||
pub const MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47: HpuPBSParameters = HpuPBSParameters {
|
||||
lwe_dimension: 879,
|
||||
glwe_dimension: 1,
|
||||
polynomial_size: 2048,
|
||||
lwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(3),
|
||||
glwe_noise_distribution: HpuNoiseDistributionInput::TUniformBound(17),
|
||||
pbs_base_log: 23,
|
||||
pbs_level: 1,
|
||||
ks_base_log: 2,
|
||||
ks_level: 8,
|
||||
message_width: 2,
|
||||
carry_width: 2,
|
||||
ciphertext_width: 64,
|
||||
};
|
||||
|
||||
impl FromRtl for HpuPBSParameters {
|
||||
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
|
||||
let pbs_app = regmap
|
||||
@@ -456,6 +471,7 @@ impl FromRtl for HpuPBSParameters {
|
||||
11 => MSG2_CARRY2_TUNIFORM,
|
||||
12 => MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA,
|
||||
13 => MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C,
|
||||
14 => MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47,
|
||||
_ => panic!("Unknown TfheAppName encoding"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -932,3 +932,33 @@ impl ErrorHpu {
|
||||
self.error_3in3 = ffi_hw.read_reg(*reg.offset() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct InfoKeyswitch {
|
||||
/// Use modulus switch mean compensation
|
||||
mod_switch_mean_comp: bool,
|
||||
}
|
||||
|
||||
impl FromRtl for InfoKeyswitch {
|
||||
fn from_rtl(ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) -> Self {
|
||||
// Info structure have method to update
|
||||
// Instead of redefine parsing here, use a default construct and update methods
|
||||
let mut infos = Self::default();
|
||||
infos.update(ffi_hw, regmap);
|
||||
infos
|
||||
}
|
||||
}
|
||||
|
||||
impl InfoKeyswitch {
|
||||
pub fn update_mod_switch_mean_comp(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
|
||||
let reg = regmap
|
||||
.register()
|
||||
.get("keyswitch::config")
|
||||
.expect("Unknown register, check regmap definition");
|
||||
self.mod_switch_mean_comp = ffi_hw.read_reg(*reg.offset() as u64) != 0;
|
||||
}
|
||||
|
||||
pub fn update(&mut self, ffi_hw: &mut ffi::HpuHw, regmap: &FlatRegmap) {
|
||||
self.update_mod_switch_mean_comp(ffi_hw, regmap);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
[pbs_params]
|
||||
lwe_dimension=879
|
||||
glwe_dimension=1
|
||||
polynomial_size=2048
|
||||
lwe_noise_distribution={TUniformBound= 3}
|
||||
glwe_noise_distribution={TUniformBound= 17}
|
||||
pbs_base_log= 23
|
||||
pbs_level= 1
|
||||
ks_base_log= 2
|
||||
ks_level= 8
|
||||
message_width= 2
|
||||
carry_width= 2
|
||||
ciphertext_width= 64
|
||||
opportunistic=true
|
||||
|
||||
[ntt_params]
|
||||
core_arch= {GF64=[5,6]}
|
||||
min_pbs_nb= 11
|
||||
batch_pbs_nb= 12
|
||||
total_pbs_nb= 32
|
||||
ct_width= 64
|
||||
radix= 2
|
||||
stg_nb= 11
|
||||
prime_modulus= "GF64"
|
||||
psi= 64
|
||||
delta= 5
|
||||
|
||||
[ks_params]
|
||||
width= 21
|
||||
lbx= 3
|
||||
lby= 64
|
||||
lbz= 3
|
||||
|
||||
[pc_params]
|
||||
ksk_pc= 16
|
||||
ksk_bytes_w= 32
|
||||
bsk_pc= 8
|
||||
bsk_bytes_w= 32
|
||||
pem_pc= 2
|
||||
pem_bytes_w= 32
|
||||
glwe_bytes_w= 32
|
||||
|
||||
[regf_params]
|
||||
reg_nb= 64
|
||||
coef_nb= 32
|
||||
[isc_params]
|
||||
min_iop_size= 4
|
||||
depth= 64
|
||||
@@ -29,8 +29,6 @@ use modules::{DdrMem, HbmBank, RegisterEvent, RegisterMap, UCore, HBM_BANK_NB};
|
||||
use tfhe::tfhe_hpu_backend::interface::io_dump::HexMem;
|
||||
use tfhe::tfhe_hpu_backend::prelude::*;
|
||||
|
||||
use serde_json;
|
||||
|
||||
pub struct HpuSim {
|
||||
config: HpuConfig,
|
||||
params: MockupParameters,
|
||||
@@ -178,7 +176,7 @@ impl HpuSim {
|
||||
}
|
||||
RegisterReq::PbsParams => {
|
||||
self.ipc.register_ack(RegisterAck::PbsParams(
|
||||
self.params.rtl_params.pbs_params.clone(),
|
||||
self.params.rtl_params.pbs_params,
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -672,7 +670,7 @@ impl HpuSim {
|
||||
// Compute Lut properties
|
||||
let (modulus_sup, box_size, fn_stride) = {
|
||||
let pbs_p = &self.params.rtl_params.pbs_params;
|
||||
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
|
||||
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
|
||||
let box_size = pbs_p.polynomial_size / modulus_sup;
|
||||
// Max valid degree for a ciphertext when using the LUT we generate
|
||||
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
||||
@@ -725,7 +723,7 @@ impl HpuSim {
|
||||
|
||||
let bfr_after_ms = lwe_ciphertext_modulus_switch(bfr_after_ks.as_view(), log_modulus);
|
||||
|
||||
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, &bsk);
|
||||
blind_rotate_ntt64_bnf_assign(&bfr_after_ms, &mut tfhe_lut, bsk);
|
||||
|
||||
assert_eq!(
|
||||
dst_rid.0,
|
||||
@@ -736,7 +734,7 @@ impl HpuSim {
|
||||
// Compute ManyLut function stride
|
||||
let fn_stride = {
|
||||
let pbs_p = &self.params.rtl_params.pbs_params;
|
||||
let modulus_sup = 1_usize << pbs_p.message_width + pbs_p.carry_width;
|
||||
let modulus_sup = 1_usize << (pbs_p.message_width + pbs_p.carry_width);
|
||||
let box_size = pbs_p.polynomial_size / modulus_sup;
|
||||
// Max valid degree for a ciphertext when using the LUT we generate
|
||||
// If MaxDegree == 1, we can have two input values 0 and 1, so we need MaxDegree + 1
|
||||
|
||||
@@ -127,7 +127,7 @@ impl RegisterMap {
|
||||
(ntt_p.batch_pbs_nb + (ntt_p.total_pbs_nb << 8)) as u32
|
||||
}
|
||||
"info::ntt_modulo" => {
|
||||
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus.clone() as u8) as u32
|
||||
MOD_NTT_NAME_OFS + (self.rtl_params.ntt_params.prime_modulus as u8) as u32
|
||||
}
|
||||
|
||||
"info::application" => {
|
||||
@@ -147,6 +147,10 @@ impl RegisterMap {
|
||||
APPLICATION_NAME_OFS + 11
|
||||
} else if MSG2_CARRY2_PFAIL64_132B_GAUSSIAN_1F72DBA == self.rtl_params.pbs_params {
|
||||
APPLICATION_NAME_OFS + 12
|
||||
} else if MSG2_CARRY2_PFAIL64_132B_TUNIFORM_7E47D8C == self.rtl_params.pbs_params {
|
||||
APPLICATION_NAME_OFS + 13
|
||||
} else if MSG2_CARRY2_PFAIL128_132B_TUNIFORM_144A47 == self.rtl_params.pbs_params {
|
||||
APPLICATION_NAME_OFS + 14
|
||||
} else {
|
||||
// Custom simulation parameters set
|
||||
// -> Return 1 without NAME_OFS
|
||||
|
||||
Reference in New Issue
Block a user