mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix perf regression introduced in cf3f25efdd
This commit is contained in:
@@ -6,6 +6,8 @@ void cuda_negate_ciphertext_64(CudaStreamsFFI streams,
|
||||
uint32_t message_modulus, uint32_t carry_modulus,
|
||||
uint32_t num_radix_blocks) {
|
||||
|
||||
host_negation<uint64_t>(CudaStreams(streams), lwe_array_out, lwe_array_in,
|
||||
auto cuda_streams = CudaStreams(streams);
|
||||
host_negation<uint64_t>(cuda_streams, lwe_array_out, lwe_array_in,
|
||||
message_modulus, carry_modulus, num_radix_blocks);
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ void cuda_scalar_addition_ciphertext_64_inplace(
|
||||
void const *scalar_input, void const *h_scalar_input, uint32_t num_scalars,
|
||||
uint32_t message_modulus, uint32_t carry_modulus) {
|
||||
|
||||
auto cuda_streams = CudaStreams(streams);
|
||||
host_scalar_addition_inplace<uint64_t>(
|
||||
CudaStreams(streams), lwe_array,
|
||||
static_cast<const uint64_t *>(scalar_input),
|
||||
cuda_streams, lwe_array, static_cast<const uint64_t *>(scalar_input),
|
||||
static_cast<const uint64_t *>(h_scalar_input), num_scalars,
|
||||
message_modulus, carry_modulus);
|
||||
cuda_synchronize_stream(cuda_streams.stream(0), cuda_streams.gpu_index(0));
|
||||
}
|
||||
|
||||
@@ -53,7 +53,6 @@ __host__ void host_scalar_addition_inplace(
|
||||
for (uint i = 0; i < num_scalars; i++) {
|
||||
lwe_array->degrees[i] = lwe_array->degrees[i] + h_scalar_input[i];
|
||||
}
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
@@ -94,7 +93,6 @@ __host__ void host_add_scalar_one_inplace(CudaStreams streams,
|
||||
for (uint i = 0; i < lwe_array->num_radix_blocks; i++) {
|
||||
lwe_array->degrees[i] = lwe_array->degrees[i] + 1;
|
||||
}
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
@@ -136,6 +134,5 @@ __host__ void host_scalar_subtraction_inplace(
|
||||
input_lwe_ciphertext_count,
|
||||
lwe_dimension, delta);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,7 @@ void cuda_add_lwe_ciphertext_vector_32(void *stream, uint32_t gpu_index,
|
||||
PANIC("Cuda error: input and output num radix blocks must be the same")
|
||||
host_addition<uint32_t>(static_cast<cudaStream_t>(stream), gpu_index, output,
|
||||
input_1, input_2, output->num_radix_blocks, 0, 0);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -49,6 +50,7 @@ void cuda_add_lwe_ciphertext_vector_64(void *stream, uint32_t gpu_index,
|
||||
PANIC("Cuda error: input and output num radix blocks must be the same")
|
||||
host_addition<uint64_t>(static_cast<cudaStream_t>(stream), gpu_index, output,
|
||||
input_1, input_2, output->num_radix_blocks, 0, 0);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -67,6 +69,7 @@ void cuda_add_lwe_ciphertext_vector_plaintext_vector_32(
|
||||
static_cast<const uint32_t *>(lwe_array_in),
|
||||
static_cast<const uint32_t *>(plaintext_array_in), input_lwe_dimension,
|
||||
input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
/*
|
||||
* Perform the addition of a u64 input LWE ciphertext vector with a u64 input
|
||||
@@ -108,6 +111,7 @@ void cuda_add_lwe_ciphertext_vector_plaintext_vector_64(
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
static_cast<const uint64_t *>(plaintext_array_in), input_lwe_dimension,
|
||||
input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -146,4 +150,5 @@ void cuda_add_lwe_ciphertext_vector_plaintext_64(
|
||||
static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<const uint64_t *>(lwe_array_in), plaintext_in,
|
||||
input_lwe_dimension, input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
@@ -62,7 +62,6 @@ __host__ void host_addition_plaintext(cudaStream_t stream, uint32_t gpu_index,
|
||||
plaintext_addition<T><<<grid, thds, 0, stream>>>(
|
||||
output, lwe_input, plaintext_input, lwe_dimension, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -84,7 +83,6 @@ __host__ void host_addition_plaintext_scalar(
|
||||
plaintext_addition_scalar<T><<<grid, thds, 0, stream>>>(
|
||||
output, lwe_input, plaintext_input, lwe_dimension, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -138,7 +136,6 @@ host_addition(cudaStream_t stream, uint32_t gpu_index,
|
||||
input_1->noise_levels[i] + input_2->noise_levels[i];
|
||||
CHECK_NOISE_LEVEL(output->noise_levels[i], message_modulus, carry_modulus);
|
||||
}
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -233,7 +230,6 @@ __host__ void host_subtraction(cudaStream_t stream, uint32_t gpu_index,
|
||||
subtraction<T>
|
||||
<<<grid, thds, 0, stream>>>(output, input_1, input_2, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -273,7 +269,6 @@ __host__ void host_subtraction_plaintext(cudaStream_t stream,
|
||||
radix_body_subtraction_inplace<T><<<grid, thds, 0, stream>>>(
|
||||
output, plaintext_input, input_lwe_dimension, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -17,6 +17,7 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_32(
|
||||
static_cast<const uint32_t *>(lwe_array_in),
|
||||
static_cast<const uint32_t *>(cleartext_array_in), input_lwe_dimension,
|
||||
input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
/*
|
||||
* Perform the multiplication of a u64 input LWE ciphertext vector with a u64
|
||||
@@ -58,6 +59,7 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_64(
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
static_cast<const uint64_t *>(cleartext_array_in), input_lwe_dimension,
|
||||
input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
void scratch_wrapping_polynomial_mul_one_to_many_64(void *stream,
|
||||
|
||||
@@ -50,7 +50,6 @@ __host__ void host_cleartext_vec_multiplication(
|
||||
cleartext_vec_multiplication<T><<<grid, thds, 0, stream>>>(
|
||||
output, lwe_input, cleartext_input, input_lwe_dimension, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -13,6 +13,7 @@ void cuda_negate_lwe_ciphertext_vector_32(
|
||||
static_cast<uint32_t *>(lwe_array_out),
|
||||
static_cast<const uint32_t *>(lwe_array_in),
|
||||
input_lwe_dimension, input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -46,4 +47,5 @@ void cuda_negate_lwe_ciphertext_vector_64(
|
||||
static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<const uint64_t *>(lwe_array_in),
|
||||
input_lwe_dimension, input_lwe_ciphertext_count);
|
||||
cuda_synchronize_stream(static_cast<cudaStream_t>(stream), gpu_index);
|
||||
}
|
||||
|
||||
@@ -39,7 +39,6 @@ __host__ void host_negation(cudaStream_t stream, uint32_t gpu_index, T *output,
|
||||
|
||||
negation<T><<<grid, thds, 0, stream>>>(output, input, num_entries);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
}
|
||||
|
||||
#endif // CUDA_NEGATE_H
|
||||
|
||||
Reference in New Issue
Block a user