mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
2 Commits
pa/test/zk
...
pa/feat/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9e94d9b6b | ||
|
|
d66e36b529 |
@@ -480,20 +480,30 @@ __host__ void host_programmable_bootstrap(
|
||||
double2 *global_join_buffer = pbs_buffer->global_join_buffer;
|
||||
int8_t *d_mem = pbs_buffer->d_mem;
|
||||
|
||||
bool graphCreated = false;
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
for (int i = 0; i < lwe_dimension; i++) {
|
||||
execute_step_one<Torus, params>(
|
||||
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
lwe_input_indexes, bootstrapping_key, global_accumulator,
|
||||
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
|
||||
partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
|
||||
execute_step_two<Torus, params>(
|
||||
stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
|
||||
lut_vector_indexes, bootstrapping_key, global_accumulator,
|
||||
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
|
||||
partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
|
||||
num_many_lut, lut_stride);
|
||||
if (!graphCreated) {
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal);
|
||||
execute_step_one<Torus, params>(
|
||||
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
lwe_input_indexes, bootstrapping_key, global_accumulator,
|
||||
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
|
||||
partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
|
||||
execute_step_two<Torus, params>(
|
||||
stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
|
||||
lut_vector_indexes, bootstrapping_key, global_accumulator,
|
||||
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
|
||||
partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
|
||||
num_many_lut, lut_stride);
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
graphCreated = true;
|
||||
}
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -649,29 +649,41 @@ __host__ void host_multi_bit_programmable_bootstrap(
|
||||
|
||||
auto lwe_chunk_size = buffer->lwe_chunk_size;
|
||||
|
||||
bool graphCreated = false;
|
||||
cudaGraph_t graph;
|
||||
cudaGraphExec_t instance;
|
||||
|
||||
for (uint32_t lwe_offset = 0; lwe_offset < (lwe_dimension / grouping_factor);
|
||||
lwe_offset += lwe_chunk_size) {
|
||||
|
||||
// Compute a keybundle
|
||||
execute_compute_keybundle<Torus, params>(
|
||||
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
|
||||
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
grouping_factor, level_count, lwe_offset);
|
||||
// Accumulate
|
||||
uint32_t chunk_size = std::min(
|
||||
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
for (uint32_t j = 0; j < chunk_size; j++) {
|
||||
execute_step_one<Torus, params>(
|
||||
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
lwe_input_indexes, buffer, num_samples, lwe_dimension, glwe_dimension,
|
||||
polynomial_size, base_log, level_count, j, lwe_offset);
|
||||
if (!graphCreated) {
|
||||
cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal);
|
||||
// Compute a keybundle
|
||||
execute_compute_keybundle<Torus, params>(
|
||||
stream, gpu_index, lwe_array_in, lwe_input_indexes, bootstrapping_key,
|
||||
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
grouping_factor, level_count, lwe_offset);
|
||||
// Accumulate
|
||||
uint32_t chunk_size = std::min(
|
||||
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
|
||||
for (uint32_t j = 0; j < chunk_size; j++) {
|
||||
execute_step_one<Torus, params>(
|
||||
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
lwe_input_indexes, buffer, num_samples, lwe_dimension,
|
||||
glwe_dimension, polynomial_size, base_log, level_count, j,
|
||||
lwe_offset);
|
||||
|
||||
execute_step_two<Torus, params>(
|
||||
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
|
||||
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
grouping_factor, level_count, j, lwe_offset, num_many_lut,
|
||||
lut_stride);
|
||||
execute_step_two<Torus, params>(
|
||||
stream, gpu_index, lwe_array_out, lwe_output_indexes, buffer,
|
||||
num_samples, lwe_dimension, glwe_dimension, polynomial_size,
|
||||
grouping_factor, level_count, j, lwe_offset, num_many_lut,
|
||||
lut_stride);
|
||||
}
|
||||
cudaStreamEndCapture(stream, &graph);
|
||||
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
|
||||
graphCreated = true;
|
||||
}
|
||||
cudaGraphLaunch(instance, stream);
|
||||
}
|
||||
}
|
||||
#endif // MULTIBIT_PBS_H
|
||||
|
||||
Reference in New Issue
Block a user