mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
fix: NTT release domain linkage
This commit is contained in:
@@ -84,7 +84,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// (4) multiply A,B
|
||||
CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
|
||||
vec_ops::VecOpsConfig<test_data> config {
|
||||
vec_ops::VecOpsConfig<test_data> config{
|
||||
ntt_config.ctx,
|
||||
true, // is_a_on_device
|
||||
true, // is_b_on_device
|
||||
@@ -92,8 +92,7 @@ int main(int argc, char** argv)
|
||||
false, // is_montgomery
|
||||
false // is_async
|
||||
};
|
||||
CHK_IF_RETURN(
|
||||
vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu));
|
||||
CHK_IF_RETURN(vec_ops::Mul(GpuA, GpuB, NTT_SIZE, config, MulGpu));
|
||||
|
||||
// (5) INTT (in place)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
@@ -118,6 +117,7 @@ int main(int argc, char** argv)
|
||||
benchmark(false); // warmup
|
||||
benchmark(true, 20);
|
||||
|
||||
ntt::ReleaseDomain<test_scalar>(ntt_config.ctx);
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -394,7 +394,8 @@ namespace ntt {
|
||||
template <typename U>
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx, bool fast_tw);
|
||||
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
template <typename U>
|
||||
friend cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
template <typename U, typename E>
|
||||
friend cudaError_t NTT<U, E>(const E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
@@ -488,32 +489,33 @@ namespace ntt {
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
cudaError_t Domain<S>::ReleaseDomain(device_context::DeviceContext& ctx)
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
max_size = 0;
|
||||
max_log_size = 0;
|
||||
cudaFreeAsync(twiddles, ctx.stream);
|
||||
twiddles = nullptr;
|
||||
cudaFreeAsync(internal_twiddles, ctx.stream);
|
||||
internal_twiddles = nullptr;
|
||||
cudaFreeAsync(basic_twiddles, ctx.stream);
|
||||
basic_twiddles = nullptr;
|
||||
coset_index.clear();
|
||||
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
|
||||
|
||||
cudaFreeAsync(fast_external_twiddles, ctx.stream);
|
||||
fast_external_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_internal_twiddles, ctx.stream);
|
||||
fast_internal_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_basic_twiddles, ctx.stream);
|
||||
fast_basic_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_external_twiddles_inv, ctx.stream);
|
||||
fast_external_twiddles_inv = nullptr;
|
||||
cudaFreeAsync(fast_internal_twiddles_inv, ctx.stream);
|
||||
fast_internal_twiddles_inv = nullptr;
|
||||
cudaFreeAsync(fast_basic_twiddles_inv, ctx.stream);
|
||||
fast_basic_twiddles_inv = nullptr;
|
||||
domain.max_size = 0;
|
||||
domain.max_log_size = 0;
|
||||
domain.twiddles = nullptr; // allocated via cudaMallocManaged(...) so released without calling cudaFree(...)
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.internal_twiddles, ctx.stream));
|
||||
domain.internal_twiddles = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.basic_twiddles, ctx.stream));
|
||||
domain.basic_twiddles = nullptr;
|
||||
domain.coset_index.clear();
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles, ctx.stream));
|
||||
domain.fast_external_twiddles = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles, ctx.stream));
|
||||
domain.fast_internal_twiddles = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles, ctx.stream));
|
||||
domain.fast_basic_twiddles = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_external_twiddles_inv, ctx.stream));
|
||||
domain.fast_external_twiddles_inv = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_internal_twiddles_inv, ctx.stream));
|
||||
domain.fast_internal_twiddles_inv = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream));
|
||||
domain.fast_basic_twiddles_inv = nullptr;
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
@@ -40,6 +40,22 @@ namespace ntt {
|
||||
template <typename S>
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode = false);
|
||||
|
||||
/**
|
||||
* Releases and deallocates resources associated with the domain initialized for performing NTTs.
|
||||
* This function should be called to clean up resources once they are no longer needed.
|
||||
* It's important to note that after calling this function, any operation that relies on the released domain will
|
||||
* fail unless InitDomain is called again to reinitialize the resources. Therefore, ensure that ReleaseDomain is
|
||||
* only called when the operations requiring the NTT domain are completely finished and the domain is no longer
|
||||
* needed.
|
||||
* Also note that it is releasing the domain associated to the specific device.
|
||||
* @param ctx Details related to the device context such as its id and stream id.
|
||||
* @return `cudaSuccess` if the resource release was successful, indicating that the domain and its associated
|
||||
* resources have been properly deallocated. Returns an error code otherwise, indicating failure to release
|
||||
* the resources. The error code can be used to diagnose the problem.
|
||||
* */
|
||||
template <typename S>
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
/**
|
||||
* @enum NTTDir
|
||||
* Whether to perform normal forward NTT, or inverse NTT (iNTT). Mathematically, forward NTT computes polynomial
|
||||
|
||||
@@ -195,5 +195,7 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(cudaFree(GpuOutputOld));
|
||||
CHK_IF_RETURN(cudaFree(GpuOutputNew));
|
||||
|
||||
ntt::ReleaseDomain<test_scalar>(ntt_config.ctx);
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
Reference in New Issue
Block a user