Files
tfhe-rs/tfhe/tests/hpu.rs

985 lines
31 KiB
Rust

#![allow(clippy::unnecessary_cast)]
//! Define a test-harness that handle setup and configuration of Hpu Backend
//! The test harness take a list of testcase and run them
//! A testcase simply bind a IOp to a closure describing it's behavior
//! WARN: Only one Hpu could be use at a time, thus all test must be run sequentially
#[cfg(feature = "hpu")]
mod hpu_test {
use std::str::FromStr;
use rand::rngs::StdRng;
use rand::{Rng, RngCore, SeedableRng};
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
use tfhe::Seed;
pub use tfhe_hpu_backend::prelude::*;
/// Variable to store initialized HpuDevice and associated client key for fast iteration
static HPU_DEVICE_RNG_CKS: std::sync::OnceLock<(
std::sync::Mutex<HpuDevice>,
tfhe::integer::ClientKey,
u128,
)> = std::sync::OnceLock::new();
// // Instantiate a shared rng for cleartext input generation
// let rng: StdRng = SeedableRng::seed_from_u64((seed & u64::MAX as u128) as u64);
/// Simple function used to retrieved or generate a seed from environment
fn get_or_init_seed(name: &str) -> u128 {
match std::env::var(name) {
Ok(var) => if let Some(hex) = var.strip_prefix("0x").or_else(|| var.strip_prefix("0X"))
{
u128::from_str_radix(hex, 16)
} else if let Some(bin) = var.strip_prefix("0b").or_else(|| var.strip_prefix("0B")) {
u128::from_str_radix(bin, 2)
} else if let Some(oct) = var.strip_prefix("0o").or_else(|| var.strip_prefix("0O")) {
u128::from_str_radix(oct, 8)
} else {
var.parse::<u128>() // default: base 10
}
.unwrap_or_else(|_| panic!("{name} env variable {var} couldn't be casted in u128")),
_ => {
// Use tread_rng to generate the seed
let lsb = rand::thread_rng().next_u64() as u128;
let msb = rand::thread_rng().next_u64() as u128;
(msb << u64::BITS) | lsb
}
}
}
fn init_hpu_and_associated_material(
) -> (std::sync::Mutex<HpuDevice>, tfhe::integer::ClientKey, u128) {
// Hpu io dump for debug -------------------------------------------------
#[cfg(feature = "hpu-debug")]
if let Some(dump_path) = std::env::var("HPU_IO_DUMP").ok() {
set_hpu_io_dump(&dump_path);
}
// Instantiate HpuDevice --------------------------------------------------
let hpu_device = {
let config_file = ShellString::new(
"${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml".to_string(),
);
HpuDevice::from_config(&config_file.expand())
};
// Check if user force a seed for the key generation
let key_seed = get_or_init_seed("HPU_KEY_SEED");
// Force key seeder for easily reproduce failure
let mut key_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(key_seed));
let shortint_engine =
tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut key_seeder);
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
std::mem::replace(engine, shortint_engine)
});
// Extract pbs_configuration from Hpu and create Client/Server Key
let cks = tfhe::integer::ClientKey::new(
tfhe::shortint::parameters::KeySwitch32PBSParameters::from(hpu_device.params()),
);
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks);
// Init Hpu device with server key and firmware
tfhe::integer::hpu::init_device(&hpu_device, sks_compressed).expect("Invalid key");
(std::sync::Mutex::new(hpu_device), cks, key_seed)
}
const DEFAULT_TEST_ITER: usize = 32;
macro_rules! hpu_testbundle {
($base_name: literal::$integer_width:tt => [$($testcase: literal),+]) => {
::paste::paste! {
#[test]
pub fn [<hpu_test_ $base_name:lower _u $integer_width>]() {
// Register tracing subscriber that use env-filter
// Discard error ( mainly due to already registered subscriber)
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.compact()
.with_file(false)
.with_line_number(false)
.without_time()
.try_init();
// Retrieved test iteration from environment ----------------------------
let hpu_test_iter = match(std::env::var("HPU_TEST_ITER")){
Ok(var) => usize::from_str(&var).unwrap_or_else(|_| {
panic!("HPU_TEST_ITER env variable {var} couldn't be casted in usize")
}),
_ => DEFAULT_TEST_ITER
};
// Retrieved HpuDevice or init ---------------------------------------------
let (hpu_mutex, cks, key_seed)= HPU_DEVICE_RNG_CKS.get_or_init(init_hpu_and_associated_material);
let mut hpu_device = hpu_mutex.lock().expect("Error with HpuDevice Mutex");
assert!(hpu_device.config().firmware.integer_w.contains(&($integer_width as usize)), "Current Hpu configuration doesn't support {}b integer [has {:?}]", $integer_width, hpu_device.config().firmware.integer_w);
// Instantiate a Rng for cleartest input generation
// Create a fresh one for each testbundle to be reproducible even if execution order
// of testbundle are not stable
let test_seed = get_or_init_seed("HPU_TEST_SEED");
// Display used seed value in a reusable manner (i.e. valid bash syntax)
println!("HPU_KEY_SEED={key_seed} #[i.e. 0x{key_seed:x}]");
println!("HPU_TEST_SEED={test_seed} #[i.e. 0x{test_seed:x}]");
let mut rng: StdRng = SeedableRng::seed_from_u64((test_seed & u64::MAX as u128) as u64);
// Reseed shortint engine for reproducible noise generation.
let mut noise_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(test_seed));
let shortint_engine =
tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut noise_seeder);
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
std::mem::replace(engine, shortint_engine)
});
// Run test-case ---------------------------------------------------------
let mut acc_status = true;
$(
{
let status = [<hpu_ $testcase _u $integer_width>](hpu_test_iter, &mut hpu_device, &mut rng, &cks);
if !status {
println!("Error: in testcase {}", stringify!([<hpu_ $testcase _u $integer_width>]));
}
acc_status &= status
}
)*
drop(hpu_device);
assert!(acc_status, "At least one testcase failed in the testbundle");
}
}
};
}
macro_rules! hpu_testcase {
($iop: literal => [$($user_type: ty),+] |$ct:ident, $imm: ident| $behav: expr) => {
::paste::paste! {
$(
#[cfg(feature = "hpu")]
#[allow(unused)]
pub fn [<hpu_ $iop:lower _ $user_type>](iter: usize, device: &mut HpuDevice, rng: &mut StdRng, cks: &tfhe::integer::ClientKey) -> bool {
use tfhe::integer::hpu::ciphertext::HpuRadixCiphertext;
// Check if user ask for test over trivial ciphertext
let (test_trivial, sks) = match(std::env::var("HPU_TEST_TRIVIAL")){
Ok(var) => {
let flag_val = usize::from_str(&var).unwrap_or_else(|_| {
panic!("HPU_TEST_TRIVIAL env variable {var} couldn't be casted in usize")
});
let sks_compressed =
tfhe::integer::ServerKey::new_radix_server_key(&cks);
(flag_val != 0, Some(sks_compressed))
},
_ => (false, None)
};
let iop = hpu_asm::AsmIOpcode::from_str($iop).expect("Invalid AsmIOpcode ");
let proto = if let Some(format) = iop.format() {
format.proto.clone()
} else {
eprintln!("Hpu testcase only work on specified operations. Check test definition");
return false;
};
let width = $user_type::BITS as usize;
let num_block = width / device.params().pbs_params.message_width;
(0..iter).map(|_| {
// Generate inputs ciphertext
let (srcs_clear, srcs_enc): (Vec<_>, Vec<_>) = proto
.src
.iter()
.enumerate()
.map(|(pos, mode)| {
let (bw, block) = match mode {
hpu_asm::iop::VarMode::Native => (width, num_block),
hpu_asm::iop::VarMode::Half => (width / 2, num_block / 2),
hpu_asm::iop::VarMode::Bool => (1, 1),
};
let clear = rng.gen_range(0..=$user_type::MAX >> ($user_type::BITS - (bw as u32)));
let fhe = if test_trivial {
sks.as_ref().unwrap().create_trivial_radix(clear, block)
} else {
cks.encrypt_radix(clear, block)
};
let hpu_fhe = HpuRadixCiphertext::from_radix_ciphertext(&fhe, device);
(clear, hpu_fhe)
})
.unzip();
let imms = (0..proto.imm)
.map(|pos| rng.gen_range(0..$user_type::MAX) as u128)
.collect::<Vec<_>>();
// execute on Hpu
let res_hpu = HpuRadixCiphertext::exec(&proto, iop.opcode(), &srcs_enc, &imms);
let res_fhe = res_hpu
.iter()
.map(|x| x.to_radix_ciphertext()).collect::<Vec<_>>();
let res = res_fhe
.iter()
.map(|x| cks.decrypt_radix(x))
.collect::<Vec<$user_type>>();
let exp_res = {
let $ct = &srcs_clear;
let $imm = imms.iter().map(|x| *x as $user_type).collect::<Vec<_>>();
($behav.iter().map(|x| *x as $user_type).collect::<Vec<_>>())
};
println!("{:>8} <{:>8x?}> <{:>8x?}> => {:<8x?} [exp {:<8x?}] {{Delta: {:x?} }}", iop, srcs_clear, imms, res, exp_res, std::iter::zip(res.iter(), exp_res.iter()).map(|(x,y)| x ^y).collect::<Vec<_>>());
std::iter::zip(res.iter(), exp_res.iter()).map(|(x,y)| x== y).fold(true, |acc, val| acc & val)
}).fold(true, |acc, val| acc & val)
}
)*
}
};
}
// Define testcase implementation for all supported IOp
// Alu IOp with Ct x Imm
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_add(imm[0])]);
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_sub(imm[0])]);
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
|ct, imm| [imm[0].wrapping_sub(ct[0])]);
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_mul(imm[0])]);
hpu_testcase!("DIVS" => [u8, u16, u32, u64, u128]
|ct, imm| if imm[0] == 0 {[0, ct[0]]} else {[ct[0].wrapping_div(imm[0]), ct[0] % imm[0]]});
hpu_testcase!("MODS" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] % imm[0]]);
// Version with overflow flag
hpu_testcase!("OVF_ADDS" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_add(imm[0]);
[res, flag.into()]
});
hpu_testcase!("OVF_SUBS" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_sub(imm[0]);
[res, flag.into()]
});
hpu_testcase!("OVF_SSUB" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = imm[0].overflowing_sub(ct[0]);
[res, flag.into()]
});
hpu_testcase!("OVF_MULS" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_mul(imm[0]);
[res, flag.into()]
});
// Shift/Rotation with Scalar IOp
hpu_testcase!("SHIFTS_R" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_shr(imm[0] as u32)] );
hpu_testcase!("SHIFTS_L" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_shl(imm[0] as u32)] );
hpu_testcase!("ROTS_R" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].rotate_right(imm[0] as u32)] );
hpu_testcase!("ROTS_L" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].rotate_left(imm[0] as u32)] );
// Alu IOp with Ct x Ct
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_add(ct[1])]);
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_sub(ct[1])]);
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_mul(ct[1])]);
hpu_testcase!("DIV" => [u8, u16, u32, u64, u128]
|ct, imm| if ct[1] == 0 {[0, ct[0]]} else {[ct[0].wrapping_div(ct[1]), ct[0] % ct[1]]});
hpu_testcase!("MOD" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] % ct[1]]);
hpu_testcase!("OVF_ADD" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_add(ct[1]);
[res, flag.into()]
});
hpu_testcase!("OVF_SUB" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_sub(ct[1]);
[res, flag.into()]
});
hpu_testcase!("OVF_MUL" => [u8, u16, u32, u64, u128]
|ct, imm| {
let (res, flag) = ct[0].overflowing_mul(ct[1]);
[res, flag.into()]
});
// Shift/Rotation IOp
hpu_testcase!("SHIFT_R" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_shr(ct[1] as u32)] );
hpu_testcase!("SHIFT_L" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].wrapping_shl(ct[1] as u32)] );
hpu_testcase!("ROT_R" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].rotate_right(ct[1] as u32)] );
hpu_testcase!("ROT_L" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].rotate_left(ct[1] as u32)] );
// Bitwise IOp
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] & ct[1]]);
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] | ct[1]]);
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] ^ ct[1]]);
// Comparison IOp
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] > ct[1]]);
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] >= ct[1]]);
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] < ct[1]]);
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] <= ct[1]]);
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] == ct[1]]);
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0] != ct[1]]);
// Ternary IOp
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
|ct, imm| [if ct[1] != 0 {ct[0]} else { 0}]);
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
|ct, imm| [if ct[2] != 0 {ct[0]} else { ct[1]}]);
// ERC 20 found xfer
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|ct, imm| {
let from = ct[0];
let to = ct[1];
let amount = ct[2];
// TODO enhance this to prevent overflow
if from >= amount {
vec![from - amount, to.wrapping_add(amount)]
} else {
vec![from, to]
}
});
// Bit count IOp
hpu_testcase!("COUNT0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].count_zeros()]);
hpu_testcase!("COUNT1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].count_ones()]);
hpu_testcase!("ILOG2" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].ilog2()]);
hpu_testcase!("LEAD0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].leading_zeros()]);
hpu_testcase!("LEAD1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].leading_ones()]);
hpu_testcase!("TRAIL0" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].trailing_zeros()]);
hpu_testcase!("TRAIL1" => [u8, u16, u32, u64, u128]
|ct, imm| [ct[0].trailing_ones()]);
// Define a set of test bundle for various size
// 8bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::8 => [
"adds",
"subs",
"ssub",
"muls",
"divs",
"mods"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alus"::8 => [
"ovf_adds",
"ovf_subs",
"ovf_ssub",
"ovf_muls"
]);
// NB: Scalar Rot/Shift not supported yet
// #[cfg(feature = "hpu")]
// hpu_testbundle!("rots"::8 => [
// "rots_r",
// "rots_l"
// ]);
// #[cfg(feature = "hpu")]
// hpu_testbundle!("shifts"::8 => [
// "shifts_r",
// "shifts_l"
// ]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::8 => [
"add",
"sub",
"mul",
"div",
"mod"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alu"::8 => [
"ovf_add",
"ovf_sub",
"ovf_mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("rot"::8 => [
"rot_r",
"rot_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("shift"::8 => [
"shift_r",
"shift_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::8 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::8 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::8 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::8 => [
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::8 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 16bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::16 => [
"adds",
"subs",
"ssub",
"muls",
"divs",
"mods"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alus"::16 => [
"ovf_adds",
"ovf_subs",
"ovf_ssub",
"ovf_muls"
]);
// NB: Scalar Rot/Shift not supported yet
// #[cfg(feature = "hpu")]
// hpu_testbundle!("rots"::16 => [
// "rots_r",
// "rots_l"
// ]);
// #[cfg(feature = "hpu")]
// hpu_testbundle!("shifts"::16 => [
// "shifts_r",
// "shifts_l"
// ]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::16 => [
"add",
"sub",
"mul",
"div",
"mod"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alu"::16 => [
"ovf_add",
"ovf_sub",
"ovf_mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("rot"::16 => [
"rot_r",
"rot_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("shift"::16 => [
"shift_r",
"shift_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::16 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::16 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::16 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::16 => [
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::16 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 32bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::32 => [
"adds",
"subs",
"ssub",
"muls",
"divs",
"mods"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alus"::32 => [
"ovf_adds",
"ovf_subs",
"ovf_ssub",
"ovf_muls"
]);
// NB: Scalar Rot/Shift not supported yet
// #[cfg(feature = "hpu")]
// hpu_testbundle!("rots"::32 => [
// "rots_r",
// "rots_l"
// ]);
// #[cfg(feature = "hpu")]
// hpu_testbundle!("shifts"::32 => [
// "shifts_r",
// "shifts_l"
// ]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::32 => [
"add",
"sub",
"mul",
"div",
"mod"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alu"::32 => [
"ovf_add",
"ovf_sub",
"ovf_mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("rot"::32 => [
"rot_r",
"rot_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("shift"::32 => [
"shift_r",
"shift_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::32 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::32 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::32 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::32 => [
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::32 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 64bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::64 => [
"adds",
"subs",
"ssub",
"muls",
"divs",
"mods"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alus"::64 => [
"ovf_adds",
"ovf_subs",
"ovf_ssub",
"ovf_muls"
]);
// NB: Scalar Rot/Shift not supported yet
// #[cfg(feature = "hpu")]
// hpu_testbundle!("rots"::64 => [
// "rots_r",
// "rots_l"
// ]);
// #[cfg(feature = "hpu")]
// hpu_testbundle!("shifts"::64 => [
// "shifts_r",
// "shifts_l"
// ]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::64 => [
"add",
"sub",
"mul",
"div",
"mod"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alu"::64 => [
"ovf_add",
"ovf_sub",
"ovf_mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("rot"::64 => [
"rot_r",
"rot_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("shift"::64 => [
"shift_r",
"shift_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::64 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::64 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::64 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::64 => [
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::64 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
// 128bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::128 => [
"adds",
"subs",
"ssub",
"muls",
"divs",
"mods"
]);
hpu_testbundle!("ovf_alus"::128 => [
"ovf_adds",
"ovf_subs",
"ovf_ssub",
"ovf_muls"
]);
// NB: Scalar Rot/Shift not supported yet
// #[cfg(feature = "hpu")]
// hpu_testbundle!("rots"::128 => [
// "rots_r",
// "rots_l"
// ]);
// #[cfg(feature = "hpu")]
// hpu_testbundle!("shifts"::128 => [
// "shifts_r",
// "shifts_l"
// ]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::128 => [
"add",
"sub",
"mul",
"div",
"mod"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ovf_alu"::128 => [
"ovf_add",
"ovf_sub",
"ovf_mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("rot"::128 => [
"rot_r",
"rot_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("shift"::128 => [
"shift_r",
"shift_l"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::128 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::128 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::128 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::128 => [
"erc_20"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cntbit"::128 => [
"count0",
"count1",
"ilog2",
"lead0",
"lead1",
"trail0",
"trail1"
]);
/// Simple test dedicated to check entities conversion from/to Cpu
#[cfg(feature = "hpu")]
#[test]
fn hpu_key_loopback() {
use tfhe::core_crypto::prelude::*;
use tfhe::*;
use tfhe_hpu_backend::prelude::*;
// Retrieved HpuDevice or init ---------------------------------------------
// Used hpu_device backed in static variable to automatically serialize tests
let (hpu_params, cks, key_seed) = {
let (hpu_mutex, cks, key_seed) =
HPU_DEVICE_RNG_CKS.get_or_init(init_hpu_and_associated_material);
let hpu_device = hpu_mutex.lock().expect("Error with HpuDevice Mutex");
(hpu_device.params().clone(), cks, key_seed)
};
println!("HPU_KEY_SEED={key_seed} #[i.e. 0x{key_seed:x}]");
// Generate Keys ---------------------------------------------------------
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(cks)
.into_raw_parts();
// Unwrap compressed key ---------------------------------------------------
let ap_key = match sks_compressed.compressed_ap_server_key {
tfhe::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey::Standard(_) => {
panic!("Hpu not support Standard keys. Required a KeySwitch32 keys")
}
tfhe::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey::KeySwitch32(keys) => keys,
};
// KSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let cpu_ksk_orig = ap_key
.key_switching_key()
.clone()
.decompress_into_lwe_keyswitch_key();
let hpu_ksk =
HpuLweKeyswitchKeyOwned::create_from(cpu_ksk_orig.as_view(), hpu_params.clone());
let cpu_ksk_lb = LweKeyswitchKeyOwned::<u32>::from(hpu_ksk.as_view());
// NB: Some hw modifications such as bit shrinki couldn't be reversed
// cpu_ksk_orig.as_mut().iter_mut().for_each(|coef| {
// let ks_p = hpu_params.ks_params;
// // Apply Hw rounding
// // Extract info bits and rounding if required
// let coef_info = *coef >> (u32::BITS - ks_p.width as u32);
// let coef_rounding = if (ks_p.width as u32) < u32::BITS {
// (*coef >> (u32::BITS - (ks_p.width + 1) as u32)) & 0x1
// } else {
// 0
// };
// *coef = (coef_info + coef_rounding) << (u32::BITS - ks_p.width as u32);
// });
let ksk_mismatch: usize =
std::iter::zip(cpu_ksk_orig.as_ref().iter(), cpu_ksk_lb.as_ref().iter())
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("Ksk mismatch @{i}:: {x:x} != {y:x}");
1
} else {
0
}
})
.sum();
// BSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let cpu_bsk_orig = match ap_key.bootstrapping_key() {
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::Classic {
bsk: seeded_bsk,
..
} => seeded_bsk.clone().decompress_into_lwe_bootstrap_key(),
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::MultiBit { .. } => {
panic!("Hpu currently not support multibit. Required a Classic BSK")
}
};
let cpu_bsk_ntt = {
// Convert the LweBootstrapKey in Ntt domain
let mut ntt_bsk = NttLweBootstrapKeyOwned::<u64>::new(
0_u64,
cpu_bsk_orig.input_lwe_dimension(),
cpu_bsk_orig.glwe_size(),
cpu_bsk_orig.polynomial_size(),
cpu_bsk_orig.decomposition_base_log(),
cpu_bsk_orig.decomposition_level_count(),
CiphertextModulus::new(u64::from(&hpu_params.ntt_params.prime_modulus) as u128),
);
// Conversion to ntt domain
par_convert_standard_lwe_bootstrap_key_to_ntt64(
&cpu_bsk_orig,
&mut ntt_bsk,
NttLweBootstrapKeyOption::Raw,
);
ntt_bsk
};
let hpu_bsk =
HpuLweBootstrapKeyOwned::create_from(cpu_bsk_orig.as_view(), hpu_params.clone());
let cpu_bsk_lb = NttLweBootstrapKeyOwned::from(hpu_bsk.as_view());
let bsk_mismatch: usize = std::iter::zip(
cpu_bsk_ntt.as_view().into_container().iter(),
cpu_bsk_lb.as_view().into_container().iter(),
)
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("@{i}:: {x:x} != {y:x}");
1
} else {
0
}
})
.sum();
println!("Ksk loopback with {ksk_mismatch} errors");
println!("Bsk loopback with {bsk_mismatch} errors");
assert_eq!(ksk_mismatch + bsk_mismatch, 0);
}
}