mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-13 08:38:03 -05:00
Compare commits
4 Commits
create-pul
...
as/cuda_st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70fa68bf52 | ||
|
|
7faafd6602 | ||
|
|
8c55f6b8d7 | ||
|
|
f6b1929a8d |
@@ -6,6 +6,15 @@
|
||||
#include <cstdlib>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#define CUDA_STREAM_POOL
|
||||
|
||||
enum CudaStreamType
|
||||
{
|
||||
KEY = 0,
|
||||
ALLOC = 1,
|
||||
TEMP_HELPER = 2,
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define check_cuda_error(ans) \
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "device.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mutex>
|
||||
@@ -6,13 +8,20 @@
|
||||
#include <cuda_profiler_api.h>
|
||||
#endif
|
||||
|
||||
#ifdef CUDA_STREAM_POOL
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#endif
|
||||
|
||||
uint32_t cuda_get_device() {
|
||||
int device;
|
||||
check_cuda_error(cudaGetDevice(&device));
|
||||
return static_cast<uint32_t>(device);
|
||||
}
|
||||
std::mutex pool_mutex;
|
||||
bool mem_pools_enabled = false;
|
||||
|
||||
std::atomic<bool> mem_pools_enabled = false;
|
||||
|
||||
// We use memory pools to reduce some overhead of memory allocations due
|
||||
// to our scratch/release pattern. This function is the simplest way of using
|
||||
@@ -29,13 +38,13 @@ bool mem_pools_enabled = false;
|
||||
// We tested more complex configurations of mempools, but they did not yield
|
||||
// better results.
|
||||
void cuda_setup_mempool(uint32_t caller_gpu_index) {
|
||||
if (!mem_pools_enabled) {
|
||||
pool_mutex.lock();
|
||||
if (mem_pools_enabled)
|
||||
return; // If mem pools are already enabled, we don't need to do anything
|
||||
|
||||
// We do it only once for all GPUs
|
||||
mem_pools_enabled = true;
|
||||
bool pools_not_initialized = false;
|
||||
bool pools_initialized = true;
|
||||
|
||||
// if pools_not_initialized is found, mem_pools_enabled is set to pools_initialized
|
||||
// and the if body runs
|
||||
if (mem_pools_enabled.compare_exchange_strong(pools_not_initialized, pools_initialized)) {
|
||||
uint32_t num_gpus = cuda_get_number_of_gpus();
|
||||
for (uint32_t gpu_index = 0; gpu_index < num_gpus; gpu_index++) {
|
||||
cuda_set_device(gpu_index);
|
||||
@@ -78,7 +87,6 @@ void cuda_setup_mempool(uint32_t caller_gpu_index) {
|
||||
}
|
||||
// We return to the original gpu_index
|
||||
cuda_set_device(caller_gpu_index);
|
||||
pool_mutex.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,18 +123,90 @@ void cuda_event_destroy(cudaEvent_t event, uint32_t gpu_index) {
|
||||
check_cuda_error(cudaEventDestroy(event));
|
||||
}
|
||||
|
||||
#ifdef CUDA_STREAM_POOL
|
||||
struct CudaBoundStream
|
||||
{
|
||||
cudaStream_t stream;
|
||||
uint32_t gpu_index;
|
||||
};
|
||||
|
||||
class CudaStreamPool
|
||||
{
|
||||
std::vector<CudaBoundStream> poolCompute;
|
||||
std::vector<CudaBoundStream> poolTransfer;
|
||||
|
||||
std::mutex mutex_pools;
|
||||
|
||||
size_t nextStream = 0;
|
||||
|
||||
const size_t MAX_STREAMS = 8;
|
||||
|
||||
public:
|
||||
cudaStream_t create_stream(uint32_t gpu_index)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_pools);
|
||||
if (poolCompute.empty())
|
||||
{
|
||||
poolCompute.reserve(MAX_STREAMS);
|
||||
|
||||
cuda_set_device(gpu_index);
|
||||
for (size_t i = 0; i < MAX_STREAMS; i++)
|
||||
{
|
||||
cudaStream_t stream;
|
||||
check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
poolCompute.push_back(CudaBoundStream{stream, gpu_index});
|
||||
}
|
||||
}
|
||||
|
||||
PANIC_IF_FALSE(gpu_index == poolCompute[nextStream].gpu_index, "Bad gpu in stream pool");
|
||||
cudaStream_t res = poolCompute[nextStream].stream;
|
||||
nextStream = (nextStream + 1) % poolCompute.size();
|
||||
return res;
|
||||
}
|
||||
|
||||
void destroy_stream(cudaStream_t stream, uint32_t gpu_index)
|
||||
{
|
||||
//do nothing
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class CudaMultiStreamPool {
|
||||
std::unordered_map<uint32_t, CudaStreamPool> per_gpu_pools;
|
||||
std::mutex pools_mutex; // for creation of the mem managers
|
||||
|
||||
public:
|
||||
CudaStreamPool &get(uint32_t gpu_index) {
|
||||
std::lock_guard<std::mutex> guard(pools_mutex);
|
||||
return per_gpu_pools[gpu_index]; // creates it if it does not exist
|
||||
}
|
||||
};
|
||||
|
||||
CudaMultiStreamPool gCudaStreamPool;
|
||||
#endif
|
||||
|
||||
|
||||
/// Unsafe function to create a CUDA stream, must check first that GPU exists
|
||||
cudaStream_t cuda_create_stream(uint32_t gpu_index) {
|
||||
#ifdef CUDA_STREAM_POOL
|
||||
cuda_set_device(gpu_index); // this will initialize the mempool
|
||||
return gCudaStreamPool.get(gpu_index).create_stream(gpu_index);
|
||||
#else
|
||||
cuda_set_device(gpu_index);
|
||||
cudaStream_t stream;
|
||||
check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
return stream;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Unsafe function to destroy CUDA stream, must check first the GPU exists
|
||||
void cuda_destroy_stream(cudaStream_t stream, uint32_t gpu_index) {
|
||||
#ifdef CUDA_STREAM_POOL
|
||||
gCudaStreamPool.get(gpu_index).destroy_stream(stream, gpu_index);
|
||||
#else
|
||||
cuda_set_device(gpu_index);
|
||||
check_cuda_error(cudaStreamDestroy(stream));
|
||||
#endif
|
||||
}
|
||||
|
||||
void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index) {
|
||||
|
||||
Reference in New Issue
Block a user