fix(hpu): First round of review by IceTDrinker

Still WIP, all remarks not taken into account yet
This commit is contained in:
Baptiste Roux
2025-04-18 17:30:31 +02:00
parent ed0c15c60d
commit 00b7d2042b
13 changed files with 534 additions and 559 deletions

View File

@@ -15,6 +15,7 @@ members = [
"utils/param_dedup",
"tests",
"mockups/tfhe-hpu-mockup",
"tests",
]
exclude = [

View File

@@ -56,8 +56,7 @@ TFHECUDA_SRC=backends/tfhe-cuda-backend/cuda
TFHECUDA_BUILD=$(TFHECUDA_SRC)/build
# tfhe-hpu-backend
CUR_SCRIPT_DIR=$(shell cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
HPU_BACKEND_DIR=$(CUR_SCRIPT_DIR)/backends/tfhe-hpu-backend
HPU_BACKEND_DIR=backends/tfhe-hpu-backend
HPU_CONFIG=v80
V80_PCIE_DEV=$(shell lspci -d 10ee:50b5 | sed -e "s/\(..\).*/\1/")

View File

@@ -3,11 +3,15 @@ name = "tfhe-hpu-backend"
version = "0.1.0"
edition = "2021"
authors = ["Zama Hardware team"]
license = "BSD-3-Clause-Clear"
description = "HPU implementation on FPGA of TFHE-rs primitives."
homepage = "https://www.zama.ai/"
documentation = "https://docs.zama.ai/tfhe-rs"
repository = "https://github.com/zama-ai/tfhe-rs"
readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography", "hardware", "fpga"]
[features]
default = []
hw-xrt = []
hw-v80 = []
io-dump = ["num-traits"]

View File

@@ -1,8 +1,10 @@
[package]
name = "hpu-sim"
name = "tfhe-hpu-mockup"
version = "0.1.0"
edition = "2021"
authors = ["Zama Hardware team"]
license = "BSD-3-Clause-Clear"
description = "Simulation model of HPU hardware."
readme = "README.md"
[features]

View File

@@ -81,7 +81,7 @@ Other optional configuration knobs are available:
On top of that `tfhe-hpu-mockup` could generate a detailed set of trace points at runtime to help during the debug/exploration phase (e.g. When writing new Hpu firmware).
Those trace points rely on `tokio-tracing` and could be activated on a path::verbosity based through the `RUST_LOG` environment variable.
For example the following value will enable the info trace for all the design and the debug one for the ucore submodule:
`RUST_LOG=info,hpu_sim::modules::ucore=debug`.
`RUST_LOG=info,tfhe_hpu_mockup::modules::ucore=debug`.
> NB: With the mockup estimated IOp performances must be read from the mockup log, not from the user application report.
> Indeed, the user application reports the execution time of the mockup binary not the expected performance on real Hpu hardware.

View File

@@ -8,8 +8,8 @@
use std::fs::OpenOptions;
use std::path::Path;
use hpu_sim::{HpuSim, MockupOptions, MockupParameters};
use tfhe::tfhe_hpu_backend::prelude::*;
use tfhe_hpu_mockup::{HpuSim, MockupOptions, MockupParameters};
/// Define CLI arguments
use clap::Parser;

View File

@@ -1,4 +1,4 @@
#! /usr/bin/bash
#! /usr/bin/env/ bash
# Find current script directory. This should be PROJECT_DIR
CUR_SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)

View File

@@ -1,103 +1,3 @@
#[cfg(feature = "hpu")]
fn main() {
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
use tfhe::prelude::*;
use tfhe::{set_server_key, ClientKey, CompressedServerKey, Config, FheUint8, *};
use tfhe_hpu_backend::prelude::*;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
pub use clap::Parser;
/// Define CLI arguments
#[derive(clap::Parser, Debug, Clone, serde::Serialize)]
#[clap(long_about = "HPU example that shows the use of the HighLevelAPI.")]
pub struct Args {
// Fpga configuration ------------------------------------------------------
/// Toml top-level configuration file
#[clap(
long,
value_parser,
default_value = "${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml"
)]
pub config: ShellString,
// Exec configuration ----------------------------------------------------
/// Seed used for some rngs
#[clap(long, value_parser)]
pub seed: Option<u128>,
}
let args = Args::parse();
println!("User Options: {args:?}");
// Register tracing subscriber that use env-filter
// Select verbosity with env_var: e.g. `RUST_LOG=Alu=trace`
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.compact()
// Display source code file paths
.with_file(false)
// Display source code line numbers
.with_line_number(false)
.without_time()
// Build & register the subscriber
.init();
// Seeder for args randomization ------------------------------------------
let mut rng: StdRng = if let Some(seed) = args.seed {
SeedableRng::seed_from_u64((seed & u64::MAX as u128) as u64)
} else {
SeedableRng::from_entropy()
};
// Instantiate HpuDevice --------------------------------------------------
let hpu_device = HpuDevice::from_config(&args.config.expand());
// Generate keys ----------------------------------------------------------
let config = Config::from_hpu_device(&hpu_device);
// Force key seeder if seed specified by user
if let Some(seed) = args.seed {
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
let shortint_engine = tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder);
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
std::mem::replace(engine, shortint_engine)
});
}
let cks = ClientKey::generate(config);
let csks = CompressedServerKey::new(&cks);
set_server_key((hpu_device, csks));
// Show 8bit capabilities --------------------------------------------------
{
impl_hlapi_showcase!(FheUint8, u8);
hlapi_showcase(&cks, &mut rng);
}
// Show 16bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint16, u16);
hlapi_showcase(&cks, &mut rng);
}
// Show 32bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint32, u32);
hlapi_showcase(&cks, &mut rng);
}
// Show 64bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint64, u64);
hlapi_showcase(&cks, &mut rng);
}
}
#[cfg(feature = "hpu")]
#[macro_export]
macro_rules! impl_hlapi_showcase {
($fhe_type: ty, $user_type: ty) => {
::paste::paste! {
@@ -199,5 +99,99 @@ macro_rules! impl_hlapi_showcase {
};
}
#[cfg(not(feature = "hpu"))]
fn main() {}
fn main() {
use tfhe::core_crypto::commons::generators::DeterministicSeeder;
use tfhe::core_crypto::prelude::DefaultRandomGenerator;
use tfhe::prelude::*;
use tfhe::{set_server_key, ClientKey, CompressedServerKey, Config, FheUint8, *};
use tfhe_hpu_backend::prelude::*;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
pub use clap::Parser;
/// Define CLI arguments
#[derive(clap::Parser, Debug, Clone, serde::Serialize)]
#[clap(long_about = "HPU example that shows the use of the HighLevelAPI.")]
pub struct Args {
// Fpga configuration ------------------------------------------------------
/// Toml top-level configuration file
#[clap(
long,
value_parser,
default_value = "${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml"
)]
pub config: ShellString,
// Exec configuration ----------------------------------------------------
/// Seed used for some rngs
#[clap(long, value_parser)]
pub seed: Option<u128>,
}
let args = Args::parse();
println!("User Options: {args:?}");
// Register tracing subscriber that use env-filter
// Select verbosity with env_var: e.g. `RUST_LOG=Alu=trace`
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.compact()
// Display source code file paths
.with_file(false)
// Display source code line numbers
.with_line_number(false)
.without_time()
// Build & register the subscriber
.init();
// Seeder for args randomization ------------------------------------------
let mut rng: StdRng = if let Some(seed) = args.seed {
SeedableRng::seed_from_u64((seed & u64::MAX as u128) as u64)
} else {
SeedableRng::from_entropy()
};
// Instantiate HpuDevice --------------------------------------------------
let hpu_device = HpuDevice::from_config(&args.config.expand());
// Generate keys ----------------------------------------------------------
let config = Config::from_hpu_device(&hpu_device);
// Force key seeder if seed specified by user
if let Some(seed) = args.seed {
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(Seed(seed));
let shortint_engine = tfhe::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder);
tfhe::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
std::mem::replace(engine, shortint_engine)
});
}
let cks = ClientKey::generate(config);
let csks = CompressedServerKey::new(&cks);
set_server_key((hpu_device, csks));
// Show 8bit capabilities --------------------------------------------------
{
impl_hlapi_showcase!(FheUint8, u8);
hlapi_showcase(&cks, &mut rng);
}
// Show 16bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint16, u16);
hlapi_showcase(&cks, &mut rng);
}
// Show 32bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint32, u32);
hlapi_showcase(&cks, &mut rng);
}
// Show 64bit capabilities -------------------------------------------------
{
impl_hlapi_showcase!(FheUint64, u64);
hlapi_showcase(&cks, &mut rng);
}
}

View File

@@ -176,8 +176,7 @@ where
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
res &= mod_mask;
// Control bit about whether we should balance the state
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random ==
// 1)
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
// Balance depending on the control bit
res.wrapping_sub(need_balance << rep_bit_count)

View File

@@ -28,7 +28,6 @@ pub fn msb2lsb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut
}
/// This function change information position in container
/// Move information bits from LSB to MSB
#[allow(unused)]
pub fn lsb2msb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
let ct_width = params.ntt_params.ct_width as usize;
let storage_width = Scalar::BITS;
@@ -36,29 +35,3 @@ pub fn lsb2msb_align<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut
*val <<= storage_width - ct_width;
}
}
/// This function switches modulus for a slice of coefficients
/// From: user domain (i.e. pow2 modulus)
/// To: ntt domain ( i.e. prime modulus)
/// Switching are done inplace
pub fn user2ntt_modswitch<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
let user_width = params.ntt_params.ct_width as usize;
let mod_p_u128 = u64::from(&params.ntt_params.prime_modulus) as u128;
for val in data.iter_mut() {
let val_u128: u128 = val.cast_into();
*val = Scalar::cast_from(((val_u128 * mod_p_u128) + (1 << (user_width - 1))) >> user_width);
}
}
/// This function switches modulus for a slice of coefficients
/// From: ntt domain ( i.e. prime modulus)
/// To: user domain (i.e. pow2 modulus)
/// Switching are done inplace
pub fn ntt2user_modswitch<Scalar: UnsignedInteger>(params: &HpuParameters, data: &mut [Scalar]) {
let user_width = params.ntt_params.ct_width as usize;
let mod_p_u128 = u64::from(&params.ntt_params.prime_modulus) as u128;
for val in data.iter_mut() {
let val_u128: u128 = val.cast_into();
*val = Scalar::cast_from((((val_u128) << user_width) | ((mod_p_u128) >> 1)) / mod_p_u128);
}
}

View File

@@ -6,27 +6,27 @@
//! Both Ntt architecture used reverse order as input
//! However, Wmm use an intermediate Network required by the BSK shuffling.
use crate::core_crypto::prelude::UnsignedInteger;
use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, UnsignedInteger};
#[derive(Debug, Clone)]
pub struct RadixBasis {
radix_lg: usize,
digits_nb: usize,
radix_lg: DecompositionBaseLog,
digits_nb: DecompositionLevelCount,
}
impl RadixBasis {
pub fn new(radix: usize, digits_nb: usize) -> Self {
let radix_lg = radix.ilog2() as usize;
Self {
radix_lg,
digits_nb,
radix_lg: DecompositionBaseLog(radix_lg),
digits_nb: DecompositionLevelCount(digits_nb),
}
}
pub fn radix_lg(&self) -> usize {
pub fn radix_lg(&self) -> DecompositionBaseLog {
self.radix_lg
}
pub fn digits_nb(&self) -> usize {
pub fn digits_nb(&self) -> DecompositionLevelCount {
self.digits_nb
}
@@ -35,13 +35,13 @@ impl RadixBasis {
/// * Nat_order from 0..rank
/// * Rev_order from rank..digits
pub fn idx_pdrev(&self, digits: usize, rank: usize, nat_val: usize) -> usize {
let mask = (1 << (digits - rank) * self.radix_lg) - 1;
let to_be_reversed = (nat_val >> (rank * self.radix_lg)) & mask;
let reversed = Self::new(1 << self.radix_lg, digits - rank).idx_rev(to_be_reversed);
let mask = (1 << (digits - rank) * self.radix_lg.0) - 1;
let to_be_reversed = (nat_val >> (rank * self.radix_lg.0)) & mask;
let reversed = Self::new(1 << self.radix_lg.0, digits - rank).idx_rev(to_be_reversed);
let to_be_zeroed = nat_val & (mask << (rank * self.radix_lg));
let to_be_zeroed = nat_val & (mask << (rank * self.radix_lg.0));
let mut result = nat_val & !to_be_zeroed;
result |= reversed << (rank * self.radix_lg);
result |= reversed << (rank * self.radix_lg.0);
result
}
@@ -54,11 +54,11 @@ impl RadixBasis {
/// Convert an index expressed in Natural Order into `reverse` Order
pub fn idx_rev(&self, mut nat_val: usize) -> usize {
let mask = (1 << self.radix_lg) - 1;
let mask = (1 << self.radix_lg.0) - 1;
let mut result = 0;
for i in (0..self.digits_nb).rev() {
result |= (nat_val & mask) << (i * self.radix_lg);
nat_val >>= self.radix_lg;
for i in (0..self.digits_nb.0).rev() {
result |= (nat_val & mask) << (i * self.radix_lg.0);
nat_val >>= self.radix_lg.0;
}
result
@@ -80,7 +80,7 @@ where
assert_eq!(src.len(), dst.len(), "Poly src/ dst length mismtach");
assert_eq!(
src.len(),
((1 << rb_conv.radix_lg()) as usize).pow(rb_conv.digits_nb() as u32),
((1 << rb_conv.radix_lg().0) as usize).pow(rb_conv.digits_nb().0 as u32),
"Poly length mismtach with RadixBasis configuration"
);
@@ -92,7 +92,7 @@ where
#[derive(Debug, Clone)]
pub struct PcgNetwork {
stg_nb: usize,
stage_nb: usize,
rb_conv: RadixBasis,
}
@@ -100,7 +100,7 @@ impl PcgNetwork {
/// Create network instance from NttParameters
pub fn new(radix: usize, stg_nb: usize) -> Self {
Self {
stg_nb,
stage_nb: stg_nb,
rb_conv: RadixBasis::new(radix, stg_nb),
}
}
@@ -108,9 +108,11 @@ impl PcgNetwork {
/// For a given position idx (in 0..N-1), at processing step delta_idx,
/// find the corresponding position idx (consider the input of the node)
pub fn get_pos_id(&mut self, delta_idx: usize, pos_idx: usize) -> usize {
let node_idx = pos_idx / (1 << self.rb_conv.radix_lg());
let rmn_idx = pos_idx % (1 << self.rb_conv.radix_lg());
let pdrev_idx = self.rb_conv.idx_pdrev(self.stg_nb - 1, delta_idx, node_idx);
pdrev_idx * (1 << self.rb_conv.radix_lg()) + rmn_idx
let node_idx = pos_idx / (1 << self.rb_conv.radix_lg().0);
let rmn_idx = pos_idx % (1 << self.rb_conv.radix_lg().0);
let pdrev_idx = self
.rb_conv
.idx_pdrev(self.stage_nb - 1, delta_idx, node_idx);
pdrev_idx * (1 << self.rb_conv.radix_lg().0) + rmn_idx
}
}

View File

@@ -21,7 +21,7 @@ impl<Scalar: UnsignedInteger> FromWith<LweCiphertextView<'_, Scalar>, HpuParamet
let pbs_p = &params.pbs_params;
let poly_size = pbs_p.polynomial_size;
// NB: Glwe polynomial must be in reversed order
// NB: lwe mask is view as polynomial and must be in reversed order
// Allocate translation buffer and reversed vector here
let rb_conv = order::RadixBasis::new(ntt_p.radix, ntt_p.stg_nb);
let lwe_len = hpu_lwe.len();

View File

@@ -4,31 +4,27 @@
//! WARN: Only one Hpu could be use at a time, thus all test must be run sequentially
#[cfg(feature = "hpu")]
use std::str::FromStr;
mod hpu_test {
use std::str::FromStr;
pub use rand::Rng;
pub use rand::Rng;
#[cfg(feature = "hpu")]
pub use tfhe_hpu_backend::prelude::*;
pub use tfhe_hpu_backend::prelude::*;
#[cfg(feature = "hpu")]
/// Variable to store initialized HpuDevice and associated client key for fast iteration
static HPU_DEVICE_CKS: std::sync::OnceLock<(
std::sync::Mutex<HpuDevice>,
tfhe::integer::ClientKey,
)> = std::sync::OnceLock::new();
/// Variable to store initialized HpuDevice and associated client key for fast iteration
static HPU_DEVICE_CKS: std::sync::OnceLock<(
std::sync::Mutex<HpuDevice>,
tfhe::integer::ClientKey,
)> = std::sync::OnceLock::new();
// NB: Currently u55c didn't check for workq overflow.
// -> Use default value < queue depth to circumvent this limitation
// NB': This is only for u55c, on V80 user could set HPU_TEST_ITER to whatever value he want
#[cfg(feature = "hpu")]
const DEFAULT_TEST_ITER: usize = 32;
// NB: Currently u55c didn't check for workq overflow.
// -> Use default value < queue depth to circumvent this limitation
// NB': This is only for u55c, on V80 user could set HPU_TEST_ITER to whatever value he want
const DEFAULT_TEST_ITER: usize = 32;
#[macro_export]
macro_rules! hpu_testbundle {
macro_rules! hpu_testbundle {
($base_name: literal::$integer_width:tt => [$($testcase: literal),+]) => {
::paste::paste! {
#[cfg(feature = "hpu")]
#[test]
pub fn [<hpu_test_ $base_name:lower _u $integer_width>]() {
// Register tracing subscriber that use env-filter
@@ -90,7 +86,7 @@ macro_rules! hpu_testbundle {
};
}
macro_rules! hpu_testcase {
macro_rules! hpu_testcase {
($iop: literal => [$($user_type: ty),+] |$ct:ident, $imm: ident| $behav: expr) => {
::paste::paste! {
$(
@@ -157,412 +153,417 @@ macro_rules! hpu_testcase {
};
}
// Define testcase implementation for all supported IOp
// Alu IOp with Ct x Imm
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
// Define testcase implementation for all supported IOp
// Alu IOp with Ct x Imm
hpu_testcase!("ADDS" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_add(imm[0])]);
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
hpu_testcase!("SUBS" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_sub(imm[0])]);
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
hpu_testcase!("SSUB" => [u8, u16, u32, u64, u128]
|ct, imm| vec![imm[0].wrapping_sub(ct[0])]);
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
hpu_testcase!("MULS" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_mul(imm[0])]);
// Alu IOp with Ct x Ct
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
// Alu IOp with Ct x Ct
hpu_testcase!("ADD" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_add(ct[1])]);
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
hpu_testcase!("SUB" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_sub(ct[1])]);
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
hpu_testcase!("MUL" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0].wrapping_mul(ct[1])]);
// Bitwise IOp
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
// Bitwise IOp
hpu_testcase!("BW_AND" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] & ct[1]]);
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
hpu_testcase!("BW_OR" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] | ct[1]]);
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
hpu_testcase!("BW_XOR" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] ^ ct[1]]);
// Comparison IOp
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
// Comparison IOp
hpu_testcase!("CMP_GT" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] > ct[1]]);
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
hpu_testcase!("CMP_GTE" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] >= ct[1]]);
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
hpu_testcase!("CMP_LT" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] < ct[1]]);
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
hpu_testcase!("CMP_LTE" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] <= ct[1]]);
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
hpu_testcase!("CMP_EQ" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] == ct[1]]);
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
hpu_testcase!("CMP_NEQ" => [u8, u16, u32, u64, u128]
|ct, imm| vec![ct[0] != ct[1]]);
// Ternary IOp
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
// Ternary IOp
hpu_testcase!("IF_THEN_ZERO" => [u8, u16, u32, u64, u128]
|ct, imm| vec![if ct[1] != 0 {ct[0]} else { 0}]);
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
hpu_testcase!("IF_THEN_ELSE" => [u8, u16, u32, u64, u128]
|ct, imm| vec![if ct[2] != 0 {ct[0]} else { ct[1]}]);
// ERC 20 found xfer
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|ct, imm| {
let from = ct[0];
let to = ct[1];
let amount = ct[2];
// TODO enhance this to prevent overflow
if from >= amount {
vec![from - amount, to.wrapping_add(amount)]
} else {
vec![from, to]
}
});
// Define a set of test bundle for various size
// 8bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::8 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alu"::8 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("bitwise"::8 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("cmp"::8 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::8 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::8 => [
"erc_20"
]);
// 16bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::16 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alu"::16 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("bitwise"::16 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("cmp"::16 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::16 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::16 => [
"erc_20"
]);
// 32bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::32 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alu"::32 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("bitwise"::32 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("cmp"::32 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::32 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::32 => [
"erc_20"
]);
// 64bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::64 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alu"::64 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("bitwise"::64 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("cmp"::64 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::64 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::64 => [
"erc_20"
]);
// 128bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alus"::128 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("alu"::128 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("bitwise"::128 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("cmp"::128 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("ternary"::128 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
crate::hpu_testbundle!("algo"::128 => [
"erc_20"
]);
/// Simple test dedicated to check entities convertion from/to Cpu
#[cfg(feature = "hpu")]
#[test]
fn hpu_key_loopback() {
use tfhe::core_crypto::hpu::from_with::FromWith;
use tfhe::core_crypto::prelude::*;
use tfhe::*;
use tfhe_hpu_backend::prelude::*;
// Retrieved HpuDevice or init ---------------------------------------------
// Used hpu_device backed in static variable to automatically serialize tests
let (hpu_mutex, cks) = HPU_DEVICE_CKS.get_or_init(|| {
// Instantiate HpuDevice --------------------------------------------------
let hpu_device = {
let config_file = ShellString::new(
"${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml".to_string(),
);
HpuDevice::from_config(&config_file.expand())
};
// Extract pbs_configuration from Hpu and create Client/Server Key
let cks = tfhe::integer::ClientKey::new(tfhe::shortint::ClassicPBSParameters::from(
hpu_device.params(),
));
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks);
// Init Hpu device with server key and firmware
tfhe::integer::hpu::init_device(&hpu_device, sks_compressed.into()).expect("Invalid key");
(std::sync::Mutex::new(hpu_device), cks)
});
let hpu_params = hpu_mutex
.lock()
.expect("Error with HpuDevice Mutex")
.params()
.clone();
// Generate Keys ---------------------------------------------------------
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks).into_raw_parts();
// KSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let mut cpu_ksk_orig = sks_compressed
.key_switching_key
.decompress_into_lwe_keyswitch_key();
let hpu_ksk = HpuLweKeyswitchKeyOwned::from_with(cpu_ksk_orig.as_view(), hpu_params.clone());
let cpu_ksk_lb = LweKeyswitchKeyOwned::from(hpu_ksk.as_view());
// NB: Some hw modifications such as bit shrinki couldn't be reversed
cpu_ksk_orig.as_mut().iter_mut().for_each(|coef| {
let ks_p = hpu_params.ks_params;
// Apply Hw rounding
// Extract info bits and rounding if required
let coef_info = *coef >> (u64::BITS - ks_p.width as u32);
let coef_rounding = if (ks_p.width as u32) < u64::BITS {
(*coef >> (u64::BITS - (ks_p.width + 1) as u32)) & 0x1
} else {
0
};
*coef = (coef_info + coef_rounding) << (u64::BITS - ks_p.width as u32);
});
let ksk_mismatch: usize =
std::iter::zip(cpu_ksk_orig.as_ref().iter(), cpu_ksk_lb.as_ref().iter())
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("Ksk mismatch @{i}:: {x:x} != {y:x}");
1
// ERC 20 found xfer
hpu_testcase!("ERC_20" => [u8, u16, u32, u64, u128]
|ct, imm| {
let from = ct[0];
let to = ct[1];
let amount = ct[2];
// TODO enhance this to prevent overflow
if from >= amount {
vec![from - amount, to.wrapping_add(amount)]
} else {
0
vec![from, to]
}
})
.sum();
});
// BSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let cpu_bsk_orig = match sks_compressed.bootstrapping_key {
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::Classic {
bsk: seeded_bsk,
..
} => seeded_bsk.decompress_into_lwe_bootstrap_key(),
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::MultiBit { .. } => {
panic!("Hpu currently not support multibit. Required a Classic BSK")
}
};
let cpu_bsk_ntt = {
// Convert the LweBootstrapKey in Ntt domain
let mut ntt_bsk = NttLweBootstrapKeyOwned::<u64>::new(
0_u64,
cpu_bsk_orig.input_lwe_dimension(),
cpu_bsk_orig.glwe_size(),
cpu_bsk_orig.polynomial_size(),
cpu_bsk_orig.decomposition_base_log(),
cpu_bsk_orig.decomposition_level_count(),
CiphertextModulus::new(u64::from(&hpu_params.ntt_params.prime_modulus) as u128),
);
// Define a set of test bundle for various size
// 8bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::8 => [
"adds",
"subs",
"ssub",
"muls"
]);
// Conversion to ntt domain
par_convert_standard_lwe_bootstrap_key_to_ntt64(&cpu_bsk_orig, &mut ntt_bsk);
ntt_bsk
};
let hpu_bsk = HpuLweBootstrapKeyOwned::from_with(cpu_bsk_orig.as_view(), hpu_params.clone());
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::8 => [
"add",
"sub",
"mul"
]);
let cpu_bsk_lb = NttLweBootstrapKeyOwned::from(hpu_bsk.as_view());
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::8 => [
"bw_and",
"bw_or",
"bw_xor"
]);
let bsk_mismatch: usize = std::iter::zip(
cpu_bsk_ntt.as_view().into_container().iter(),
cpu_bsk_lb.as_view().into_container().iter(),
)
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("@{i}:: {x:x} != {y:x}");
1
} else {
0
}
})
.sum();
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::8 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
println!("Ksk loopback with {ksk_mismatch} errors");
println!("Bsk loopback with {bsk_mismatch} errors");
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::8 => [
"if_then_zero",
"if_then_else"
]);
assert_eq!(ksk_mismatch + bsk_mismatch, 0);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::8 => [
"erc_20"
]);
// 16bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::16 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::16 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::16 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::16 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::16 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::16 => [
"erc_20"
]);
// 32bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::32 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::32 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::32 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::32 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::32 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::32 => [
"erc_20"
]);
// 64bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::64 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::64 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::64 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::64 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::64 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::64 => [
"erc_20"
]);
// 128bit ciphertext -----------------------------------------
#[cfg(feature = "hpu")]
hpu_testbundle!("alus"::128 => [
"adds",
"subs",
"ssub",
"muls"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("alu"::128 => [
"add",
"sub",
"mul"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("bitwise"::128 => [
"bw_and",
"bw_or",
"bw_xor"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("cmp"::128 => [
"cmp_gt",
"cmp_gte",
"cmp_lt",
"cmp_lte",
"cmp_eq",
"cmp_neq"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("ternary"::128 => [
"if_then_zero",
"if_then_else"
]);
#[cfg(feature = "hpu")]
hpu_testbundle!("algo"::128 => [
"erc_20"
]);
/// Simple test dedicated to check entities convertion from/to Cpu
#[cfg(feature = "hpu")]
#[test]
fn hpu_key_loopback() {
use tfhe::core_crypto::hpu::from_with::FromWith;
use tfhe::core_crypto::prelude::*;
use tfhe::*;
use tfhe_hpu_backend::prelude::*;
// Retrieved HpuDevice or init ---------------------------------------------
// Used hpu_device backed in static variable to automatically serialize tests
let (hpu_mutex, cks) = HPU_DEVICE_CKS.get_or_init(|| {
// Instantiate HpuDevice --------------------------------------------------
let hpu_device = {
let config_file = ShellString::new(
"${HPU_BACKEND_DIR}/config_store/${HPU_CONFIG}/hpu_config.toml".to_string(),
);
HpuDevice::from_config(&config_file.expand())
};
// Extract pbs_configuration from Hpu and create Client/Server Key
let cks = tfhe::integer::ClientKey::new(tfhe::shortint::ClassicPBSParameters::from(
hpu_device.params(),
));
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks);
// Init Hpu device with server key and firmware
tfhe::integer::hpu::init_device(&hpu_device, sks_compressed.into())
.expect("Invalid key");
(std::sync::Mutex::new(hpu_device), cks)
});
let hpu_params = hpu_mutex
.lock()
.expect("Error with HpuDevice Mutex")
.params()
.clone();
// Generate Keys ---------------------------------------------------------
let sks_compressed =
tfhe::integer::CompressedServerKey::new_radix_compressed_server_key(&cks)
.into_raw_parts();
// KSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let mut cpu_ksk_orig = sks_compressed
.key_switching_key
.decompress_into_lwe_keyswitch_key();
let hpu_ksk =
HpuLweKeyswitchKeyOwned::from_with(cpu_ksk_orig.as_view(), hpu_params.clone());
let cpu_ksk_lb = LweKeyswitchKeyOwned::from(hpu_ksk.as_view());
// NB: Some hw modifications such as bit shrinki couldn't be reversed
cpu_ksk_orig.as_mut().iter_mut().for_each(|coef| {
let ks_p = hpu_params.ks_params;
// Apply Hw rounding
// Extract info bits and rounding if required
let coef_info = *coef >> (u64::BITS - ks_p.width as u32);
let coef_rounding = if (ks_p.width as u32) < u64::BITS {
(*coef >> (u64::BITS - (ks_p.width + 1) as u32)) & 0x1
} else {
0
};
*coef = (coef_info + coef_rounding) << (u64::BITS - ks_p.width as u32);
});
let ksk_mismatch: usize =
std::iter::zip(cpu_ksk_orig.as_ref().iter(), cpu_ksk_lb.as_ref().iter())
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("Ksk mismatch @{i}:: {x:x} != {y:x}");
1
} else {
0
}
})
.sum();
// BSK Loopback conversion and check -------------------------------------
// Extract and convert ksk
let cpu_bsk_orig = match sks_compressed.bootstrapping_key {
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::Classic {
bsk: seeded_bsk,
..
} => seeded_bsk.decompress_into_lwe_bootstrap_key(),
tfhe::shortint::server_key::ShortintCompressedBootstrappingKey::MultiBit { .. } => {
panic!("Hpu currently not support multibit. Required a Classic BSK")
}
};
let cpu_bsk_ntt = {
// Convert the LweBootstrapKey in Ntt domain
let mut ntt_bsk = NttLweBootstrapKeyOwned::<u64>::new(
0_u64,
cpu_bsk_orig.input_lwe_dimension(),
cpu_bsk_orig.glwe_size(),
cpu_bsk_orig.polynomial_size(),
cpu_bsk_orig.decomposition_base_log(),
cpu_bsk_orig.decomposition_level_count(),
CiphertextModulus::new(u64::from(&hpu_params.ntt_params.prime_modulus) as u128),
);
// Conversion to ntt domain
par_convert_standard_lwe_bootstrap_key_to_ntt64(&cpu_bsk_orig, &mut ntt_bsk);
ntt_bsk
};
let hpu_bsk =
HpuLweBootstrapKeyOwned::from_with(cpu_bsk_orig.as_view(), hpu_params.clone());
let cpu_bsk_lb = NttLweBootstrapKeyOwned::from(hpu_bsk.as_view());
let bsk_mismatch: usize = std::iter::zip(
cpu_bsk_ntt.as_view().into_container().iter(),
cpu_bsk_lb.as_view().into_container().iter(),
)
.enumerate()
.map(|(i, (x, y))| {
if x != y {
println!("@{i}:: {x:x} != {y:x}");
1
} else {
0
}
})
.sum();
println!("Ksk loopback with {ksk_mismatch} errors");
println!("Bsk loopback with {bsk_mismatch} errors");
assert_eq!(ksk_mismatch + bsk_mismatch, 0);
}
}