fix(gpu-wrappers): fix KS/BS wrappers for GPU (memory management).

This commit is contained in:
Antoniu Pop
2023-01-17 22:29:12 +00:00
committed by Antoniu Pop
parent ddf905b4db
commit 0329d4fc2d

View File

@@ -149,7 +149,6 @@ void memref_batched_keyswitch_lwe_cuda_u64(
// free memory that we allocated on gpu
cuda_drop(ct0_gpu, gpu_idx);
cuda_drop(out_gpu, gpu_idx);
cuda_drop(ksk_gpu, gpu_idx);
cuda_destroy_stream(stream, gpu_idx);
}
@@ -197,10 +196,7 @@ void memref_batched_bootstrap_lwe_cuda_u64(
// Move the glwe accumulator to the GPU
void *glwe_ct_gpu =
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_size, gpu_idx, stream);
// Free the glwe accumulator (on CPU)
free(glwe_ct);
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_len, gpu_idx, stream);
// Move test vector indexes to the GPU, the test vector indexes is set of 0
uint32_t num_test_vectors = 1, lwe_idx = 0,
@@ -220,6 +216,8 @@ void memref_batched_bootstrap_lwe_cuda_u64(
memcpy_async_to_cpu(out_aligned, out_offset, out_batch_size, out_gpu, gpu_idx,
stream);
cuda_synchronize_device(gpu_idx);
// Free the glwe accumulator (on CPU)
free(glwe_ct);
// free memory that we allocated on gpu
cuda_drop(ct0_gpu, gpu_idx);
cuda_drop(out_gpu, gpu_idx);