feat(compiler): add mapped version of batched bootstrap wrappers for CPU and GPU.

This commit is contained in:
Antoniu Pop
2023-04-11 14:00:49 +01:00
committed by Antoniu Pop
parent 7407948b18
commit 3a679a6f0a
2 changed files with 144 additions and 0 deletions

View File

@@ -169,6 +169,18 @@ void memref_batched_bootstrap_lwe_u64(
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void memref_batched_mapped_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size0,
uint64_t tlu_size1, uint64_t tlu_stride0, uint64_t tlu_stride1,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void *memref_bootstrap_async_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
@@ -271,6 +283,17 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
void memref_batched_mapped_bootstrap_lwe_cuda_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size0,
uint64_t tlu_size1, uint64_t tlu_stride0, uint64_t tlu_stride1,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context);
// Tracing ////////////////////////////////////////////////////////////////////
void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size,

View File

@@ -233,6 +233,104 @@ void memref_batched_bootstrap_lwe_cuda_u64(
cuda_destroy_stream((cudaStream_t *)stream, gpu_idx);
}
void memref_batched_mapped_bootstrap_lwe_cuda_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size0,
uint64_t tlu_size1, uint64_t tlu_stride0, uint64_t tlu_stride1,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
assert(bsk_index == 0 && "multiple bsk is not yet implemented on GPU");
assert(out_size0 == ct0_size0);
assert(out_size1 == glwe_dim * poly_size + 1);
assert((out_size0 == tlu_size0 || tlu_size0 == 1) && "Number of LUTs does not match batch size");
// TODO: Multi GPU
uint32_t gpu_idx = 0;
uint32_t num_samples = out_size0;
uint32_t num_lut_vectors = tlu_size0;
uint64_t ct0_batch_size = ct0_size0 * ct0_size1;
uint64_t out_batch_size = out_size0 * out_size1;
int8_t *pbs_buffer = nullptr;
// Create the cuda stream
// TODO: Should be created by the compiler codegen
void *stream = cuda_create_stream(gpu_idx);
// Get the pointer on the bootstraping key on the GPU
void *fbsk_gpu = memcpy_async_bsk_to_gpu(context, input_lwe_dim, poly_size,
level, glwe_dim, gpu_idx, stream);
// Move the input and output batch of ciphertext to the GPU
// TODO: The allocation should be done by the compiler codegen
void *ct0_gpu = alloc_and_memcpy_async_to_gpu(
ct0_aligned, ct0_offset, ct0_batch_size, gpu_idx, (cudaStream_t *)stream);
void *out_gpu = cuda_malloc_async(out_batch_size * sizeof(uint64_t),
(cudaStream_t *)stream, gpu_idx);
// Construct the glwe accumulator (on CPU)
// TODO: Should be done outside of the bootstrap call, compile time if
// possible. Refactor in progress
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1) * num_lut_vectors;
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
auto tlu = tlu_aligned + tlu_offset;
// Glwe trivial encryption
size_t pos = 0, postlu = 0;
for (size_t l = 0; l < num_lut_vectors; ++l) {
for (size_t i = 0; i < poly_size * glwe_dim; i++) {
glwe_ct[pos++] = 0;
}
for (size_t i = 0; i < poly_size; i++) {
glwe_ct[pos++] = tlu[postlu++];
}
}
// Move the glwe accumulator to the GPU
void *glwe_ct_gpu = alloc_and_memcpy_async_to_gpu(
glwe_ct, 0, glwe_ct_size, gpu_idx, (cudaStream_t *)stream);
// Move test vector indexes to the GPU, the test vector indexes is set of 0
uint32_t lwe_idx = 0,
test_vector_idxes_size = num_samples * sizeof(uint64_t);
uint64_t *test_vector_idxes = (uint64_t *)malloc(test_vector_idxes_size);
if (num_lut_vectors == 1) {
memset((void *)test_vector_idxes, 0, test_vector_idxes_size);
} else {
assert(num_lut_vectors == num_samples);
for (size_t i = 0; i < num_lut_vectors; ++i)
test_vector_idxes[i] = i;
}
void *test_vector_idxes_gpu = cuda_malloc_async(
test_vector_idxes_size, (cudaStream_t *)stream, gpu_idx);
cuda_memcpy_async_to_gpu(test_vector_idxes_gpu, (void *)test_vector_idxes,
test_vector_idxes_size, (cudaStream_t *)stream,
gpu_idx);
// Allocate PBS buffer on GPU
scratch_cuda_bootstrap_amortized_64(
stream, gpu_idx, &pbs_buffer, glwe_dim, poly_size, num_samples,
cuda_get_max_shared_memory(gpu_idx), true);
// Run the bootstrap kernel on the GPU
cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
stream, gpu_idx, out_gpu, glwe_ct_gpu, test_vector_idxes_gpu, ct0_gpu,
fbsk_gpu, pbs_buffer, input_lwe_dim, glwe_dim, poly_size, base_log, level,
num_samples, num_lut_vectors, lwe_idx,
cuda_get_max_shared_memory(gpu_idx));
cleanup_cuda_bootstrap_amortized(stream, gpu_idx, &pbs_buffer);
// Copy the output batch of ciphertext back to CPU
memcpy_async_to_cpu(out_aligned, out_offset, out_batch_size, out_gpu, gpu_idx,
stream);
// free memory that we allocated on gpu
cuda_drop_async(ct0_gpu, (cudaStream_t *)stream, gpu_idx);
cuda_drop_async(out_gpu, (cudaStream_t *)stream, gpu_idx);
cuda_drop_async(glwe_ct_gpu, (cudaStream_t *)stream, gpu_idx);
cuda_drop_async(test_vector_idxes_gpu, (cudaStream_t *)stream, gpu_idx);
cudaStreamSynchronize(*(cudaStream_t *)stream);
// Free the glwe accumulator (on CPU)
free(glwe_ct);
cuda_destroy_stream((cudaStream_t *)stream, gpu_idx);
}
#endif
void memref_encode_plaintext_with_crt(
@@ -696,6 +794,29 @@ void memref_batched_bootstrap_lwe_u64(
}
}
void memref_batched_mapped_bootstrap_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size0, uint64_t out_size1, uint64_t out_stride0,
uint64_t out_stride1, uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size0, uint64_t ct0_size1,
uint64_t ct0_stride0, uint64_t ct0_stride1, uint64_t *tlu_allocated,
uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size0,
uint64_t tlu_size1, uint64_t tlu_stride0, uint64_t tlu_stride1,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t bsk_index,
mlir::concretelang::RuntimeContext *context) {
assert(out_size0 == tlu_size0 && "Number of LUTs does not match batch size");
for (size_t i = 0; i < out_size0; i++) {
memref_bootstrap_lwe_u64(
out_allocated + i * out_size1, out_aligned + i * out_size1, out_offset,
out_size1, out_stride1, ct0_allocated, ct0_aligned + i * ct0_size1,
ct0_offset, ct0_size1, ct0_stride1, tlu_allocated,
tlu_aligned + i * tlu_size1, tlu_offset, tlu_size1, tlu_stride1,
input_lwe_dim, poly_size, level, base_log, glwe_dim, bsk_index,
context);
}
}
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
}