Compare commits

...

2 Commits

View File

@@ -38,7 +38,7 @@ int32_t cuda_setup_multi_gpu() {
int get_active_gpu_count(int num_inputs, int gpu_count) {
int active_gpu_count = gpu_count;
if (gpu_count > num_inputs) {
active_gpu_count = num_inputs;
active_gpu_count = 1;
}
return active_gpu_count;
}
@@ -56,8 +56,8 @@ int get_num_inputs_on_gpu(int total_num_inputs, int gpu_index, int gpu_count) {
// If there are fewer inputs than GPUs, not all GPUs are active and GPU 0
// handles everything
if (gpu_count > total_num_inputs) {
if (gpu_index < total_num_inputs) {
num_inputs = 1;
if (gpu_index == 0) {
num_inputs = total_num_inputs;
}
} else {
// If there are more inputs than GPUs, all GPUs are active and compute over