fix(gpu): avoid out of memory when benchmarking throughput

This commit is contained in:
Guillermo Oyarzun
2025-09-08 15:19:40 +02:00
parent c4feabbfa3
commit 022cb3b18a
2 changed files with 53 additions and 23 deletions

View File

@@ -421,23 +421,32 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
let block_multiplicator = (ref_block_count as f64 / num_block as f64).ceil().min(1.0);
// Some operations with a high serial workload (e.g. division) would yield an operation
// loading value so low that the number of elements in the end wouldn't be meaningful.
let minimum_loading = if num_block < 64 { 0.2 } else { 0.01 };
let minimum_loading = if num_block < 64 { 1.0 } else { 0.015 };
#[cfg(feature = "gpu")]
{
let num_sms_per_gpu = get_number_of_sms();
let total_num_sm = num_sms_per_gpu * get_number_of_gpus();
let total_blocks_per_sm = 4u32; // Assume each SM can handle 4 blocks concurrently
let total_num_sm = total_blocks_per_sm * total_num_sm;
let total_blocks_per_sm = 4u64; // Assume each SM can handle 4 blocks concurrently
let min_num_waves = 4u64; //Enforce at least 4 waves in the GPU
let elements_per_wave = total_num_sm as u64 / (num_block as u64);
let block_factor = ((2.0f64 * num_block as f64) / 4.0f64).ceil() as u64;
let elements_per_wave = total_blocks_per_sm * total_num_sm as u64 / block_factor;
// We need to enable the new load for pbs benches and for sizes larger than 16 blocks in
// demanding operations for the rest of operations we maintain a minimum of 200
// elements
let min_elements = if op_pbs_count == 1
|| (op_pbs_count > (num_block * num_block) as u64 && num_block >= 16)
{
elements_per_wave * min_num_waves
} else {
200u64
};
let operation_loading = ((total_num_sm as u64 / op_pbs_count) as f64).max(minimum_loading);
let elements = (total_num_sm as f64 * block_multiplicator * operation_loading) as u64;
elements.min(elements_per_wave * min_num_waves) // This threshold is useful for operation
// with both a small number of
// block and low PBs count.
elements.min(min_elements) // This threshold is useful for operation
// with both a small number of
// block and low PBs count.
}
#[cfg(feature = "hpu")]
{