mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
90 lines
3.4 KiB
C++
90 lines
3.4 KiB
C++
#pragma once
|
|
|
|
#include "integer_utilities.h"
|
|
#include "keyswitch/ks_enums.h"
|
|
#include "zk/expand.cuh"
|
|
#include "zk/zk_utilities.h"
|
|
|
|
template <typename Torus> struct int_rerand_mem {
|
|
int_radix_params params;
|
|
|
|
Torus *tmp_zero_lwes;
|
|
Torus *tmp_ksed_zero_lwes;
|
|
Torus *lwe_trivial_indexes;
|
|
uint32_t num_lwes;
|
|
|
|
bool gpu_memory_allocated;
|
|
|
|
std::vector<ks_mem<Torus> *>
|
|
ks_tmp_buf_vec; // not allocated, ReRand not using GEMM KS for now
|
|
// kept empty to pass to the KS function indicating GEMM KS disabled
|
|
|
|
expand_job<Torus> *d_expand_jobs;
|
|
expand_job<Torus> *h_expand_jobs;
|
|
|
|
int_rerand_mem(CudaStreams streams, int_radix_params params,
|
|
const uint32_t num_lwes, const bool allocate_gpu_memory,
|
|
uint64_t &size_tracker)
|
|
: params(params), num_lwes(num_lwes),
|
|
gpu_memory_allocated(allocate_gpu_memory) {
|
|
|
|
tmp_zero_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
|
num_lwes * (params.big_lwe_dimension + 1) * sizeof(Torus),
|
|
streams.stream(0), streams.gpu_index(0), size_tracker,
|
|
allocate_gpu_memory);
|
|
|
|
tmp_ksed_zero_lwes = (Torus *)cuda_malloc_with_size_tracking_async(
|
|
num_lwes * (params.small_lwe_dimension + 1) * sizeof(Torus),
|
|
streams.stream(0), streams.gpu_index(0), size_tracker,
|
|
allocate_gpu_memory);
|
|
|
|
d_expand_jobs =
|
|
static_cast<expand_job<Torus> *>(cuda_malloc_with_size_tracking_async(
|
|
num_lwes * sizeof(expand_job<Torus>), streams.stream(0),
|
|
streams.gpu_index(0), size_tracker, allocate_gpu_memory));
|
|
|
|
h_expand_jobs = static_cast<expand_job<Torus> *>(
|
|
malloc(num_lwes * sizeof(expand_job<Torus>)));
|
|
|
|
auto h_lwe_trivial_indexes =
|
|
static_cast<Torus *>(malloc(num_lwes * sizeof(Torus)));
|
|
for (auto i = 0; i < num_lwes; ++i) {
|
|
h_lwe_trivial_indexes[i] = i;
|
|
}
|
|
lwe_trivial_indexes = (Torus *)cuda_malloc_with_size_tracking_async(
|
|
num_lwes * sizeof(Torus), streams.stream(0), streams.gpu_index(0),
|
|
size_tracker, allocate_gpu_memory);
|
|
cuda_memcpy_async_to_gpu(lwe_trivial_indexes, h_lwe_trivial_indexes,
|
|
num_lwes * sizeof(Torus), streams.stream(0),
|
|
streams.gpu_index(0));
|
|
|
|
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
|
|
|
free(h_lwe_trivial_indexes);
|
|
}
|
|
|
|
void release(CudaStreams streams) {
|
|
cuda_drop_with_size_tracking_async(tmp_zero_lwes, streams.stream(0),
|
|
streams.gpu_index(0),
|
|
gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(tmp_ksed_zero_lwes, streams.stream(0),
|
|
streams.gpu_index(0),
|
|
gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(lwe_trivial_indexes, streams.stream(0),
|
|
streams.gpu_index(0),
|
|
gpu_memory_allocated);
|
|
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
|
|
streams.gpu_index(0),
|
|
gpu_memory_allocated);
|
|
|
|
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
|
|
cleanup_cuda_keyswitch(streams.stream(i), streams.gpu_index(i),
|
|
ks_tmp_buf_vec[i], gpu_memory_allocated);
|
|
}
|
|
ks_tmp_buf_vec.clear();
|
|
|
|
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
|
|
free(h_expand_jobs);
|
|
}
|
|
};
|