mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(shortint): add an atomic counter to keep track of the number of PBSes
This commit is contained in:
@@ -88,6 +88,8 @@ integer = ["shortint"]
|
||||
internal-keycache = ["dep:lazy_static", "dep:fs2"]
|
||||
gpu = ["tfhe-cuda-backend"]
|
||||
|
||||
pbs-stats = []
|
||||
|
||||
# Experimental section
|
||||
experimental = []
|
||||
experimental-force_fft_algo_dif4 = []
|
||||
@@ -289,5 +291,9 @@ required-features = ["boolean"]
|
||||
name = "sha256"
|
||||
required-features = ["integer"]
|
||||
|
||||
[[example]]
|
||||
name = "pbs_count"
|
||||
required-features = ["integer", "pbs-stats"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "staticlib", "cdylib"]
|
||||
|
||||
29
tfhe/examples/pbs_count.rs
Normal file
29
tfhe/examples/pbs_count.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::*;
|
||||
|
||||
pub fn main() {
|
||||
let config = ConfigBuilder::default().build();
|
||||
|
||||
let (cks, sks) = generate_keys(config);
|
||||
|
||||
let a = FheUint32::encrypt(42u32, &cks);
|
||||
let b = FheUint32::encrypt(69u32, &cks);
|
||||
|
||||
set_server_key(sks);
|
||||
|
||||
let c = &a * &b;
|
||||
let mul_32_count = get_pbs_count();
|
||||
|
||||
reset_pbs_count();
|
||||
let d = &a & &b;
|
||||
let and_32_count = get_pbs_count();
|
||||
|
||||
println!("mul_32_count: {mul_32_count}");
|
||||
println!("and_32_count: {and_32_count}");
|
||||
|
||||
let c_dec: u32 = c.decrypt(&cks);
|
||||
let d_dec: u32 = d.decrypt(&cks);
|
||||
|
||||
assert_eq!(42 * 69, c_dec);
|
||||
assert_eq!(42 & 69, d_dec);
|
||||
}
|
||||
@@ -92,6 +92,9 @@ pub mod integer;
|
||||
/// cbindgen:ignore
|
||||
pub mod shortint;
|
||||
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
pub use shortint::server_key::pbs_stats::*;
|
||||
|
||||
#[cfg(feature = "__wasm_api")]
|
||||
/// cbindgen:ignore
|
||||
mod js_on_wasm_api;
|
||||
|
||||
@@ -46,6 +46,23 @@ use crate::shortint::PBSOrder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
pub mod pbs_stats {
|
||||
use std::sync::atomic::AtomicU64;
|
||||
pub use std::sync::atomic::Ordering;
|
||||
pub static PBS_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
pub fn get_pbs_count() -> u64 {
|
||||
PBS_COUNT.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn reset_pbs_count() {
|
||||
PBS_COUNT.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
pub use pbs_stats::*;
|
||||
|
||||
/// Error returned when the carry buffer is full.
|
||||
#[derive(Debug)]
|
||||
pub enum CheckError {
|
||||
@@ -1172,6 +1189,9 @@ impl ServerKey {
|
||||
ct: &mut Ciphertext,
|
||||
acc: &LookupTableOwned,
|
||||
) {
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if ct.is_trivial() {
|
||||
self.trivial_pbs_assign(ct, acc);
|
||||
return;
|
||||
@@ -1248,6 +1268,9 @@ impl ServerKey {
|
||||
ct: &mut Ciphertext,
|
||||
acc: &LookupTableOwned,
|
||||
) {
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if ct.is_trivial() {
|
||||
self.trivial_pbs_assign(ct, acc);
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user