From 406020bda67f0869ebbf17d07162c4b5bed867a0 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Wed, 3 Apr 2024 13:09:13 +0000 Subject: [PATCH] fix: NTT release domain linkage --- .../c++/polynomial_multiplication/example.cu | 6 +-- icicle/appUtils/ntt/ntt.cu | 48 ++++++++++--------- icicle/appUtils/ntt/ntt.cuh | 16 +++++++ icicle/appUtils/ntt/tests/verification.cu | 2 + 4 files changed, 46 insertions(+), 26 deletions(-) diff --git a/examples/c++/polynomial_multiplication/example.cu b/examples/c++/polynomial_multiplication/example.cu index b9b0bd68..dfaa5e87 100644 --- a/examples/c++/polynomial_multiplication/example.cu +++ b/examples/c++/polynomial_multiplication/example.cu @@ -84,7 +84,7 @@ int main(int argc, char** argv) // (4) multiply A,B CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream)); - vec_ops::VecOpsConfig config { + vec_ops::VecOpsConfig config{ ntt_config.ctx, true, // is_a_on_device true, // is_b_on_device @@ -92,8 +92,7 @@ int main(int argc, char** argv) false, // is_montgomery false // is_async }; - CHK_IF_RETURN( - vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu)); + CHK_IF_RETURN(vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu)); // (5) INTT (in place) ntt_config.are_inputs_on_device = true; @@ -118,6 +117,7 @@ int main(int argc, char** argv) benchmark(false); // warmup benchmark(true, 20); + ntt::ReleaseDomain(ntt_config.ctx); CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); return 0; diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 8dc0b1b7..493409d2 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -394,7 +394,8 @@ namespace ntt { template friend cudaError_t InitDomain(U primitive_root, device_context::DeviceContext& ctx, bool fast_tw); - cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); + template + friend cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); template friend cudaError_t NTT(const E* input, int size, NTTDir dir, NTTConfig& config, E* output); @@ -488,32 +489,33 @@ namespace ntt { } template - cudaError_t Domain::ReleaseDomain(device_context::DeviceContext& ctx) + cudaError_t ReleaseDomain(device_context::DeviceContext& ctx) { CHK_INIT_IF_RETURN(); - max_size = 0; - max_log_size = 0; - cudaFreeAsync(twiddles, ctx.stream); - twiddles = nullptr; - cudaFreeAsync(internal_twiddles, ctx.stream); - internal_twiddles = nullptr; - cudaFreeAsync(basic_twiddles, ctx.stream); - basic_twiddles = nullptr; - coset_index.clear(); + Domain& domain = domains_for_devices[ctx.device_id]; - cudaFreeAsync(fast_external_twiddles, ctx.stream); - fast_external_twiddles = nullptr; - cudaFreeAsync(fast_internal_twiddles, ctx.stream); - fast_internal_twiddles = nullptr; - cudaFreeAsync(fast_basic_twiddles, ctx.stream); - fast_basic_twiddles = nullptr; - cudaFreeAsync(fast_external_twiddles_inv, ctx.stream); - fast_external_twiddles_inv = nullptr; - cudaFreeAsync(fast_internal_twiddles_inv, ctx.stream); - fast_internal_twiddles_inv = nullptr; - cudaFreeAsync(fast_basic_twiddles_inv, ctx.stream); - fast_basic_twiddles_inv = nullptr; + domain.max_size = 0; + domain.max_log_size = 0; + domain.twiddles = nullptr; // allocated via cudaMallocManaged(...) so released without calling cudaFree(...) + CHK_IF_RETURN(cudaFreeAsync(domain.internal_twiddles, ctx.stream)); + domain.internal_twiddles = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.basic_twiddles, ctx.stream)); + domain.basic_twiddles = nullptr; + domain.coset_index.clear(); + + CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles, ctx.stream)); + domain.fast_external_twiddles = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles, ctx.stream)); + domain.fast_internal_twiddles = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles, ctx.stream)); + domain.fast_basic_twiddles = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles_inv, ctx.stream)); + domain.fast_external_twiddles_inv = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles_inv, ctx.stream)); + domain.fast_internal_twiddles_inv = nullptr; + CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream)); + domain.fast_basic_twiddles_inv = nullptr; return CHK_LAST(); } diff --git a/icicle/appUtils/ntt/ntt.cuh b/icicle/appUtils/ntt/ntt.cuh index e77da86d..7b9e4454 100644 --- a/icicle/appUtils/ntt/ntt.cuh +++ b/icicle/appUtils/ntt/ntt.cuh @@ -40,6 +40,22 @@ namespace ntt { template cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode = false); + /** + * Releases and deallocates resources associated with the domain initialized for performing NTTs. + * This function should be called to clean up resources once they are no longer needed. + * It's important to note that after calling this function, any operation that relies on the released domain will + * fail unless InitDomain is called again to reinitialize the resources. Therefore, ensure that ReleaseDomain is + * only called when the operations requiring the NTT domain are completely finished and the domain is no longer + * needed. + * Also note that it is releasing the domain associated to the specific device. + * @param ctx Details related to the device context such as its id and stream id. + * @return `cudaSuccess` if the resource release was successful, indicating that the domain and its associated + * resources have been properly deallocated. Returns an error code otherwise, indicating failure to release + * the resources. The error code can be used to diagnose the problem. + * */ + template + cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); + /** * @enum NTTDir * Whether to perform normal forward NTT, or inverse NTT (iNTT). Mathematically, forward NTT computes polynomial diff --git a/icicle/appUtils/ntt/tests/verification.cu b/icicle/appUtils/ntt/tests/verification.cu index 751ffe09..98f9360c 100644 --- a/icicle/appUtils/ntt/tests/verification.cu +++ b/icicle/appUtils/ntt/tests/verification.cu @@ -195,5 +195,7 @@ int main(int argc, char** argv) CHK_IF_RETURN(cudaFree(GpuOutputOld)); CHK_IF_RETURN(cudaFree(GpuOutputNew)); + ntt::ReleaseDomain(ntt_config.ctx); + return CHK_LAST(); } \ No newline at end of file