feat(shortint): add an atomic counter to keep track of the number of PBSes

This commit is contained in:
Arthur Meyre
2024-02-02 17:00:25 +01:00
parent 3ff5d551a9
commit 52f3babde5
4 changed files with 61 additions and 0 deletions

View File

@@ -88,6 +88,8 @@ integer = ["shortint"]
internal-keycache = ["dep:lazy_static", "dep:fs2"]
gpu = ["tfhe-cuda-backend"]
pbs-stats = []
# Experimental section
experimental = []
experimental-force_fft_algo_dif4 = []
@@ -289,5 +291,9 @@ required-features = ["boolean"]
name = "sha256"
required-features = ["integer"]
[[example]]
name = "pbs_count"
required-features = ["integer", "pbs-stats"]
[lib]
crate-type = ["lib", "staticlib", "cdylib"]

View File

@@ -0,0 +1,29 @@
use tfhe::prelude::*;
use tfhe::*;
pub fn main() {
let config = ConfigBuilder::default().build();
let (cks, sks) = generate_keys(config);
let a = FheUint32::encrypt(42u32, &cks);
let b = FheUint32::encrypt(69u32, &cks);
set_server_key(sks);
let c = &a * &b;
let mul_32_count = get_pbs_count();
reset_pbs_count();
let d = &a & &b;
let and_32_count = get_pbs_count();
println!("mul_32_count: {mul_32_count}");
println!("and_32_count: {and_32_count}");
let c_dec: u32 = c.decrypt(&cks);
let d_dec: u32 = d.decrypt(&cks);
assert_eq!(42 * 69, c_dec);
assert_eq!(42 & 69, d_dec);
}

View File

@@ -92,6 +92,9 @@ pub mod integer;
/// cbindgen:ignore
pub mod shortint;
#[cfg(feature = "pbs-stats")]
pub use shortint::server_key::pbs_stats::*;
#[cfg(feature = "__wasm_api")]
/// cbindgen:ignore
mod js_on_wasm_api;

View File

@@ -46,6 +46,23 @@ use crate::shortint::PBSOrder;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display, Formatter};
#[cfg(feature = "pbs-stats")]
pub mod pbs_stats {
use std::sync::atomic::AtomicU64;
pub use std::sync::atomic::Ordering;
pub static PBS_COUNT: AtomicU64 = AtomicU64::new(0);
pub fn get_pbs_count() -> u64 {
PBS_COUNT.load(Ordering::Relaxed)
}
pub fn reset_pbs_count() {
PBS_COUNT.store(0, Ordering::Relaxed);
}
}
#[cfg(feature = "pbs-stats")]
pub use pbs_stats::*;
/// Error returned when the carry buffer is full.
#[derive(Debug)]
pub enum CheckError {
@@ -1172,6 +1189,9 @@ impl ServerKey {
ct: &mut Ciphertext,
acc: &LookupTableOwned,
) {
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
if ct.is_trivial() {
self.trivial_pbs_assign(ct, acc);
return;
@@ -1248,6 +1268,9 @@ impl ServerKey {
ct: &mut Ciphertext,
acc: &LookupTableOwned,
) {
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
if ct.is_trivial() {
self.trivial_pbs_assign(ct, acc);
return;