chore(gpu): bench KS latency batches

This commit is contained in:
Andrei Stoian
2025-11-06 18:39:20 +01:00
committed by Andrei Stoian
parent d6a0a366b9
commit e2063c8ef4
24 changed files with 1239 additions and 269 deletions

View File

@@ -386,7 +386,7 @@ mod cuda {
.keyswitch_key(ksk_big_to_small)
.build();
let bench_id;
let mut bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
@@ -423,120 +423,163 @@ mod cuda {
&mut output_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
black_box(&mut ct_gpu);
})
});
}
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&bench_id,
*params,
name,
"ks",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
BenchmarkType::Throughput => {
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);
let gpu_count = get_number_of_gpus() as usize;
bench_id = format!("{bench_name}::throughput::{name}");
let blocks: usize = 1;
let elements = throughput_num_threads(blocks, 1);
let elements_per_stream = elements as usize / gpu_count;
bench_group.throughput(Throughput::Elements(elements));
bench_group.sample_size(50);
bench_group.bench_function(&bench_id, |b| {
let setup_encrypted_values = || {
let local_streams = cuda_local_streams_core();
let plaintext_list = PlaintextList::new(
Scalar::ZERO,
PlaintextCount(elements_per_stream),
for uses_gemm_ks in [false, true] {
for uses_simple_indices in [false, true] {
let indices_str = if uses_simple_indices {
"simple"
} else {
"complex"
};
let gemm_str = if uses_gemm_ks { "gemm" } else { "classical" };
bench_id = format!(
"{bench_name}::throughput::{gemm_str}::{indices_str}_indices::{name}",
);
let input_cts = (0..gpu_count)
.map(|i| {
let mut input_ct_list = LweCiphertextList::new(
let blocks: usize = 256;
let elements = gpu_count * blocks;
let elements_per_stream = elements / gpu_count;
bench_group.throughput(Throughput::Elements(elements as u64));
bench_group.sample_size(50);
bench_group.bench_function(&bench_id, |b| {
let setup_encrypted_values = || {
let local_streams = cuda_local_streams_core();
let plaintext_list = PlaintextList::new(
Scalar::ZERO,
big_lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream),
params.ciphertext_modulus.unwrap(),
PlaintextCount(elements_per_stream),
);
encrypt_lwe_ciphertext_list(
&big_lwe_sk,
&mut input_ct_list,
&plaintext_list,
params.lwe_noise_distribution.unwrap(),
&mut encryption_generator,
);
let input_ks_list = LweCiphertextList::from_container(
input_ct_list.into_container(),
big_lwe_sk.lwe_dimension().to_lwe_size(),
params.ciphertext_modulus.unwrap(),
);
CudaLweCiphertextList::from_lwe_ciphertext_list(
&input_ks_list,
&local_streams[i],
)
})
.collect::<Vec<_>>();
let output_cts = (0..gpu_count)
.map(|i| {
let output_ct_list = LweCiphertextList::new(
Scalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream),
params.ciphertext_modulus.unwrap(),
);
CudaLweCiphertextList::from_lwe_ciphertext_list(
&output_ct_list,
&local_streams[i],
)
})
.collect::<Vec<_>>();
let input_cts = (0..gpu_count)
.map(|i| {
let mut input_ct_list = LweCiphertextList::new(
Scalar::ZERO,
big_lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream),
params.ciphertext_modulus.unwrap(),
);
encrypt_lwe_ciphertext_list(
&big_lwe_sk,
&mut input_ct_list,
&plaintext_list,
params.lwe_noise_distribution.unwrap(),
&mut encryption_generator,
);
let input_ks_list = LweCiphertextList::from_container(
input_ct_list.into_container(),
big_lwe_sk.lwe_dimension().to_lwe_size(),
params.ciphertext_modulus.unwrap(),
);
CudaLweCiphertextList::from_lwe_ciphertext_list(
&input_ks_list,
&local_streams[i],
)
})
.collect::<Vec<_>>();
let h_indexes = (0..(elements / gpu_count as u64))
.map(CastFrom::cast_from)
.collect::<Vec<_>>();
let cuda_indexes_vec = (0..gpu_count)
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
.collect::<Vec<_>>();
local_streams.iter().for_each(|stream| stream.synchronize());
let output_cts = (0..gpu_count)
.map(|i| {
let output_ct_list = LweCiphertextList::new(
Scalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream),
params.ciphertext_modulus.unwrap(),
);
CudaLweCiphertextList::from_lwe_ciphertext_list(
&output_ct_list,
&local_streams[i],
)
})
.collect::<Vec<_>>();
(input_cts, output_cts, cuda_indexes_vec, local_streams)
};
let indexes_range: Vec<u64> = if uses_simple_indices {
(0..(elements / gpu_count) as u64).collect()
} else {
(0..(elements / gpu_count) as u64).rev().collect()
};
let h_indexes = indexes_range
.iter()
.map(|v| CastFrom::cast_from(*v))
.collect::<Vec<_>>();
b.iter_batched(
setup_encrypted_values,
|(input_cts, mut output_cts, cuda_indexes_vec, local_streams)| {
(0..gpu_count)
.into_par_iter()
.zip(input_cts.par_iter())
.zip(output_cts.par_iter_mut())
.zip(local_streams.par_iter())
.for_each(|(((i, input_ct), output_ct), local_stream)| {
cuda_keyswitch_lwe_ciphertext(
gpu_keys_vec[i].ksk.as_ref().unwrap(),
input_ct,
output_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
local_stream,
);
})
},
criterion::BatchSize::SmallInput,
)
});
let cuda_indexes_vec = (0..gpu_count)
.map(|i| CudaIndexes::new(&h_indexes, &local_streams[i], 0))
.collect::<Vec<_>>();
local_streams.iter().for_each(|stream| stream.synchronize());
(input_cts, output_cts, cuda_indexes_vec, local_streams)
};
b.iter_batched(
setup_encrypted_values,
|(
input_cts,
mut output_cts,
cuda_indexes_vec,
local_streams,
)| {
(0..gpu_count)
.into_par_iter()
.zip(input_cts.par_iter())
.zip(output_cts.par_iter_mut())
.zip(local_streams.par_iter())
.for_each(
|(((i, input_ct), output_ct), local_stream)| {
cuda_keyswitch_lwe_ciphertext(
gpu_keys_vec[i].ksk.as_ref().unwrap(),
input_ct,
output_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
uses_simple_indices,
local_stream,
uses_gemm_ks,
);
},
)
},
criterion::BatchSize::SmallInput,
)
});
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&bench_id,
*params,
name,
"ks",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
}
}
};
let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&bench_id,
*params,
name,
"ks",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
}

View File

@@ -630,7 +630,9 @@ mod cuda {
&mut output_ks_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
cuda_programmable_bootstrap_lwe_ciphertext(
&output_ks_ct_gpu,
@@ -782,7 +784,9 @@ mod cuda {
output_ks_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
true,
local_stream,
false,
);
cuda_programmable_bootstrap_lwe_ciphertext(
output_ks_ct,
@@ -937,7 +941,9 @@ mod cuda {
&mut output_ks_ct_gpu,
&cuda_indexes.d_input,
&cuda_indexes.d_output,
true,
&streams,
false,
);
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
&output_ks_ct_gpu,
@@ -1088,7 +1094,9 @@ mod cuda {
output_ks_ct,
&cuda_indexes_vec[i].d_input,
&cuda_indexes_vec[i].d_output,
true,
local_stream,
false,
);
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
output_ks_ct,