fix: NTT release domain linkage

This commit is contained in:
Yuval Shekel
2024-04-03 13:09:13 +00:00
committed by yshekel
parent 25ac705c3b
commit 406020bda6
4 changed files with 46 additions and 26 deletions

View File

@@ -84,7 +84,7 @@ int main(int argc, char** argv)
// (4) multiply A,B
CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
vec_ops::VecOpsConfig<test_data> config {
vec_ops::VecOpsConfig<test_data> config{
ntt_config.ctx,
true, // is_a_on_device
true, // is_b_on_device
@@ -92,8 +92,7 @@ int main(int argc, char** argv)
false, // is_montgomery
false // is_async
};
CHK_IF_RETURN(
vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu));
CHK_IF_RETURN(vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu));
// (5) INTT (in place)
ntt_config.are_inputs_on_device = true;
@@ -118,6 +117,7 @@ int main(int argc, char** argv)
benchmark(false); // warmup
benchmark(true, 20);
ntt::ReleaseDomain<test_scalar>(ntt_config.ctx);
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
return 0;

View File

@@ -394,7 +394,8 @@ namespace ntt {
template <typename U>
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx, bool fast_tw);
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
template <typename U>
friend cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
template <typename U, typename E>
friend cudaError_t NTT<U, E>(const E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
@@ -488,32 +489,33 @@ namespace ntt {
}
template <typename S>
cudaError_t Domain<S>::ReleaseDomain(device_context::DeviceContext& ctx)
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx)
{
CHK_INIT_IF_RETURN();
max_size = 0;
max_log_size = 0;
cudaFreeAsync(twiddles, ctx.stream);
twiddles = nullptr;
cudaFreeAsync(internal_twiddles, ctx.stream);
internal_twiddles = nullptr;
cudaFreeAsync(basic_twiddles, ctx.stream);
basic_twiddles = nullptr;
coset_index.clear();
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
cudaFreeAsync(fast_external_twiddles, ctx.stream);
fast_external_twiddles = nullptr;
cudaFreeAsync(fast_internal_twiddles, ctx.stream);
fast_internal_twiddles = nullptr;
cudaFreeAsync(fast_basic_twiddles, ctx.stream);
fast_basic_twiddles = nullptr;
cudaFreeAsync(fast_external_twiddles_inv, ctx.stream);
fast_external_twiddles_inv = nullptr;
cudaFreeAsync(fast_internal_twiddles_inv, ctx.stream);
fast_internal_twiddles_inv = nullptr;
cudaFreeAsync(fast_basic_twiddles_inv, ctx.stream);
fast_basic_twiddles_inv = nullptr;
domain.max_size = 0;
domain.max_log_size = 0;
domain.twiddles = nullptr; // allocated via cudaMallocManaged(...) so released without calling cudaFree(...)
CHK_IF_RETURN(cudaFreeAsync(domain.internal_twiddles, ctx.stream));
domain.internal_twiddles = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.basic_twiddles, ctx.stream));
domain.basic_twiddles = nullptr;
domain.coset_index.clear();
CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles, ctx.stream));
domain.fast_external_twiddles = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles, ctx.stream));
domain.fast_internal_twiddles = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles, ctx.stream));
domain.fast_basic_twiddles = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles_inv, ctx.stream));
domain.fast_external_twiddles_inv = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles_inv, ctx.stream));
domain.fast_internal_twiddles_inv = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream));
domain.fast_basic_twiddles_inv = nullptr;
return CHK_LAST();
}

View File

@@ -40,6 +40,22 @@ namespace ntt {
template <typename S>
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode = false);
/**
* Releases and deallocates resources associated with the domain initialized for performing NTTs.
* This function should be called to clean up resources once they are no longer needed.
* It's important to note that after calling this function, any operation that relies on the released domain will
* fail unless InitDomain is called again to reinitialize the resources. Therefore, ensure that ReleaseDomain is
* only called when the operations requiring the NTT domain are completely finished and the domain is no longer
* needed.
* Also note that it is releasing the domain associated to the specific device.
* @param ctx Details related to the device context such as its id and stream id.
* @return `cudaSuccess` if the resource release was successful, indicating that the domain and its associated
* resources have been properly deallocated. Returns an error code otherwise, indicating failure to release
* the resources. The error code can be used to diagnose the problem.
* */
template <typename S>
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
/**
* @enum NTTDir
* Whether to perform normal forward NTT, or inverse NTT (iNTT). Mathematically, forward NTT computes polynomial

View File

@@ -195,5 +195,7 @@ int main(int argc, char** argv)
CHK_IF_RETURN(cudaFree(GpuOutputOld));
CHK_IF_RETURN(cudaFree(GpuOutputNew));
ntt::ReleaseDomain<test_scalar>(ntt_config.ctx);
return CHK_LAST();
}