mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
feat(gpu): use mempools to optimize mem reuse
This commit is contained in:
@@ -1,15 +1,74 @@
|
||||
#include "device.h"
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mutex>
|
||||
|
||||
uint32_t cuda_get_device() {
|
||||
int device;
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
return static_cast<uint32_t>(device);
|
||||
}
|
||||
std::mutex pool_mutex;
|
||||
bool mem_pools_enabled = false;
|
||||
|
||||
void cuda_setup_mempool(uint32_t caller_gpu_index) {
|
||||
if (!mem_pools_enabled) {
|
||||
pool_mutex.lock();
|
||||
if (mem_pools_enabled)
|
||||
return; // If mem pools are already enabled, we don't need to do anything
|
||||
|
||||
// We do it only once for all GPUs
|
||||
mem_pools_enabled = true;
|
||||
uint32_t num_gpus = cuda_get_number_of_gpus();
|
||||
for (uint32_t gpu_index = 0; gpu_index < num_gpus; gpu_index++) {
|
||||
cuda_set_device(gpu_index);
|
||||
|
||||
size_t total_mem, free_mem;
|
||||
check_cuda_error(cudaMemGetInfo(&free_mem, &total_mem));
|
||||
|
||||
// If we have more than 2% of free memory, we can set up the mempool
|
||||
uint64_t mem_pool_threshold = total_mem / 50; // 2% of total memory
|
||||
mem_pool_threshold =
|
||||
mem_pool_threshold - (mem_pool_threshold % 1024); // Align to 1KB
|
||||
if (mem_pool_threshold < free_mem) {
|
||||
// Get default memory pool
|
||||
cudaMemPool_t default_pool;
|
||||
check_cuda_error(cudaDeviceGetDefaultMemPool(&default_pool, gpu_index));
|
||||
|
||||
// Enable opportunistic reuse
|
||||
int reuse = 1;
|
||||
check_cuda_error(cudaMemPoolSetAttribute(
|
||||
default_pool, cudaMemPoolReuseAllowOpportunistic, &reuse));
|
||||
|
||||
// Prevent memory from being released back to the OS too soon
|
||||
check_cuda_error(cudaMemPoolSetAttribute(
|
||||
default_pool, cudaMemPoolAttrReleaseThreshold,
|
||||
&mem_pool_threshold));
|
||||
|
||||
// Warm up the pool by allocating and freeing a large block
|
||||
cudaStream_t stream;
|
||||
stream = cuda_create_stream(gpu_index);
|
||||
void *warmup_ptr = nullptr;
|
||||
warmup_ptr = cuda_malloc_async(mem_pool_threshold, stream, gpu_index);
|
||||
cuda_drop_async(warmup_ptr, stream, gpu_index);
|
||||
|
||||
// Sync to ensure pool is grown
|
||||
cuda_synchronize_stream(stream, gpu_index);
|
||||
|
||||
// Clean up
|
||||
cuda_destroy_stream(stream, gpu_index);
|
||||
}
|
||||
}
|
||||
// We return to the original gpu_index
|
||||
cuda_set_device(caller_gpu_index);
|
||||
pool_mutex.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_set_device(uint32_t gpu_index) {
|
||||
check_cuda_error(cudaSetDevice(gpu_index));
|
||||
// Mempools are initialized only once in all the GPUS available
|
||||
cuda_setup_mempool(gpu_index);
|
||||
}
|
||||
|
||||
cudaEvent_t cuda_create_event(uint32_t gpu_index) {
|
||||
|
||||
Reference in New Issue
Block a user