fix(gpu): change mininum number of elements in benches

This commit is contained in:
Guillermo Oyarzun
2025-09-01 18:40:34 +02:00
committed by Agnès Leroy
parent b42ba79145
commit c2e816a86c
5 changed files with 28 additions and 8 deletions

View File

@@ -3,7 +3,7 @@ use std::path::PathBuf;
use std::sync::OnceLock;
use std::{env, fs};
#[cfg(feature = "gpu")]
use tfhe::core_crypto::gpu::get_number_of_gpus;
use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms};
use tfhe::core_crypto::prelude::*;
#[cfg(feature = "boolean")]
@@ -417,10 +417,6 @@ pub fn get_bench_type() -> &'static BenchmarkType {
BENCH_TYPE.get_or_init(|| BenchmarkType::from_env().unwrap())
}
/// Number of streaming multiprocessors (SM) available on Nvidia H100 GPU
#[cfg(feature = "gpu")]
const H100_PCIE_SM_COUNT: u32 = 114;
/// Generate a number of threads to use to saturate current machine for throughput measurements.
pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
let ref_block_count = 32; // Represent a ciphertext of 64 bits for 2_2 parameters set
@@ -431,11 +427,19 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
#[cfg(feature = "gpu")]
{
let total_num_sm = H100_PCIE_SM_COUNT * get_number_of_gpus();
let num_sms_per_gpu = get_number_of_sms();
let total_num_sm = num_sms_per_gpu * get_number_of_gpus();
let total_blocks_per_sm = 4u32; // Assume each SM can handle 4 blocks concurrently
let total_num_sm = total_blocks_per_sm * total_num_sm;
let min_num_waves = 4u64; //Enforce at least 4 waves in the GPU
let elements_per_wave = total_num_sm as u64 / (num_block as u64);
let operation_loading = ((total_num_sm as u64 / op_pbs_count) as f64).max(minimum_loading);
let elements = (total_num_sm as f64 * block_multiplicator * operation_loading) as u64;
elements.min(200) // This threshold is useful for operation with both a small number of
// block and low PBs count.
elements.min(elements_per_wave * min_num_waves) // This threshold is useful for operation
// with both a small number of
// block and low PBs count.
}
#[cfg(feature = "hpu")]
{