mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
fix(hpu): First round of review by IceTDrinker
Still WIP, all remarks not taken into account yet
This commit is contained in:
@@ -15,6 +15,7 @@ members = [
|
||||
"utils/param_dedup",
|
||||
"tests",
|
||||
"mockups/tfhe-hpu-mockup",
|
||||
"tests",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
|
||||
3
Makefile
3
Makefile
@@ -56,8 +56,7 @@ TFHECUDA_SRC=backends/tfhe-cuda-backend/cuda
|
||||
TFHECUDA_BUILD=$(TFHECUDA_SRC)/build
|
||||
|
||||
# tfhe-hpu-backend
|
||||
CUR_SCRIPT_DIR=$(shell cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
|
||||
HPU_BACKEND_DIR=$(CUR_SCRIPT_DIR)/backends/tfhe-hpu-backend
|
||||
HPU_BACKEND_DIR=backends/tfhe-hpu-backend
|
||||
HPU_CONFIG=v80
|
||||
V80_PCIE_DEV=$(shell lspci -d 10ee:50b5 | sed -e "s/\(..\).*/\1/")
|
||||
|
||||
|
||||
@@ -3,11 +3,15 @@ name = "tfhe-hpu-backend"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Zama Hardware team"]
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "HPU implementation on FPGA of TFHE-rs primitives."
|
||||
homepage = "https://www.zama.ai/"
|
||||
documentation = "https://docs.zama.ai/tfhe-rs"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
readme = "README.md"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography", "hardware", "fpga"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
hw-xrt = []
|
||||
hw-v80 = []
|
||||
io-dump = ["num-traits"]
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
[package]
|
||||
name = "hpu-sim"
|
||||
name = "tfhe-hpu-mockup"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Zama Hardware team"]
|
||||
license = "BSD-3-Clause-Clear"
|
||||
description = "Simulation model of HPU hardware."
|
||||
readme = "README.md"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -81,7 +81,7 @@ Other optional configuration knobs are available:
|
||||
On top of that `tfhe-hpu-mockup` could generate a detailed set of trace points at runtime to help during the debug/exploration phase (e.g. When writing new Hpu firmware).
|
||||
Those trace points rely on `tokio-tracing` and could be activated on a path::verbosity based through the `RUST_LOG` environment variable.
|
||||
For example the following value will enable the info trace for all the design and the debug one for the ucore submodule:
|
||||
`RUST_LOG=info,hpu_sim::modules::ucore=debug`.
|
||||
`RUST_LOG=info,tfhe_hpu_mockup::modules::ucore=debug`.
|
||||
|
||||
> NB: With the mockup estimated IOp performances must be read from the mockup log, not from the user application report.
|
||||
> Indeed, the user application reports the execution time of the mockup binary not the expected performance on real Hpu hardware.
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
use std::fs::OpenOptions;
|
||||
use std::path::Path;
|
||||
|
||||
use hpu_sim::{HpuSim, MockupOptions, MockupParameters};
|
||||
use tfhe::tfhe_hpu_backend::prelude::*;
|
||||
use tfhe_hpu_mockup::{HpuSim, MockupOptions, MockupParameters};
|
||||
|
||||
/// Define CLI arguments
|
||||
use clap::Parser;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/bash
|
||||
#! /usr/bin/env/ bash
|
||||
|
||||
# Find current script directory. This should be PROJECT_DIR
|
||||
CUR_SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
|
||||
|
||||
@@ -1,103 +1,3 @@
|
||||
#[cfg(feature = "hpu")]
|
||||
fn main() {
|
||||
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
|
||||
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{set_server_key, ClientKey, CompressedServerKey, Config, FheUint8, *};
|
||||
use tfhe_hpu_backend::prelude::*;
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
pub use clap::Parser;
|
||||
/// Define CLI arguments
|
||||
#[derive(clap::Parser, Debug, Clone, serde::Serialize)]
|
||||
#[clap(long_about = "HPU example that shows the use of the HighLevelAPI.")]
|
||||
pub struct Args {
|
||||
// Fpga configuration ------------------------------------------------------
|
||||
/// Toml top-level configuration file
|
||||
#[clap(
|
||||
long,
|
||||
value_parser,
|
||||
default_value = "${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml"
|
||||
)]
|
||||
pub config: ShellString,
|
||||
|
||||
// Exec configuration ----------------------------------------------------
|
||||
/// Seed used for some rngs
|
||||
#[clap(long, value_parser)]
|
||||
pub seed: Option<u128>,
|
||||
}
|
||||
let args = Args::parse();
|
||||
println!("User Options: {args:?}");
|
||||
|
||||
// Register tracing subscriber that use env-filter
|
||||
// Select verbosity with env_var: e.g. `RUST_LOG=Alu=trace`
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.compact()
|
||||
// Display source code file paths
|
||||
.with_file(false)
|
||||
// Display source code line numbers
|
||||
.with_line_number(false)
|
||||
.without_time()
|
||||
// Build & register the subscriber
|
||||
.init();
|
||||
|
||||
// Seeder for args randomization ------------------------------------------
|
||||
let mut rng: StdRng = if let Some(seed) = args.seed {
|
||||
SeedableRng::seed_from_u64((seed & u64::MAX as u128) as u64)
|
||||
} else {
|
||||
SeedableRng::from_entropy()
|
||||
};
|
||||
|
||||
// Instantiate HpuDevice --------------------------------------------------
|
||||
let hpu_device = HpuDevice::from_config(&args.config.expand());
|
||||
|
||||
// Generate keys ----------------------------------------------------------
|
||||
let config = Config::from_hpu_device(&hpu_device);
|
||||
|
||||
// Force key seeder if seed specified by user
|
||||
if let Some(seed) = args.seed {
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
|
||||
let shortint_engine = tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder);
|
||||
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
|
||||
std::mem::replace(engine, shortint_engine)
|
||||
});
|
||||
}
|
||||
|
||||
let cks = ClientKey::generate(config);
|
||||
let csks = CompressedServerKey::new(&cks);
|
||||
|
||||
set_server_key((hpu_device, csks));
|
||||
|
||||
// Show 8bit capabilities --------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint8, u8);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 16bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint16, u16);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 32bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint32, u32);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 64bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint64, u64);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
#[macro_export]
|
||||
macro_rules! impl_hlapi_showcase {
|
||||
($fhe_type: ty, $user_type: ty) => {
|
||||
::paste::paste! {
|
||||
@@ -199,5 +99,99 @@ macro_rules! impl_hlapi_showcase {
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "hpu"))]
|
||||
fn main() {}
|
||||
fn main() {
|
||||
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
|
||||
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{set_server_key, ClientKey, CompressedServerKey, Config, FheUint8, *};
|
||||
use tfhe_hpu_backend::prelude::*;
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
pub use clap::Parser;
|
||||
/// Define CLI arguments
|
||||
#[derive(clap::Parser, Debug, Clone, serde::Serialize)]
|
||||
#[clap(long_about = "HPU example that shows the use of the HighLevelAPI.")]
|
||||
pub struct Args {
|
||||
// Fpga configuration ------------------------------------------------------
|
||||
/// Toml top-level configuration file
|
||||
#[clap(
|
||||
long,
|
||||
value_parser,
|
||||
default_value = "${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml"
|
||||
)]
|
||||
pub config: ShellString,
|
||||
|
||||
// Exec configuration ----------------------------------------------------
|
||||
/// Seed used for some rngs
|
||||
#[clap(long, value_parser)]
|
||||
pub seed: Option<u128>,
|
||||
}
|
||||
let args = Args::parse();
|
||||
println!("User Options: {args:?}");
|
||||
|
||||
// Register tracing subscriber that use env-filter
|
||||
// Select verbosity with env_var: e.g. `RUST_LOG=Alu=trace`
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.compact()
|
||||
// Display source code file paths
|
||||
.with_file(false)
|
||||
// Display source code line numbers
|
||||
.with_line_number(false)
|
||||
.without_time()
|
||||
// Build & register the subscriber
|
||||
.init();
|
||||
|
||||
// Seeder for args randomization ------------------------------------------
|
||||
let mut rng: StdRng = if let Some(seed) = args.seed {
|
||||
SeedableRng::seed_from_u64((seed & u64::MAX as u128) as u64)
|
||||
} else {
|
||||
SeedableRng::from_entropy()
|
||||
};
|
||||
|
||||
// Instantiate HpuDevice --------------------------------------------------
|
||||
let hpu_device = HpuDevice::from_config(&args.config.expand());
|
||||
|
||||
// Generate keys ----------------------------------------------------------
|
||||
let config = Config::from_hpu_device(&hpu_device);
|
||||
|
||||
// Force key seeder if seed specified by user
|
||||
if let Some(seed) = args.seed {
|
||||
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
|
||||
let shortint_engine = tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder);
|
||||
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
|
||||
std::mem::replace(engine, shortint_engine)
|
||||
});
|
||||
}
|
||||
|
||||
let cks = ClientKey::generate(config);
|
||||
let csks = CompressedServerKey::new(&cks);
|
||||
|
||||
set_server_key((hpu_device, csks));
|
||||
|
||||
// Show 8bit capabilities --------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint8, u8);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 16bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint16, u16);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 32bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint32, u32);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
|
||||
// Show 64bit capabilities -------------------------------------------------
|
||||
{
|
||||
impl_hlapi_showcase!(FheUint64, u64);
|
||||
hlapi_showcase(&cks, &mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,8 +176,7 @@ where
|
||||
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
|
||||
res &= mod_mask;
|
||||
// Control bit about whether we should balance the state
|
||||
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random ==
|
||||
// 1)
|
||||
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
|
||||
let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
|
||||
// Balance depending on the control bit
|
||||
res.wrapping_sub(need_balance << rep_bit_count)
|
||||
|
||||
@@ -28,7 +28,6 @@ pub fn msb2lsb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut
|
||||
}
|
||||
/// This function change information position in container
|
||||
/// Move information bits from LSB to MSB
|
||||
#[allow(unused)]
|
||||
pub fn lsb2msb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
|
||||
let ct_width = params.ntt_params.ct_width as usize;
|
||||
let storage_width = Scalar::BITS;
|
||||
@@ -36,29 +35,3 @@ pub fn lsb2msb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut
|
||||
*val <<= storage_width - ct_width;
|
||||
}
|
||||
}
|
||||
|
||||
/// This function switches modulus for a slice of coefficients
|
||||
/// From: user domain (i.e. pow2 modulus)
|
||||
/// To: ntt domain ( i.e. prime modulus)
|
||||
/// Switching are done inplace
|
||||
pub fn user2ntt_modswitch<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
|
||||
let user_width = params.ntt_params.ct_width as usize;
|
||||
let mod_p_u128 = u64::from(¶ms.ntt_params.prime_modulus) as u128;
|
||||
for val in data.iter_mut() {
|
||||
let val_u128: u128 = val.cast_into();
|
||||
*val = Scalar::cast_from(((val_u128 * mod_p_u128) + (1 << (user_width - 1))) >> user_width);
|
||||
}
|
||||
}
|
||||
|
||||
/// This function switches modulus for a slice of coefficients
|
||||
/// From: ntt domain ( i.e. prime modulus)
|
||||
/// To: user domain (i.e. pow2 modulus)
|
||||
/// Switching are done inplace
|
||||
pub fn ntt2user_modswitch<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
|
||||
let user_width = params.ntt_params.ct_width as usize;
|
||||
let mod_p_u128 = u64::from(¶ms.ntt_params.prime_modulus) as u128;
|
||||
for val in data.iter_mut() {
|
||||
let val_u128: u128 = val.cast_into();
|
||||
*val = Scalar::cast_from((((val_u128) << user_width) | ((mod_p_u128) >> 1)) / mod_p_u128);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,27 +6,27 @@
|
||||
//! Both Ntt architecture used reverse order as input
|
||||
//! However, Wmm use an intermediate Network required by the BSK shuffling.
|
||||
|
||||
use crate::core_crypto::prelude::UnsignedInteger;
|
||||
use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, UnsignedInteger};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RadixBasis {
|
||||
radix_lg: usize,
|
||||
digits_nb: usize,
|
||||
radix_lg: DecompositionBaseLog,
|
||||
digits_nb: DecompositionLevelCount,
|
||||
}
|
||||
|
||||
impl RadixBasis {
|
||||
pub fn new(radix: usize, digits_nb: usize) -> Self {
|
||||
let radix_lg = radix.ilog2() as usize;
|
||||
Self {
|
||||
radix_lg,
|
||||
digits_nb,
|
||||
radix_lg: DecompositionBaseLog(radix_lg),
|
||||
digits_nb: DecompositionLevelCount(digits_nb),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn radix_lg(&self) -> usize {
|
||||
pub fn radix_lg(&self) -> DecompositionBaseLog {
|
||||
self.radix_lg
|
||||
}
|
||||
pub fn digits_nb(&self) -> usize {
|
||||
pub fn digits_nb(&self) -> DecompositionLevelCount {
|
||||
self.digits_nb
|
||||
}
|
||||
|
||||
@@ -35,13 +35,13 @@ impl RadixBasis {
|
||||
/// * Nat_order from 0..rank
|
||||
/// * Rev_order from rank..digits
|
||||
pub fn idx_pdrev(&self, digits: usize, rank: usize, nat_val: usize) -> usize {
|
||||
let mask = (1 << (digits - rank) * self.radix_lg) - 1;
|
||||
let to_be_reversed = (nat_val >> (rank * self.radix_lg)) & mask;
|
||||
let reversed = Self::new(1 << self.radix_lg, digits - rank).idx_rev(to_be_reversed);
|
||||
let mask = (1 << (digits - rank) * self.radix_lg.0) - 1;
|
||||
let to_be_reversed = (nat_val >> (rank * self.radix_lg.0)) & mask;
|
||||
let reversed = Self::new(1 << self.radix_lg.0, digits - rank).idx_rev(to_be_reversed);
|
||||
|
||||
let to_be_zeroed = nat_val & (mask << (rank * self.radix_lg));
|
||||
let to_be_zeroed = nat_val & (mask << (rank * self.radix_lg.0));
|
||||
let mut result = nat_val & !to_be_zeroed;
|
||||
result |= reversed << (rank * self.radix_lg);
|
||||
result |= reversed << (rank * self.radix_lg.0);
|
||||
|
||||
result
|
||||
}
|
||||
@@ -54,11 +54,11 @@ impl RadixBasis {
|
||||
|
||||
/// Convert an index expressed in Natural Order into `reverse` Order
|
||||
pub fn idx_rev(&self, mut nat_val: usize) -> usize {
|
||||
let mask = (1 << self.radix_lg) - 1;
|
||||
let mask = (1 << self.radix_lg.0) - 1;
|
||||
let mut result = 0;
|
||||
for i in (0..self.digits_nb).rev() {
|
||||
result |= (nat_val & mask) << (i * self.radix_lg);
|
||||
nat_val >>= self.radix_lg;
|
||||
for i in (0..self.digits_nb.0).rev() {
|
||||
result |= (nat_val & mask) << (i * self.radix_lg.0);
|
||||
nat_val >>= self.radix_lg.0;
|
||||
}
|
||||
|
||||
result
|
||||
@@ -80,7 +80,7 @@ where
|
||||
assert_eq!(src.len(), dst.len(), "Poly src/ dst length mismtach");
|
||||
assert_eq!(
|
||||
src.len(),
|
||||
((1 << rb_conv.radix_lg()) as usize).pow(rb_conv.digits_nb() as u32),
|
||||
((1 << rb_conv.radix_lg().0) as usize).pow(rb_conv.digits_nb().0 as u32),
|
||||
"Poly length mismtach with RadixBasis configuration"
|
||||
);
|
||||
|
||||
@@ -92,7 +92,7 @@ where
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PcgNetwork {
|
||||
stg_nb: usize,
|
||||
stage_nb: usize,
|
||||
rb_conv: RadixBasis,
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ impl PcgNetwork {
|
||||
/// Create network instance from NttParameters
|
||||
pub fn new(radix: usize, stg_nb: usize) -> Self {
|
||||
Self {
|
||||
stg_nb,
|
||||
stage_nb: stg_nb,
|
||||
rb_conv: RadixBasis::new(radix, stg_nb),
|
||||
}
|
||||
}
|
||||
@@ -108,9 +108,11 @@ impl PcgNetwork {
|
||||
/// For a given position idx (in 0..N-1), at processing step delta_idx,
|
||||
/// find the corresponding position idx (consider the input of the node)
|
||||
pub fn get_pos_id(&mut self, delta_idx: usize, pos_idx: usize) -> usize {
|
||||
let node_idx = pos_idx / (1 << self.rb_conv.radix_lg());
|
||||
let rmn_idx = pos_idx % (1 << self.rb_conv.radix_lg());
|
||||
let pdrev_idx = self.rb_conv.idx_pdrev(self.stg_nb - 1, delta_idx, node_idx);
|
||||
pdrev_idx * (1 << self.rb_conv.radix_lg()) + rmn_idx
|
||||
let node_idx = pos_idx / (1 << self.rb_conv.radix_lg().0);
|
||||
let rmn_idx = pos_idx % (1 << self.rb_conv.radix_lg().0);
|
||||
let pdrev_idx = self
|
||||
.rb_conv
|
||||
.idx_pdrev(self.stage_nb - 1, delta_idx, node_idx);
|
||||
pdrev_idx * (1 << self.rb_conv.radix_lg().0) + rmn_idx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ impl<Scalar: UnsignedInteger> FromWith<LweCiphertextView<'_, Scalar>, HpuParamet
|
||||
let pbs_p = ¶ms.pbs_params;
|
||||
let poly_size = pbs_p.polynomial_size;
|
||||
|
||||
// NB: Glwe polynomial must be in reversed order
|
||||
// NB: lwe mask is view as polynomial and must be in reversed order
|
||||
// Allocate translation buffer and reversed vector here
|
||||
let rb_conv = order::RadixBasis::new(ntt_p.radix, ntt_p.stg_nb);
|
||||
let lwe_len = hpu_lwe.len();
|
||||
|
||||
@@ -4,31 +4,27 @@
|
||||
//! WARN: Only one Hpu could be use at a time, thus all test must be run sequentially
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
use std::str::FromStr;
|
||||
mod hpu_test {
|
||||
use std::str::FromStr;
|
||||
|
||||
pub use rand::Rng;
|
||||
pub use rand::Rng;
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
pub use tfhe_hpu_backend::prelude::*;
|
||||
pub use tfhe_hpu_backend::prelude::*;
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
/// Variable to store initialized HpuDevice and associated client key for fast iteration
|
||||
static HPU_DEVICE_CKS: std::sync::OnceLock<(
|
||||
std::sync::Mutex<HpuDevice>,
|
||||
tfhe::integer::ClientKey,
|
||||
)> = std::sync::OnceLock::new();
|
||||
/// Variable to store initialized HpuDevice and associated client key for fast iteration
|
||||
static HPU_DEVICE_CKS: std::sync::OnceLock<(
|
||||
std::sync::Mutex<HpuDevice>,
|
||||
tfhe::integer::ClientKey,
|
||||
)> = std::sync::OnceLock::new();
|
||||
|
||||
// NB: Currently u55c didn't check for workq overflow.
|
||||
// -> Use default value < queue depth to circumvent this limitation
|
||||
// NB': This is only for u55c, on V80 user could set HPU_TEST_ITER to whatever value he want
|
||||
#[cfg(feature = "hpu")]
|
||||
const DEFAULT_TEST_ITER: usize = 32;
|
||||
// NB: Currently u55c didn't check for workq overflow.
|
||||
// -> Use default value < queue depth to circumvent this limitation
|
||||
// NB': This is only for u55c, on V80 user could set HPU_TEST_ITER to whatever value he want
|
||||
const DEFAULT_TEST_ITER: usize = 32;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! hpu_testbundle {
|
||||
macro_rules! hpu_testbundle {
|
||||
($base_name: literal::$integer_width:tt => [$($testcase: literal),+]) => {
|
||||
::paste::paste! {
|
||||
#[cfg(feature = "hpu")]
|
||||
#[test]
|
||||
pub fn [<hpu_test_ $base_name:lower _u $integer_width>]() {
|
||||
// Register tracing subscriber that use env-filter
|
||||
@@ -90,7 +86,7 @@ macro_rules! hpu_testbundle {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! hpu_testcase {
|
||||
macro_rules! hpu_testcase {
|
||||
($iop: literal => [$($user_type: ty),+] |$ct:ident, $imm: ident| $behav: expr) => {
|
||||
::paste::paste! {
|
||||
$(
|
||||
@@ -157,412 +153,417 @@ macro_rules! hpu_testcase {
|
||||
};
|
||||
}
|
||||
|
||||
// Define testcase implementation for all supported IOp
|
||||
// Alu IOp with Ct x Imm
|
||||
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
|
||||
// Define testcase implementation for all supported IOp
|
||||
// Alu IOp with Ct x Imm
|
||||
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_add(imm[0])]);
|
||||
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_sub(imm[0])]);
|
||||
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![imm[0].wrapping_sub(ct[0])]);
|
||||
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_mul(imm[0])]);
|
||||
|
||||
// Alu IOp with Ct x Ct
|
||||
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
|
||||
// Alu IOp with Ct x Ct
|
||||
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_add(ct[1])]);
|
||||
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_sub(ct[1])]);
|
||||
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0].wrapping_mul(ct[1])]);
|
||||
|
||||
// Bitwise IOp
|
||||
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
|
||||
// Bitwise IOp
|
||||
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] & ct[1]]);
|
||||
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] | ct[1]]);
|
||||
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] ^ ct[1]]);
|
||||
|
||||
// Comparison IOp
|
||||
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
|
||||
// Comparison IOp
|
||||
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] > ct[1]]);
|
||||
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] >= ct[1]]);
|
||||
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] < ct[1]]);
|
||||
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] <= ct[1]]);
|
||||
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] == ct[1]]);
|
||||
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![ct[0] != ct[1]]);
|
||||
|
||||
// Ternary IOp
|
||||
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
|
||||
// Ternary IOp
|
||||
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![if ct[1] != 0 {ct[0]} else { 0}]);
|
||||
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
|
||||
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| vec![if ct[2] != 0 {ct[0]} else { ct[1]}]);
|
||||
|
||||
// ERC 20 found xfer
|
||||
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| {
|
||||
let from = ct[0];
|
||||
let to = ct[1];
|
||||
let amount = ct[2];
|
||||
// TODO enhance this to prevent overflow
|
||||
if from >= amount {
|
||||
vec![from - amount, to.wrapping_add(amount)]
|
||||
} else {
|
||||
vec![from, to]
|
||||
}
|
||||
});
|
||||
|
||||
// Define a set of test bundle for various size
|
||||
// 8bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alus"::8 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alu"::8 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("bitwise"::8 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("cmp"::8 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("ternary"::8 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("algo"::8 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 16bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alus"::16 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alu"::16 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("bitwise"::16 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("cmp"::16 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("ternary"::16 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("algo"::16 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 32bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alus"::32 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alu"::32 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("bitwise"::32 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("cmp"::32 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("ternary"::32 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("algo"::32 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 64bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alus"::64 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alu"::64 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("bitwise"::64 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("cmp"::64 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("ternary"::64 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("algo"::64 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 128bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alus"::128 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("alu"::128 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("bitwise"::128 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("cmp"::128 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("ternary"::128 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
crate::hpu_testbundle!("algo"::128 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
/// Simple test dedicated to check entities convertion from/to Cpu
|
||||
#[cfg(feature = "hpu")]
|
||||
#[test]
|
||||
fn hpu_key_loopback() {
|
||||
use tfhe::core_crypto::hpu::from_with::FromWith;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::*;
|
||||
use tfhe_hpu_backend::prelude::*;
|
||||
|
||||
// Retrieved HpuDevice or init ---------------------------------------------
|
||||
// Used hpu_device backed in static variable to automatically serialize tests
|
||||
let (hpu_mutex, cks) = HPU_DEVICE_CKS.get_or_init(|| {
|
||||
// Instantiate HpuDevice --------------------------------------------------
|
||||
let hpu_device = {
|
||||
let config_file = ShellString::new(
|
||||
"${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml".to_string(),
|
||||
);
|
||||
HpuDevice::from_config(&config_file.expand())
|
||||
};
|
||||
|
||||
// Extract pbs_configuration from Hpu and create Client/Server Key
|
||||
let cks = tfhe::integer::ClientKey::new(tfhe::shortint::ClassicPBSParameters::from(
|
||||
hpu_device.params(),
|
||||
));
|
||||
let sks_compressed =
|
||||
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
|
||||
// Init Hpu device with server key and firmware
|
||||
tfhe::integer::hpu::init_device(&hpu_device, sks_compressed.into()).expect("Invalid key");
|
||||
(std::sync::Mutex::new(hpu_device), cks)
|
||||
});
|
||||
let hpu_params = hpu_mutex
|
||||
.lock()
|
||||
.expect("Error with HpuDevice Mutex")
|
||||
.params()
|
||||
.clone();
|
||||
|
||||
// Generate Keys ---------------------------------------------------------
|
||||
let sks_compressed =
|
||||
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks).into_raw_parts();
|
||||
|
||||
// KSK Loopback conversion and check -------------------------------------
|
||||
// Extract and convert ksk
|
||||
let mut cpu_ksk_orig = sks_compressed
|
||||
.key_switching_key
|
||||
.decompress_into_lwe_keyswitch_key();
|
||||
let hpu_ksk = HpuLweKeyswitchKeyOwned::from_with(cpu_ksk_orig.as_view(), hpu_params.clone());
|
||||
let cpu_ksk_lb = LweKeyswitchKeyOwned::from(hpu_ksk.as_view());
|
||||
|
||||
// NB: Some hw modifications such as bit shrinki couldn't be reversed
|
||||
cpu_ksk_orig.as_mut().iter_mut().for_each(|coef| {
|
||||
let ks_p = hpu_params.ks_params;
|
||||
// Apply Hw rounding
|
||||
// Extract info bits and rounding if required
|
||||
let coef_info = *coef >> (u64::BITS - ks_p.width as u32);
|
||||
let coef_rounding = if (ks_p.width as u32) < u64::BITS {
|
||||
(*coef >> (u64::BITS - (ks_p.width + 1) as u32)) & 0x1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
*coef = (coef_info + coef_rounding) << (u64::BITS - ks_p.width as u32);
|
||||
});
|
||||
|
||||
let ksk_mismatch: usize =
|
||||
std::iter::zip(cpu_ksk_orig.as_ref().iter(), cpu_ksk_lb.as_ref().iter())
|
||||
.enumerate()
|
||||
.map(|(i, (x, y))| {
|
||||
if x != y {
|
||||
println!("Ksk mismatch @{i}:: {x:x} != {y:x}");
|
||||
1
|
||||
// ERC 20 found xfer
|
||||
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|
||||
|ct, imm| {
|
||||
let from = ct[0];
|
||||
let to = ct[1];
|
||||
let amount = ct[2];
|
||||
// TODO enhance this to prevent overflow
|
||||
if from >= amount {
|
||||
vec![from - amount, to.wrapping_add(amount)]
|
||||
} else {
|
||||
0
|
||||
vec![from, to]
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
});
|
||||
|
||||
// BSK Loopback conversion and check -------------------------------------
|
||||
// Extract and convert ksk
|
||||
let cpu_bsk_orig = match sks_compressed.bootstrapping_key {
|
||||
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk: seeded_bsk,
|
||||
..
|
||||
} => seeded_bsk.decompress_into_lwe_bootstrap_key(),
|
||||
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::MultiBit { .. } => {
|
||||
panic!("Hpu currently not support multibit. Required a Classic BSK")
|
||||
}
|
||||
};
|
||||
let cpu_bsk_ntt = {
|
||||
// Convert the LweBootstrapKey in Ntt domain
|
||||
let mut ntt_bsk = NttLweBootstrapKeyOwned::<u64>::new(
|
||||
0_u64,
|
||||
cpu_bsk_orig.input_lwe_dimension(),
|
||||
cpu_bsk_orig.glwe_size(),
|
||||
cpu_bsk_orig.polynomial_size(),
|
||||
cpu_bsk_orig.decomposition_base_log(),
|
||||
cpu_bsk_orig.decomposition_level_count(),
|
||||
CiphertextModulus::new(u64::from(&hpu_params.ntt_params.prime_modulus) as u128),
|
||||
);
|
||||
// Define a set of test bundle for various size
|
||||
// 8bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::8 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
// Conversion to ntt domain
|
||||
par_convert_standard_lwe_bootstrap_key_to_ntt64(&cpu_bsk_orig, &mut ntt_bsk);
|
||||
ntt_bsk
|
||||
};
|
||||
let hpu_bsk = HpuLweBootstrapKeyOwned::from_with(cpu_bsk_orig.as_view(), hpu_params.clone());
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alu"::8 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
let cpu_bsk_lb = NttLweBootstrapKeyOwned::from(hpu_bsk.as_view());
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("bitwise"::8 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
let bsk_mismatch: usize = std::iter::zip(
|
||||
cpu_bsk_ntt.as_view().into_container().iter(),
|
||||
cpu_bsk_lb.as_view().into_container().iter(),
|
||||
)
|
||||
.enumerate()
|
||||
.map(|(i, (x, y))| {
|
||||
if x != y {
|
||||
println!("@{i}:: {x:x} != {y:x}");
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cmp"::8 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
println!("Ksk loopback with {ksk_mismatch} errors");
|
||||
println!("Bsk loopback with {bsk_mismatch} errors");
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("ternary"::8 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
assert_eq!(ksk_mismatch + bsk_mismatch, 0);
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("algo"::8 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 16bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::16 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alu"::16 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("bitwise"::16 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cmp"::16 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("ternary"::16 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("algo"::16 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 32bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::32 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alu"::32 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("bitwise"::32 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cmp"::32 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("ternary"::32 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("algo"::32 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 64bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::64 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alu"::64 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("bitwise"::64 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cmp"::64 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("ternary"::64 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("algo"::64 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
// 128bit ciphertext -----------------------------------------
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alus"::128 => [
|
||||
"adds",
|
||||
"subs",
|
||||
"ssub",
|
||||
"muls"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("alu"::128 => [
|
||||
"add",
|
||||
"sub",
|
||||
"mul"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("bitwise"::128 => [
|
||||
"bw_and",
|
||||
"bw_or",
|
||||
"bw_xor"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("cmp"::128 => [
|
||||
"cmp_gt",
|
||||
"cmp_gte",
|
||||
"cmp_lt",
|
||||
"cmp_lte",
|
||||
"cmp_eq",
|
||||
"cmp_neq"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("ternary"::128 => [
|
||||
"if_then_zero",
|
||||
"if_then_else"
|
||||
]);
|
||||
|
||||
#[cfg(feature = "hpu")]
|
||||
hpu_testbundle!("algo"::128 => [
|
||||
"erc_20"
|
||||
]);
|
||||
|
||||
/// Simple test dedicated to check entities convertion from/to Cpu
|
||||
#[cfg(feature = "hpu")]
|
||||
#[test]
|
||||
fn hpu_key_loopback() {
|
||||
use tfhe::core_crypto::hpu::from_with::FromWith;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::*;
|
||||
use tfhe_hpu_backend::prelude::*;
|
||||
|
||||
// Retrieved HpuDevice or init ---------------------------------------------
|
||||
// Used hpu_device backed in static variable to automatically serialize tests
|
||||
let (hpu_mutex, cks) = HPU_DEVICE_CKS.get_or_init(|| {
|
||||
// Instantiate HpuDevice --------------------------------------------------
|
||||
let hpu_device = {
|
||||
let config_file = ShellString::new(
|
||||
"${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml".to_string(),
|
||||
);
|
||||
HpuDevice::from_config(&config_file.expand())
|
||||
};
|
||||
|
||||
// Extract pbs_configuration from Hpu and create Client/Server Key
|
||||
let cks = tfhe::integer::ClientKey::new(tfhe::shortint::ClassicPBSParameters::from(
|
||||
hpu_device.params(),
|
||||
));
|
||||
let sks_compressed =
|
||||
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks);
|
||||
|
||||
// Init Hpu device with server key and firmware
|
||||
tfhe::integer::hpu::init_device(&hpu_device, sks_compressed.into())
|
||||
.expect("Invalid key");
|
||||
(std::sync::Mutex::new(hpu_device), cks)
|
||||
});
|
||||
let hpu_params = hpu_mutex
|
||||
.lock()
|
||||
.expect("Error with HpuDevice Mutex")
|
||||
.params()
|
||||
.clone();
|
||||
|
||||
// Generate Keys ---------------------------------------------------------
|
||||
let sks_compressed =
|
||||
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks)
|
||||
.into_raw_parts();
|
||||
|
||||
// KSK Loopback conversion and check -------------------------------------
|
||||
// Extract and convert ksk
|
||||
let mut cpu_ksk_orig = sks_compressed
|
||||
.key_switching_key
|
||||
.decompress_into_lwe_keyswitch_key();
|
||||
let hpu_ksk =
|
||||
HpuLweKeyswitchKeyOwned::from_with(cpu_ksk_orig.as_view(), hpu_params.clone());
|
||||
let cpu_ksk_lb = LweKeyswitchKeyOwned::from(hpu_ksk.as_view());
|
||||
|
||||
// NB: Some hw modifications such as bit shrinki couldn't be reversed
|
||||
cpu_ksk_orig.as_mut().iter_mut().for_each(|coef| {
|
||||
let ks_p = hpu_params.ks_params;
|
||||
// Apply Hw rounding
|
||||
// Extract info bits and rounding if required
|
||||
let coef_info = *coef >> (u64::BITS - ks_p.width as u32);
|
||||
let coef_rounding = if (ks_p.width as u32) < u64::BITS {
|
||||
(*coef >> (u64::BITS - (ks_p.width + 1) as u32)) & 0x1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
*coef = (coef_info + coef_rounding) << (u64::BITS - ks_p.width as u32);
|
||||
});
|
||||
|
||||
let ksk_mismatch: usize =
|
||||
std::iter::zip(cpu_ksk_orig.as_ref().iter(), cpu_ksk_lb.as_ref().iter())
|
||||
.enumerate()
|
||||
.map(|(i, (x, y))| {
|
||||
if x != y {
|
||||
println!("Ksk mismatch @{i}:: {x:x} != {y:x}");
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
|
||||
// BSK Loopback conversion and check -------------------------------------
|
||||
// Extract and convert ksk
|
||||
let cpu_bsk_orig = match sks_compressed.bootstrapping_key {
|
||||
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk: seeded_bsk,
|
||||
..
|
||||
} => seeded_bsk.decompress_into_lwe_bootstrap_key(),
|
||||
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::MultiBit { .. } => {
|
||||
panic!("Hpu currently not support multibit. Required a Classic BSK")
|
||||
}
|
||||
};
|
||||
let cpu_bsk_ntt = {
|
||||
// Convert the LweBootstrapKey in Ntt domain
|
||||
let mut ntt_bsk = NttLweBootstrapKeyOwned::<u64>::new(
|
||||
0_u64,
|
||||
cpu_bsk_orig.input_lwe_dimension(),
|
||||
cpu_bsk_orig.glwe_size(),
|
||||
cpu_bsk_orig.polynomial_size(),
|
||||
cpu_bsk_orig.decomposition_base_log(),
|
||||
cpu_bsk_orig.decomposition_level_count(),
|
||||
CiphertextModulus::new(u64::from(&hpu_params.ntt_params.prime_modulus) as u128),
|
||||
);
|
||||
|
||||
// Conversion to ntt domain
|
||||
par_convert_standard_lwe_bootstrap_key_to_ntt64(&cpu_bsk_orig, &mut ntt_bsk);
|
||||
ntt_bsk
|
||||
};
|
||||
let hpu_bsk =
|
||||
HpuLweBootstrapKeyOwned::from_with(cpu_bsk_orig.as_view(), hpu_params.clone());
|
||||
|
||||
let cpu_bsk_lb = NttLweBootstrapKeyOwned::from(hpu_bsk.as_view());
|
||||
|
||||
let bsk_mismatch: usize = std::iter::zip(
|
||||
cpu_bsk_ntt.as_view().into_container().iter(),
|
||||
cpu_bsk_lb.as_view().into_container().iter(),
|
||||
)
|
||||
.enumerate()
|
||||
.map(|(i, (x, y))| {
|
||||
if x != y {
|
||||
println!("@{i}:: {x:x} != {y:x}");
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.sum();
|
||||
|
||||
println!("Ksk loopback with {ksk_mismatch} errors");
|
||||
println!("Bsk loopback with {bsk_mismatch} errors");
|
||||
|
||||
assert_eq!(ksk_mismatch + bsk_mismatch, 0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user