mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): avoid out of memory when benchmarking throughput
This commit is contained in:
@@ -421,23 +421,32 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
|
||||
let block_multiplicator = (ref_block_count as f64 / num_block as f64).ceil().min(1.0);
|
||||
// Some operations with a high serial workload (e.g. division) would yield an operation
|
||||
// loading value so low that the number of elements in the end wouldn't be meaningful.
|
||||
let minimum_loading = if num_block < 64 { 0.2 } else { 0.01 };
|
||||
let minimum_loading = if num_block < 64 { 1.0 } else { 0.015 };
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
let num_sms_per_gpu = get_number_of_sms();
|
||||
let total_num_sm = num_sms_per_gpu * get_number_of_gpus();
|
||||
|
||||
let total_blocks_per_sm = 4u32; // Assume each SM can handle 4 blocks concurrently
|
||||
let total_num_sm = total_blocks_per_sm * total_num_sm;
|
||||
let total_blocks_per_sm = 4u64; // Assume each SM can handle 4 blocks concurrently
|
||||
let min_num_waves = 4u64; //Enforce at least 4 waves in the GPU
|
||||
let elements_per_wave = total_num_sm as u64 / (num_block as u64);
|
||||
|
||||
let block_factor = ((2.0f64 * num_block as f64) / 4.0f64).ceil() as u64;
|
||||
let elements_per_wave = total_blocks_per_sm * total_num_sm as u64 / block_factor;
|
||||
// We need to enable the new load for pbs benches and for sizes larger than 16 blocks in
|
||||
// demanding operations for the rest of operations we maintain a minimum of 200
|
||||
// elements
|
||||
let min_elements = if op_pbs_count == 1
|
||||
|| (op_pbs_count > (num_block * num_block) as u64 && num_block >= 16)
|
||||
{
|
||||
elements_per_wave * min_num_waves
|
||||
} else {
|
||||
200u64
|
||||
};
|
||||
let operation_loading = ((total_num_sm as u64 / op_pbs_count) as f64).max(minimum_loading);
|
||||
let elements = (total_num_sm as f64 * block_multiplicator * operation_loading) as u64;
|
||||
elements.min(elements_per_wave * min_num_waves) // This threshold is useful for operation
|
||||
// with both a small number of
|
||||
// block and low PBs count.
|
||||
elements.min(min_elements) // This threshold is useful for operation
|
||||
// with both a small number of
|
||||
// block and low PBs count.
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user