diff --git a/.codespellignore b/.codespellignore index b41990df..acadd68e 100644 --- a/.codespellignore +++ b/.codespellignore @@ -1,3 +1,5 @@ inout crate lmit +mut +uint diff --git a/README.md b/README.md index 8ee6e7ec..9094bf77 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,6 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP. > [!NOTE] > Developers: We highly recommend reading our [documentation] - > [!TIP] > Try out ICICLE by running some [examples] using ICICLE in C++ and our Rust bindings @@ -43,8 +42,8 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP. - [GCC](https://gcc.gnu.org/install/download.html) version 9, latest version is recommended. - Any Nvidia GPU (which supports CUDA Toolkit version 12.0 or above). -> [!NOTE] -> It is possible to use CUDA 11 for cards which dont support CUDA 12, however we dont officially support this version and in the future there may be issues. +> [!NOTE] +> It is possible to use CUDA 11 for cards which don't support CUDA 12, however we don't officially support this version and in the future there may be issues. ### Accessing Hardware diff --git a/examples/rust/msm/src/main.rs b/examples/rust/msm/src/main.rs index 219ea69d..75544194 100644 --- a/examples/rust/msm/src/main.rs +++ b/examples/rust/msm/src/main.rs @@ -1,43 +1,20 @@ -use icicle_bn254::curve::{ - CurveCfg, - ScalarCfg, - G1Projective, - G2CurveCfg, - G2Projective -}; +use icicle_bn254::curve::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg}; use icicle_bls12_377::curve::{ - CurveCfg as BLS12377CurveCfg, - ScalarCfg as BLS12377ScalarCfg, - G1Projective as BLS12377G1Projective + CurveCfg as BLS12377CurveCfg, G1Projective as BLS12377G1Projective, ScalarCfg as BLS12377ScalarCfg, }; -use icicle_cuda_runtime::{ - stream::CudaStream, - memory::HostOrDeviceSlice -}; +use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream}; -use icicle_core::{ - msm, - curve::Curve, - traits::GenerateRandom -}; +use icicle_core::{curve::Curve, msm, traits::GenerateRandom}; #[cfg(feature = "arkworks")] use icicle_core::traits::ArkConvertible; #[cfg(feature = "arkworks")] -use ark_bn254::{ - G1Projective as Bn254ArkG1Projective, - G1Affine as Bn254G1Affine, - Fr as Bn254Fr -}; +use ark_bls12_377::{Fr as Bls12377Fr, G1Affine as Bls12377G1Affine, G1Projective as Bls12377ArkG1Projective}; #[cfg(feature = "arkworks")] -use ark_bls12_377::{ - G1Projective as Bls12377ArkG1Projective, - G1Affine as Bls12377G1Affine, - Fr as Bls12377Fr -}; +use ark_bn254::{Fr as Bn254Fr, G1Affine as Bn254G1Affine, G1Projective as Bn254ArkG1Projective}; #[cfg(feature = "arkworks")] use ark_ec::scalar_mul::variable_base::VariableBaseMSM; @@ -67,23 +44,26 @@ fn main() { let upper_points = CurveCfg::generate_random_affine_points(upper_size); let g2_upper_points = G2CurveCfg::generate_random_affine_points(upper_size); let upper_scalars = ScalarCfg::generate_random(upper_size); - + println!("Generating random inputs on host for bls12377..."); let upper_points_bls12377 = BLS12377CurveCfg::generate_random_affine_points(upper_size); let upper_scalars_bls12377 = BLS12377ScalarCfg::generate_random(upper_size); - for i in lower_bound..=upper_bound { + for i in lower_bound..=upper_bound { let log_size = i; let size = 1 << log_size; - println!("---------------------- MSM size 2^{}={} ------------------------", log_size, size); + println!( + "---------------------- MSM size 2^{}={} ------------------------", + log_size, size + ); // Setting Bn254 points and scalars let points = HostOrDeviceSlice::Host(upper_points[..size].to_vec()); let g2_points = HostOrDeviceSlice::Host(g2_upper_points[..size].to_vec()); let scalars = HostOrDeviceSlice::Host(upper_scalars[..size].to_vec()); - + // Setting bls12377 points and scalars // let points_bls12377 = &upper_points_bls12377[..size]; - let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size]; + let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size]; let scalars_bls12377 = HostOrDeviceSlice::Host(upper_scalars_bls12377[..size].to_vec()); println!("Configuring bn254 MSM..."); @@ -91,18 +71,24 @@ fn main() { let mut g2_msm_results: HostOrDeviceSlice<'_, G2Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap(); let stream = CudaStream::create().unwrap(); let g2_stream = CudaStream::create().unwrap(); - let mut cfg = msm::get_default_msm_config::(); - let mut g2_cfg = msm::get_default_msm_config::(); - cfg.ctx.stream = &stream; - g2_cfg.ctx.stream = &g2_stream; + let mut cfg = msm::MSMConfig::default(); + let mut g2_cfg = msm::MSMConfig::default(); + cfg.ctx + .stream = &stream; + g2_cfg + .ctx + .stream = &g2_stream; cfg.is_async = true; g2_cfg.is_async = true; println!("Configuring bls12377 MSM..."); - let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap(); + let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> = + HostOrDeviceSlice::cuda_malloc(1).unwrap(); let stream_bls12377 = CudaStream::create().unwrap(); - let mut cfg_bls12377 = msm::get_default_msm_config::(); - cfg_bls12377.ctx.stream = &stream_bls12377; + let mut cfg_bls12377 = msm::MSMConfig::default(); + cfg_bls12377 + .ctx + .stream = &stream_bls12377; cfg_bls12377.is_async = true; println!("Executing bn254 MSM on device..."); @@ -110,22 +96,37 @@ fn main() { let start = Instant::now(); msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap(); #[cfg(feature = "profile")] - println!("ICICLE BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "ICICLE BN254 MSM on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); msm::msm(&scalars, &g2_points, &g2_cfg, &mut g2_msm_results).unwrap(); - println!("Executing bls12377 MSM on device..."); #[cfg(feature = "profile")] let start = Instant::now(); - msm::msm(&scalars_bls12377, &points_bls12377, &cfg_bls12377, &mut msm_results_bls12377 ).unwrap(); + msm::msm( + &scalars_bls12377, + &points_bls12377, + &cfg_bls12377, + &mut msm_results_bls12377, + ) + .unwrap(); #[cfg(feature = "profile")] - println!("ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); println!("Moving results to host.."); let mut msm_host_result = vec![G1Projective::zero(); 1]; let mut g2_msm_host_result = vec![G2Projective::zero(); 1]; let mut msm_host_result_bls12377 = vec![BLS12377G1Projective::zero(); 1]; - + stream .synchronize() .unwrap(); @@ -140,7 +141,7 @@ fn main() { .unwrap(); println!("bn254 result: {:#?}", msm_host_result); println!("G2 bn254 result: {:#?}", g2_msm_host_result); - + stream_bls12377 .synchronize() .unwrap(); @@ -148,37 +149,70 @@ fn main() { .copy_to_host(&mut msm_host_result_bls12377[..]) .unwrap(); println!("bls12377 result: {:#?}", msm_host_result_bls12377); - + #[cfg(feature = "arkworks")] { println!("Checking against arkworks..."); - let ark_points: Vec = points.as_slice().iter().map(|&point| point.to_ark()).collect(); - let ark_scalars: Vec = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect(); + let ark_points: Vec = points + .as_slice() + .iter() + .map(|&point| point.to_ark()) + .collect(); + let ark_scalars: Vec = scalars + .as_slice() + .iter() + .map(|scalar| scalar.to_ark()) + .collect(); - let ark_points_bls12377: Vec = points_bls12377.as_slice().iter().map(|point| point.to_ark()).collect(); - let ark_scalars_bls12377: Vec = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect(); + let ark_points_bls12377: Vec = points_bls12377 + .as_slice() + .iter() + .map(|point| point.to_ark()) + .collect(); + let ark_scalars_bls12377: Vec = scalars_bls12377 + .as_slice() + .iter() + .map(|scalar| scalar.to_ark()) + .collect(); #[cfg(feature = "profile")] let start = Instant::now(); let bn254_ark_msm_res = Bn254ArkG1Projective::msm(&ark_points, &ark_scalars).unwrap(); println!("Arkworks Bn254 result: {:#?}", bn254_ark_msm_res); #[cfg(feature = "profile")] - println!("Ark BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "Ark BN254 MSM on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); #[cfg(feature = "profile")] let start = Instant::now(); - let bls12377_ark_msm_res = Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap(); + let bls12377_ark_msm_res = + Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap(); println!("Arkworks Bls12377 result: {:#?}", bls12377_ark_msm_res); #[cfg(feature = "profile")] - println!("Ark BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "Ark BLS12377 MSM on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); let bn254_icicle_msm_res_as_ark = msm_host_result[0].to_ark(); let bls12377_icicle_msm_res_as_ark = msm_host_result_bls12377[0].to_ark(); - println!("Bn254 MSM is correct: {}", bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark)); - println!("Bls12377 MSM is correct: {}", bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark)); + println!( + "Bn254 MSM is correct: {}", + bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark) + ); + println!( + "Bls12377 MSM is correct: {}", + bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark) + ); } - + println!("Cleaning up bn254..."); stream .destroy() diff --git a/examples/rust/ntt/src/main.rs b/examples/rust/ntt/src/main.rs index f39a5804..1de1b195 100644 --- a/examples/rust/ntt/src/main.rs +++ b/examples/rust/ntt/src/main.rs @@ -1,28 +1,18 @@ -use icicle_bn254::curve::{ - ScalarCfg, - ScalarField, -}; +use icicle_bn254::curve::{ScalarCfg, ScalarField}; -use icicle_bls12_377::curve::{ - ScalarCfg as BLS12377ScalarCfg, - ScalarField as BLS12377ScalarField -}; +use icicle_bls12_377::curve::{ScalarCfg as BLS12377ScalarCfg, ScalarField as BLS12377ScalarField}; -use icicle_cuda_runtime::{ - stream::CudaStream, - memory::HostOrDeviceSlice, - device_context::get_default_device_context -}; +use icicle_cuda_runtime::{device_context::DeviceContext, memory::HostOrDeviceSlice, stream::CudaStream}; use icicle_core::{ ntt::{self, NTT}, - traits::{GenerateRandom, FieldImpl} + traits::{FieldImpl, GenerateRandom}, }; use icicle_core::traits::ArkConvertible; -use ark_bn254::Fr as Bn254Fr; use ark_bls12_377::Fr as Bls12377Fr; +use ark_bn254::Fr as Bn254Fr; use ark_ff::FftField; use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; use ark_std::cmp::{Ord, Ordering}; @@ -45,37 +35,52 @@ fn main() { println!("Running Icicle Examples: Rust NTT"); let log_size = args.size; let size = 1 << log_size; - println!("---------------------- NTT size 2^{}={} ------------------------", log_size, size); + println!( + "---------------------- NTT size 2^{}={} ------------------------", + log_size, size + ); // Setting Bn254 points and scalars println!("Generating random inputs on host for bn254..."); let scalars = HostOrDeviceSlice::Host(ScalarCfg::generate_random(size)); let mut ntt_results: HostOrDeviceSlice<'_, ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap(); - + // Setting bls12377 points and scalars println!("Generating random inputs on host for bls12377..."); let scalars_bls12377 = HostOrDeviceSlice::Host(BLS12377ScalarCfg::generate_random(size)); - let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap(); - + let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> = + HostOrDeviceSlice::cuda_malloc(size).unwrap(); + println!("Setting up bn254 Domain..."); - let icicle_omega = ::get_root_of_unity(size.try_into().unwrap()).unwrap(); - let ctx = get_default_device_context(); + let icicle_omega = ::get_root_of_unity( + size.try_into() + .unwrap(), + ) + .unwrap(); + let ctx = DeviceContext::default(); ScalarCfg::initialize_domain(ScalarField::from_ark(icicle_omega), &ctx).unwrap(); println!("Configuring bn254 NTT..."); let stream = CudaStream::create().unwrap(); - let mut cfg = ntt::get_default_ntt_config::(); - cfg.ctx.stream = &stream; + let mut cfg = ntt::NTTConfig::default(); + cfg.ctx + .stream = &stream; cfg.is_async = true; println!("Setting up bls12377 Domain..."); - let icicle_omega = ::get_root_of_unity(size.try_into().unwrap()).unwrap(); + let icicle_omega = ::get_root_of_unity( + size.try_into() + .unwrap(), + ) + .unwrap(); // reusing ctx from above BLS12377ScalarCfg::initialize_domain(BLS12377ScalarField::from_ark(icicle_omega), &ctx).unwrap(); println!("Configuring bls12377 NTT..."); let stream_bls12377 = CudaStream::create().unwrap(); - let mut cfg_bls12377 = ntt::get_default_ntt_config::(); - cfg_bls12377.ctx.stream = &stream_bls12377; + let mut cfg_bls12377 = ntt::NTTConfig::default(); + cfg_bls12377 + .ctx + .stream = &stream_bls12377; cfg_bls12377.is_async = true; println!("Executing bn254 NTT on device..."); @@ -83,14 +88,30 @@ fn main() { let start = Instant::now(); ntt::ntt(&scalars, ntt::NTTDir::kForward, &cfg, &mut ntt_results).unwrap(); #[cfg(feature = "profile")] - println!("ICICLE BN254 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros()); + println!( + "ICICLE BN254 NTT on size 2^{log_size} took: {} μs", + start + .elapsed() + .as_micros() + ); println!("Executing bls12377 NTT on device..."); #[cfg(feature = "profile")] let start = Instant::now(); - ntt::ntt(&scalars_bls12377, ntt::NTTDir::kForward, &cfg_bls12377, &mut ntt_results_bls12377).unwrap(); + ntt::ntt( + &scalars_bls12377, + ntt::NTTDir::kForward, + &cfg_bls12377, + &mut ntt_results_bls12377, + ) + .unwrap(); #[cfg(feature = "profile")] - println!("ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros()); + println!( + "ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs", + start + .elapsed() + .as_micros() + ); println!("Moving results to host.."); stream @@ -100,7 +121,7 @@ fn main() { ntt_results .copy_to_host(&mut host_bn254_results[..]) .unwrap(); - + stream_bls12377 .synchronize() .unwrap(); @@ -108,25 +129,43 @@ fn main() { ntt_results_bls12377 .copy_to_host(&mut host_bls12377_results[..]) .unwrap(); - + println!("Checking against arkworks..."); - let mut ark_scalars: Vec = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect(); + let mut ark_scalars: Vec = scalars + .as_slice() + .iter() + .map(|scalar| scalar.to_ark()) + .collect(); let bn254_domain = as EvaluationDomain>::new(size).unwrap(); - - let mut ark_scalars_bls12377: Vec = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect(); + + let mut ark_scalars_bls12377: Vec = scalars_bls12377 + .as_slice() + .iter() + .map(|scalar| scalar.to_ark()) + .collect(); let bls12_377_domain = as EvaluationDomain>::new(size).unwrap(); - + #[cfg(feature = "profile")] let start = Instant::now(); bn254_domain.fft_in_place(&mut ark_scalars); #[cfg(feature = "profile")] - println!("Ark BN254 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "Ark BN254 NTT on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); #[cfg(feature = "profile")] let start = Instant::now(); bls12_377_domain.fft_in_place(&mut ark_scalars_bls12377); #[cfg(feature = "profile")] - println!("Ark BLS12377 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis()); + println!( + "Ark BLS12377 NTT on size 2^{log_size} took: {} ms", + start + .elapsed() + .as_millis() + ); host_bn254_results .iter() @@ -135,7 +174,7 @@ fn main() { assert_eq!(ark_scalar.cmp(&icicle_scalar.to_ark()), Ordering::Equal); }); println!("Bn254 NTT is correct"); - + host_bls12377_results .iter() .zip(ark_scalars_bls12377.iter()) @@ -144,7 +183,7 @@ fn main() { }); println!("Bls12377 NTT is correct"); - + println!("Cleaning up bn254..."); stream .destroy() diff --git a/icicle/appUtils/msm/msm.cu b/icicle/appUtils/msm/msm.cu index 4c48edb0..7c2e7956 100644 --- a/icicle/appUtils/msm/msm.cu +++ b/icicle/appUtils/msm/msm.cu @@ -14,7 +14,6 @@ #include "primitives/affine.cuh" #include "primitives/field.cuh" #include "primitives/projective.cuh" -#include "utils/cuda_utils.cuh" #include "utils/error_handler.cuh" #include "utils/mont.cuh" #include "utils/utils.h" @@ -240,7 +239,7 @@ namespace msm { unsigned large_bucket_index = tid / threads_per_bucket; unsigned bucket_segment_index = tid % threads_per_bucket; if (tid >= nof_buckets_to_compute * threads_per_bucket) { return; } - if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // dont need + if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // don't need return; // skip zero buckets } unsigned write_bucket_index = bucket_segment_index * nof_buckets_to_compute + large_bucket_index; diff --git a/icicle/appUtils/msm/msm.cuh b/icicle/appUtils/msm/msm.cuh index 152d9b31..8b8498d7 100644 --- a/icicle/appUtils/msm/msm.cuh +++ b/icicle/appUtils/msm/msm.cuh @@ -8,7 +8,6 @@ #include "../../primitives/affine.cuh" #include "../../primitives/field.cuh" #include "../../primitives/projective.cuh" -#include "../../utils/cuda_utils.cuh" #include "../../utils/device_context.cuh" #include "../../utils/error_handler.cuh" diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 7fea1551..5cd3b72a 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -9,6 +9,8 @@ #include "utils/utils.h" #include "appUtils/ntt/ntt_impl.cuh" +#include + namespace ntt { namespace { @@ -361,72 +363,88 @@ namespace ntt { template class Domain { - static inline int max_size = 0; - static inline int max_log_size = 0; - static inline S* twiddles = nullptr; - static inline std::unordered_map coset_index = {}; + // Mutex for protecting access to the domain/device container array + static inline std::mutex device_domain_mutex; + // The domain-per-device container - assumption is InitDomain is called once per device per program. - static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT - static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT + int max_size = 0; + int max_log_size = 0; + S* twiddles = nullptr; + std::unordered_map coset_index = {}; + + S* internal_twiddles = nullptr; // required by mixed-radix NTT + S* basic_twiddles = nullptr; // required by mixed-radix NTT public: template friend cudaError_t InitDomain(U primitive_root, device_context::DeviceContext& ctx); - static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); + cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); template friend cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig& config, E* output); }; + template + static inline Domain domains_for_devices[device_context::MAX_DEVICES] = {}; + template cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx) { CHK_INIT_IF_RETURN(); + Domain& domain = domains_for_devices[ctx.device_id]; + // only generate twiddles if they haven't been generated yet - // please note that this is not thread-safe at all, - // but it's a singleton that is supposed to be initialized once per program lifetime - if (!Domain::twiddles) { + // please note that this offers just basic thread-safety, + // it's assumed a singleton (non-enforced) that is supposed + // to be initialized once per device per program lifetime + if (!domain.twiddles) { + // Mutex is automatically released when lock goes out of scope, even in case of exceptions + std::lock_guard lock(Domain::device_domain_mutex); + // double check locking + if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain + bool found_logn = false; S omega = primitive_root; unsigned omegas_count = S::get_omegas_count(); for (int i = 0; i < omegas_count; i++) { omega = S::sqr(omega); if (!found_logn) { - ++Domain::max_log_size; + ++domain.max_log_size; found_logn = omega == S::one(); if (found_logn) break; } } - Domain::max_size = (int)pow(2, Domain::max_log_size); + + domain.max_size = (int)pow(2, domain.max_log_size); if (omega != S::one()) { - throw IcicleError( + THROW_ICICLE_ERR( IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup"); } // allocate and calculate twiddles on GPU // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements // Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host - CHK_IF_RETURN(cudaMallocManaged(&Domain::twiddles, (Domain::max_size + 1) * sizeof(S))); + CHK_IF_RETURN(cudaMallocManaged(&domain.twiddles, (domain.max_size + 1) * sizeof(S))); CHK_IF_RETURN(generate_external_twiddles_generic( - primitive_root, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, - Domain::max_log_size, ctx.stream)); + primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size, + ctx.stream)); CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream)); const bool is_map_only_powers_of_primitive_root = true; if (is_map_only_powers_of_primitive_root) { // populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8 // etc.) - Domain::coset_index[S::one()] = 0; - for (int i = 0; i < Domain::max_log_size; ++i) { + domain.coset_index[S::one()] = 0; + for (int i = 0; i < domain.max_log_size; ++i) { const int index = (int)pow(2, i); - Domain::coset_index[Domain::twiddles[index]] = index; + domain.coset_index[domain.twiddles[index]] = index; } } else { // populate all values - for (int i = 0; i < Domain::max_size; ++i) { - Domain::coset_index[Domain::twiddles[i]] = i; + for (int i = 0; i < domain.max_size; ++i) { + domain.coset_index[domain.twiddles[i]] = i; } } } @@ -527,7 +545,9 @@ namespace ntt { { CHK_INIT_IF_RETURN(); - if (size > Domain::max_size) { + Domain& domain = domains_for_devices[config.ctx.device_id]; + + if (size > domain.max_size) { std::ostringstream oss; oss << "NTT size=" << size << " is too large for the domain. Consider generating your domain with a higher order root of unity.\n"; @@ -565,7 +585,7 @@ namespace ntt { S* coset = nullptr; int coset_index = 0; try { - coset_index = Domain::coset_index.at(config.coset_gen); + coset_index = domain.coset_index.at(config.coset_gen); } catch (...) { // if coset index is not found in the subgroup, compute coset powers on CPU and move them to device std::vector h_coset; @@ -584,12 +604,12 @@ namespace ntt { if (is_radix2_algorithm) { CHK_IF_RETURN(ntt::radix2_ntt( - d_input, d_output, Domain::twiddles, size, Domain::max_size, batch_size, is_inverse, config.ordering, - coset, coset_index, stream)); + d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset, + coset_index, stream)); } else { CHK_IF_RETURN(ntt::mixed_radix_ntt( - d_input, d_output, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, size, - Domain::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream)); + d_input, d_output, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, size, domain.max_log_size, + batch_size, is_inverse, config.ordering, coset, coset_index, stream)); } if (!are_outputs_on_device) diff --git a/icicle/utils/cuda_utils.cuh b/icicle/utils/cuda_utils.cuh deleted file mode 100644 index 071e8d89..00000000 --- a/icicle/utils/cuda_utils.cuh +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once -#include - -struct cuda_ctx { - int device_id; - cudaMemPool_t mempool; - cudaStream_t stream; - - cuda_ctx(int gpu_id) - { - gpu_id = gpu_id; - cudaMemPoolProps pool_props; - pool_props.allocType = cudaMemAllocationTypePinned; - pool_props.handleTypes = cudaMemHandleTypePosixFileDescriptor; - pool_props.location.type = cudaMemLocationTypeDevice; - pool_props.location.id = device_id; - - cudaMemPoolCreate(&mempool, &pool_props); - cudaStreamCreate(&stream); - } - - void set_device() { cudaSetDevice(device_id); } - - void sync_stream() { cudaStreamSynchronize(stream); } - - void malloc(void* ptr, size_t bytesize) { cudaMallocFromPoolAsync(&ptr, bytesize, mempool, stream); } - - void free(void* ptr) { cudaFreeAsync(ptr, stream); } -}; - -// -- Proposed Function Tops -------------------------------------------------- -// ---------------------------------------------------------------------------- diff --git a/icicle/utils/device_context.cuh b/icicle/utils/device_context.cuh index 031b7123..9dfd83f2 100644 --- a/icicle/utils/device_context.cuh +++ b/icicle/utils/device_context.cuh @@ -6,6 +6,8 @@ namespace device_context { + constexpr std::size_t MAX_DEVICES = 32; + /** * Properties of the device used in icicle functions. */ @@ -18,7 +20,7 @@ namespace device_context { /** * Return default device context that corresponds to using the default stream of the first GPU */ - inline DeviceContext get_default_device_context() + inline DeviceContext get_default_device_context() // TODO: naming convention ? { static cudaStream_t default_stream = (cudaStream_t)0; return DeviceContext{ diff --git a/icicle/utils/sharedmem.cuh b/icicle/utils/sharedmem.cuh index 64469995..9b00d36d 100644 --- a/icicle/utils/sharedmem.cuh +++ b/icicle/utils/sharedmem.cuh @@ -1,3 +1,4 @@ +// TODO: remove this file, seems working without it // based on https://leimao.github.io/blog/CUDA-Shared-Memory-Templated-Kernel/ // may be outdated, but only worked like that diff --git a/wrappers/rust/icicle-core/Cargo.toml b/wrappers/rust/icicle-core/Cargo.toml index 21d32eeb..5b92ee08 100644 --- a/wrappers/rust/icicle-core/Cargo.toml +++ b/wrappers/rust/icicle-core/Cargo.toml @@ -2,21 +2,23 @@ name = "icicle-core" version = "1.2.0" edition = "2021" -authors = [ "Ingonyama" ] +authors = ["Ingonyama"] description = "A library for GPU ZK acceleration by Ingonyama" homepage = "https://www.ingonyama.com" repository = "https://github.com/ingonyama-zk/icicle" -[dependencies] +[dependencies] icicle-cuda-runtime = { path = "../icicle-cuda-runtime" } ark-ff = { version = "0.4.0", optional = true } -ark-ec = { version = "0.4.0", optional = true, features = [ "parallel" ] } +ark-ec = { version = "0.4.0", optional = true, features = ["parallel"] } ark-poly = { version = "0.4.0", optional = true } ark-std = { version = "0.4.0", optional = true } +rayon = "1.8.1" + [features] default = [] arkworks = ["ark-ff", "ark-ec", "ark-poly", "ark-std"] diff --git a/wrappers/rust/icicle-core/src/curve.rs b/wrappers/rust/icicle-core/src/curve.rs index a688114a..916270e8 100644 --- a/wrappers/rust/icicle-core/src/curve.rs +++ b/wrappers/rust/icicle-core/src/curve.rs @@ -284,7 +284,7 @@ macro_rules! impl_curve { points.as_mut_ptr(), points.len(), is_into, - &get_default_device_context() as *const _ as *const DeviceContext, + &DeviceContext::default() as *const _ as *const DeviceContext, ) } } @@ -298,7 +298,7 @@ macro_rules! impl_curve { points.as_mut_ptr(), points.len(), is_into, - &get_default_device_context() as *const _ as *const DeviceContext, + &DeviceContext::default() as *const _ as *const DeviceContext, ) } } diff --git a/wrappers/rust/icicle-core/src/field.rs b/wrappers/rust/icicle-core/src/field.rs index 58e6b9d6..be556545 100644 --- a/wrappers/rust/icicle-core/src/field.rs +++ b/wrappers/rust/icicle-core/src/field.rs @@ -15,6 +15,9 @@ pub struct Field { p: PhantomData, } +unsafe impl Send for Field {} +unsafe impl Sync for Field {} + impl Display for Field { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { write!(f, "0x")?; @@ -165,7 +168,7 @@ macro_rules! impl_scalar_field { impl_field!($num_limbs, $field_name, $field_cfg, $ark_equiv); mod $field_prefix_ident { - use crate::curve::{get_default_device_context, $field_name, CudaError, DeviceContext, HostOrDeviceSlice}; + use crate::curve::{$field_name, CudaError, DeviceContext, HostOrDeviceSlice}; extern "C" { #[link_name = concat!($field_prefix, "GenerateScalars")] @@ -189,7 +192,7 @@ macro_rules! impl_scalar_field { scalars.as_mut_ptr(), scalars.len(), is_into, - &get_default_device_context() as *const _ as *const DeviceContext, + &DeviceContext::default() as *const _ as *const DeviceContext, ) } } diff --git a/wrappers/rust/icicle-core/src/msm/mod.rs b/wrappers/rust/icicle-core/src/msm/mod.rs index cf733da2..c4b68a65 100644 --- a/wrappers/rust/icicle-core/src/msm/mod.rs +++ b/wrappers/rust/icicle-core/src/msm/mod.rs @@ -1,6 +1,6 @@ use crate::curve::{Affine, Curve, Projective}; use crate::error::IcicleResult; -use icicle_cuda_runtime::device_context::DeviceContext; +use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID}; use icicle_cuda_runtime::memory::HostOrDeviceSlice; #[cfg(feature = "arkworks")] @@ -59,6 +59,33 @@ pub struct MSMConfig<'a> { pub is_async: bool, } +impl<'a> Default for MSMConfig<'a> { + fn default() -> Self { + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} + +impl<'a> MSMConfig<'a> { + pub fn default_for_device(device_id: usize) -> Self { + Self { + ctx: DeviceContext::default_for_device(device_id), + points_size: 0, + precompute_factor: 1, + c: 0, + bitsize: 0, + large_bucket_factor: 10, + batch_size: 1, + are_scalars_on_device: false, + are_scalars_montgomery_form: false, + are_points_on_device: false, + are_points_montgomery_form: false, + are_results_on_device: false, + is_big_triangle: false, + is_async: false, + } + } +} + #[doc(hidden)] pub trait MSM { fn msm_unchecked( @@ -67,8 +94,6 @@ pub trait MSM { cfg: &MSMConfig, results: &mut HostOrDeviceSlice>, ) -> IcicleResult<()>; - - fn get_default_msm_config() -> MSMConfig<'static>; } /// Computes the multi-scalar multiplication, or MSM: `s1*P1 + s2*P2 + ... + sn*Pn`, or a batch of several MSMs. @@ -113,11 +138,6 @@ pub fn msm>( C::msm_unchecked(scalars, points, &local_cfg, results) } -/// Returns [MSM config](MSMConfig) struct populated with default values. -pub fn get_default_msm_config>() -> MSMConfig<'static> { - C::get_default_msm_config() -} - #[macro_export] macro_rules! impl_msm { ( @@ -161,10 +181,6 @@ macro_rules! impl_msm { .wrap() } } - - fn get_default_msm_config() -> MSMConfig<'static> { - unsafe { $curve_prefix_indent::default_msm_config() } - } } }; } diff --git a/wrappers/rust/icicle-core/src/msm/tests.rs b/wrappers/rust/icicle-core/src/msm/tests.rs index 4fb6ca58..30c2719f 100644 --- a/wrappers/rust/icicle-core/src/msm/tests.rs +++ b/wrappers/rust/icicle-core/src/msm/tests.rs @@ -1,8 +1,11 @@ -use super::{get_default_msm_config, msm, MSM}; use crate::curve::{Affine, Curve, Projective}; +use crate::msm::{msm, MSMConfig, MSM}; use crate::traits::{FieldImpl, GenerateRandom}; +use icicle_cuda_runtime::device::{get_device_count, set_device}; use icicle_cuda_runtime::memory::HostOrDeviceSlice; use icicle_cuda_runtime::stream::CudaStream; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; #[cfg(feature = "arkworks")] use crate::traits::ArkConvertible; @@ -19,61 +22,73 @@ where C::ScalarField: ArkConvertible::ScalarField>, C::BaseField: ArkConvertible::BaseField>, { - let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18]; - let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); - for test_size in test_sizes { - let points = C::generate_random_affine_points(test_size); - let scalars = ::Config::generate_random(test_size); - let points_ark: Vec<_> = points - .iter() - .map(|x| x.to_ark()) - .collect(); - let scalars_ark: Vec<_> = scalars - .iter() - .map(|x| x.to_ark()) - .collect(); - // if we simply transmute arkworks types, we'll get scalars or points in Montgomery format - // (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that) - let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) }; + let device_count = get_device_count().unwrap(); + (0..device_count) // TODO: this is proto-loadbalancer + .into_par_iter() + .for_each(move |device_id| { + //TODO: currently supported multi-GPU workflow: + // 1) User starts child host thread from parent host thread + // 2) Calls set_device once with selected device_id (0 is default device .. < device_count) + // 3) Perform all operations (without changing device on the thread) + // 4) If necessary - export results to parent host thread - let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap(); - let stream = CudaStream::create().unwrap(); - scalars_d - .copy_from_host_async(&scalars_mont, &stream) - .unwrap(); + set_device(device_id).unwrap(); + let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18]; + let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); + for test_size in test_sizes { + let points = C::generate_random_affine_points(test_size); + let scalars = ::Config::generate_random(test_size); + let points_ark: Vec<_> = points + .iter() + .map(|x| x.to_ark()) + .collect(); + let scalars_ark: Vec<_> = scalars + .iter() + .map(|x| x.to_ark()) + .collect(); + // if we simply transmute arkworks types, we'll get scalars or points in Montgomery format + // (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that) + let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) }; - let mut cfg = get_default_msm_config::(); - cfg.ctx - .stream = &stream; - cfg.is_async = true; - cfg.are_scalars_montgomery_form = true; - msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap(); - // need to make sure that scalars_d weren't mutated by the previous call - let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size]; - scalars_d - .copy_to_host_async(&mut scalars_mont_after, &stream) - .unwrap(); - assert_eq!(scalars_mont, scalars_mont_after); + let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap(); + let stream = CudaStream::create().unwrap(); + scalars_d + .copy_from_host_async(&scalars_mont, &stream) + .unwrap(); - let mut msm_host_result = vec![Projective::::zero(); 1]; - msm_results - .copy_to_host(&mut msm_host_result[..]) - .unwrap(); - stream - .synchronize() - .unwrap(); - stream - .destroy() - .unwrap(); + let mut cfg = MSMConfig::default_for_device(device_id); + cfg.ctx + .stream = &stream; + cfg.is_async = true; + cfg.are_scalars_montgomery_form = true; + msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap(); + // need to make sure that scalars_d weren't mutated by the previous call + let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size]; + scalars_d + .copy_to_host_async(&mut scalars_mont_after, &stream) + .unwrap(); + assert_eq!(scalars_mont, scalars_mont_after); - let msm_result_ark: ark_ec::models::short_weierstrass::Projective = - VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap(); - let msm_res_affine: ark_ec::short_weierstrass::Affine = msm_host_result[0] - .to_ark() - .into(); - assert!(msm_res_affine.is_on_curve()); - assert_eq!(msm_host_result[0].to_ark(), msm_result_ark); - } + let mut msm_host_result = vec![Projective::::zero(); 1]; + msm_results + .copy_to_host(&mut msm_host_result[..]) + .unwrap(); + stream + .synchronize() + .unwrap(); + stream + .destroy() + .unwrap(); + + let msm_result_ark: ark_ec::models::short_weierstrass::Projective = + VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap(); + let msm_res_affine: ark_ec::short_weierstrass::Affine = msm_host_result[0] + .to_ark() + .into(); + assert!(msm_res_affine.is_on_curve()); + assert_eq!(msm_host_result[0].to_ark(), msm_result_ark); + } + }); } pub fn check_msm_batch>() @@ -104,7 +119,7 @@ where .copy_from_host_async(&points_cloned, &stream) .unwrap(); - let mut cfg = get_default_msm_config::(); + let mut cfg = MSMConfig::default(); cfg.ctx .stream = &stream; cfg.is_async = true; @@ -181,7 +196,7 @@ where let mut msm_results = HostOrDeviceSlice::on_host(vec![Projective::::zero(); batch_size]); - let mut cfg = get_default_msm_config::(); + let mut cfg = MSMConfig::default(); if test_size < test_threshold { cfg.bitsize = 1; } diff --git a/wrappers/rust/icicle-core/src/ntt/mod.rs b/wrappers/rust/icicle-core/src/ntt/mod.rs index 6ef6a5b9..3d699e7c 100644 --- a/wrappers/rust/icicle-core/src/ntt/mod.rs +++ b/wrappers/rust/icicle-core/src/ntt/mod.rs @@ -1,4 +1,4 @@ -use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext}; +use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID}; use icicle_cuda_runtime::memory::HostOrDeviceSlice; use crate::{error::IcicleResult, traits::FieldImpl}; @@ -89,11 +89,16 @@ pub struct NTTConfig<'a, S> { pub ntt_algorithm: NttAlgorithm, } +impl<'a, S: FieldImpl> Default for NTTConfig<'a, S> { + fn default() -> Self { + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} + impl<'a, S: FieldImpl> NTTConfig<'a, S> { - pub fn default_config() -> Self { - let ctx = get_default_device_context(); + pub fn default_for_device(device_id: usize) -> Self { NTTConfig { - ctx, + ctx: DeviceContext::default_for_device(device_id), coset_gen: S::one(), batch_size: 1, ordering: Ordering::kNN, @@ -114,7 +119,6 @@ pub trait NTT { output: &mut HostOrDeviceSlice, ) -> IcicleResult<()>; fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>; - fn get_default_ntt_config() -> NTTConfig<'static, F>; } /// Computes the NTT, or a batch of several NTTs. @@ -169,15 +173,6 @@ where <::Config as NTT>::initialize_domain(primitive_root, ctx) } -/// Returns [NTT config](NTTConfig) struct populated with default values. -pub fn get_default_ntt_config() -> NTTConfig<'static, F> -where - F: FieldImpl, - ::Config: NTT, -{ - <::Config as NTT>::get_default_ntt_config() -} - #[macro_export] macro_rules! impl_ntt { ( @@ -187,7 +182,7 @@ macro_rules! impl_ntt { $field_config:ident ) => { mod $field_prefix_ident { - use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir}; + use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir, DEFAULT_DEVICE_ID}; extern "C" { #[link_name = concat!($field_prefix, "NTTCuda")] @@ -226,10 +221,6 @@ macro_rules! impl_ntt { fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> { unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() } } - - fn get_default_ntt_config() -> NTTConfig<'static, $field> { - NTTConfig::<$field>::default_config() - } } }; } @@ -244,31 +235,31 @@ macro_rules! impl_ntt_tests { #[test] fn test_ntt() { - INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE)); + INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID)); check_ntt::<$field>() } #[test] fn test_ntt_coset_from_subgroup() { - INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE)); + INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID)); check_ntt_coset_from_subgroup::<$field>() } #[test] fn test_ntt_arbitrary_coset() { - INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE)); + INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID)); check_ntt_arbitrary_coset::<$field>() } #[test] fn test_ntt_batch() { - INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE)); + INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID)); check_ntt_batch::<$field>() } #[test] fn test_ntt_device_async() { - INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE)); + // init_domain is in this test is performed per-device check_ntt_device_async::<$field>() } }; diff --git a/wrappers/rust/icicle-core/src/ntt/tests.rs b/wrappers/rust/icicle-core/src/ntt/tests.rs index 3819441b..fde182fc 100644 --- a/wrappers/rust/icicle-core/src/ntt/tests.rs +++ b/wrappers/rust/icicle-core/src/ntt/tests.rs @@ -1,23 +1,27 @@ use ark_ff::{FftField, Field as ArkField, One}; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; use ark_std::{ops::Neg, test_rng, UniformRand}; -use icicle_cuda_runtime::device_context::get_default_device_context; +use icicle_cuda_runtime::device::get_device_count; +use icicle_cuda_runtime::device::set_device; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::memory::HostOrDeviceSlice; -use icicle_cuda_runtime::stream::CudaStream; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; use crate::{ - ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering}, + ntt::{initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering}, traits::{ArkConvertible, FieldImpl, GenerateRandom}, }; +use super::NTTConfig; use super::NTT; -pub fn init_domain(max_size: u64) +pub fn init_domain(max_size: u64, device_id: usize) where F::ArkEquivalent: FftField, ::Config: NTT, { - let ctx = get_default_device_context(); + let ctx = DeviceContext::default_for_device(device_id); let ark_rou = F::ArkEquivalent::get_root_of_unity(max_size).unwrap(); initialize_domain(F::from_ark(ark_rou), &ctx).unwrap(); } @@ -61,7 +65,7 @@ where let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) }; let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec()); - let mut config = get_default_ntt_config(); + let mut config = NTTConfig::default(); for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { config.ntt_algorithm = alg; let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); @@ -107,7 +111,7 @@ where .collect::>(); for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { - let mut config = get_default_ntt_config(); + let mut config = NTTConfig::default(); config.ordering = Ordering::kNR; config.ntt_algorithm = alg; let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); @@ -188,7 +192,7 @@ where .map(|v| F::ArkEquivalent::from_random_bytes(&v.to_bytes_le()).unwrap()) .collect::>(); - let mut config = get_default_ntt_config(); + let mut config = NTTConfig::default(); config.coset_gen = F::from_ark(coset_gen); for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { config.ordering = Ordering::kNR; @@ -229,7 +233,7 @@ where let batch_sizes = [1, 1 << 4, 100]; for test_size in test_sizes { let coset_generators = [F::one(), F::Config::generate_random(1)[0]]; - let mut config = get_default_ntt_config(); + let mut config = NTTConfig::default(); for batch_size in batch_sizes { let scalars = HostOrDeviceSlice::on_host(F::Config::generate_random(test_size * batch_size)); @@ -279,55 +283,65 @@ where F::ArkEquivalent: FftField, ::Config: NTT + GenerateRandom, { - let test_sizes = [1 << 4, 1 << 12]; - let batch_sizes = [1, 1 << 4, 100]; - for test_size in test_sizes { - let coset_generators = [F::one(), F::Config::generate_random(1)[0]]; - let stream = CudaStream::create().unwrap(); - let mut config = get_default_ntt_config(); - for batch_size in batch_sizes { - let scalars_h: Vec = F::Config::generate_random(test_size * batch_size); - let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size] - .iter() - .map(|x| x.to_ark()) - .sum(); - let mut scalars_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap(); - scalars_d - .copy_from_host_async(&scalars_h, &stream) - .unwrap(); - let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap(); + let device_count = get_device_count().unwrap(); - for coset_gen in coset_generators { - for ordering in [Ordering::kNN, Ordering::kRR] { - config.coset_gen = coset_gen; - config.ordering = ordering; - config.batch_size = batch_size as i32; - config.is_async = true; - config - .ctx - .stream = &stream; - for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { - config.ntt_algorithm = alg; - ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap(); - ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap(); - let mut intt_result_h = vec![F::zero(); test_size * batch_size]; - scalars_d - .copy_to_host_async(&mut intt_result_h, &stream) - .unwrap(); - stream - .synchronize() - .unwrap(); - assert_eq!(scalars_h, intt_result_h); - if coset_gen == F::one() { - let mut ntt_result_h = vec![F::zero(); test_size * batch_size]; - ntt_out_d - .copy_to_host(&mut ntt_result_h) - .unwrap(); - assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark()); + (0..device_count) + .into_par_iter() + .for_each(move |device_id| { + set_device(device_id).unwrap(); + init_domain::(1 << 16, device_id); // init domain per device + let test_sizes = [1 << 4, 1 << 12]; + let batch_sizes = [1, 1 << 4, 100]; + for test_size in test_sizes { + let coset_generators = [F::one(), F::Config::generate_random(1)[0]]; + let mut config = NTTConfig::default_for_device(device_id); + let stream = config + .ctx + .stream; + for batch_size in batch_sizes { + let scalars_h: Vec = F::Config::generate_random(test_size * batch_size); + let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size] + .iter() + .map(|x| x.to_ark()) + .sum(); + let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size * batch_size).unwrap(); + scalars_d + .copy_from_host(&scalars_h) + .unwrap(); + let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap(); + + for coset_gen in coset_generators { + for ordering in [Ordering::kNN, Ordering::kRR] { + config.coset_gen = coset_gen; + config.ordering = ordering; + config.batch_size = batch_size as i32; + config.is_async = true; + config + .ctx + .stream = &stream; + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + config.ntt_algorithm = alg; + ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap(); + ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap(); + let mut intt_result_h = vec![F::zero(); test_size * batch_size]; + scalars_d + .copy_to_host_async(&mut intt_result_h, &stream) + .unwrap(); + stream + .synchronize() + .unwrap(); + assert_eq!(scalars_h, intt_result_h); + if coset_gen == F::one() { + let mut ntt_result_h = vec![F::zero(); test_size * batch_size]; + ntt_out_d + .copy_to_host(&mut ntt_result_h) + .unwrap(); + assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark()); + } + } } } } } - } - } + }); } diff --git a/wrappers/rust/icicle-core/src/poseidon/mod.rs b/wrappers/rust/icicle-core/src/poseidon/mod.rs index 42d96eb9..cbcc70e0 100644 --- a/wrappers/rust/icicle-core/src/poseidon/mod.rs +++ b/wrappers/rust/icicle-core/src/poseidon/mod.rs @@ -2,7 +2,7 @@ pub mod tests; use icicle_cuda_runtime::{ - device_context::{get_default_device_context, DeviceContext}, + device_context::{DeviceContext, DEFAULT_DEVICE_ID}, memory::HostOrDeviceSlice, }; @@ -60,9 +60,14 @@ pub struct PoseidonConfig<'a> { impl<'a> Default for PoseidonConfig<'a> { fn default() -> Self { - let ctx = get_default_device_context(); + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} + +impl<'a> PoseidonConfig<'a> { + pub fn default_for_device(device_id: usize) -> Self { Self { - ctx, + ctx: DeviceContext::default_for_device(device_id), are_inputs_on_device: false, are_outputs_on_device: false, input_is_a_state: false, diff --git a/wrappers/rust/icicle-core/src/poseidon/tests.rs b/wrappers/rust/icicle-core/src/poseidon/tests.rs index 997e07cc..9225bb9a 100644 --- a/wrappers/rust/icicle-core/src/poseidon/tests.rs +++ b/wrappers/rust/icicle-core/src/poseidon/tests.rs @@ -1,5 +1,5 @@ use crate::traits::FieldImpl; -use icicle_cuda_runtime::device_context::get_default_device_context; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::memory::HostOrDeviceSlice; use std::io::Read; @@ -15,7 +15,7 @@ pub fn init_poseidon<'a, F: FieldImpl>(arity: u32) -> PoseidonConstants<'a, F> where ::Config: Poseidon, { - let ctx = get_default_device_context(); + let ctx = DeviceContext::default(); load_optimized_poseidon_constants::(arity, &ctx).unwrap() } @@ -71,7 +71,7 @@ where let full_rounds_half = 4; - let ctx = get_default_device_context(); + let ctx = DeviceContext::default(); let cargo_manifest_dir = env!("CARGO_MANIFEST_DIR"); let constants_file = PathBuf::from(cargo_manifest_dir) .join("tests") diff --git a/wrappers/rust/icicle-core/src/traits.rs b/wrappers/rust/icicle-core/src/traits.rs index 6e55c58e..05eb00ec 100644 --- a/wrappers/rust/icicle-core/src/traits.rs +++ b/wrappers/rust/icicle-core/src/traits.rs @@ -18,7 +18,9 @@ pub trait FieldConfig: Debug + PartialEq + Copy + Clone { type ArkField: ArkField; } -pub trait FieldImpl: Display + Debug + PartialEq + Copy + Clone + Into + From { +pub trait FieldImpl: + Display + Debug + PartialEq + Copy + Clone + Into + From + Send + Sync +{ #[doc(hidden)] type Config: FieldConfig; type Repr; diff --git a/wrappers/rust/icicle-core/src/tree/mod.rs b/wrappers/rust/icicle-core/src/tree/mod.rs index d635c5f3..3340f472 100644 --- a/wrappers/rust/icicle-core/src/tree/mod.rs +++ b/wrappers/rust/icicle-core/src/tree/mod.rs @@ -1,5 +1,5 @@ use icicle_cuda_runtime::{ - device_context::{get_default_device_context, DeviceContext}, + device_context::{DeviceContext, DEFAULT_DEVICE_ID}, memory::HostOrDeviceSlice, }; @@ -28,9 +28,13 @@ pub struct TreeBuilderConfig<'a> { impl<'a> Default for TreeBuilderConfig<'a> { fn default() -> Self { - let ctx = get_default_device_context(); + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} +impl<'a> TreeBuilderConfig<'a> { + fn default_for_device(device_id: usize) -> Self { Self { - ctx, + ctx: DeviceContext::default_for_device(device_id), keep_rows: 0, are_inputs_on_device: false, is_async: false, diff --git a/wrappers/rust/icicle-cuda-runtime/build.rs b/wrappers/rust/icicle-cuda-runtime/build.rs index 4637f498..23dfdc51 100644 --- a/wrappers/rust/icicle-cuda-runtime/build.rs +++ b/wrappers/rust/icicle-cuda-runtime/build.rs @@ -45,7 +45,9 @@ fn main() { .must_use_type("cudaError") // device management // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html + .allowlist_function("cudaGetDevice") .allowlist_function("cudaSetDevice") + .allowlist_function("cudaGetDeviceCount") // error handling // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__ERROR.html .allowlist_function("cudaGetLastError") @@ -69,6 +71,7 @@ fn main() { .allowlist_function("cudaMemcpyAsync") .allowlist_function("cudaMemset") .allowlist_function("cudaMemsetAsync") + .allowlist_function("cudaDeviceGetDefaultMemPool") .rustified_enum("cudaMemcpyKind") // Stream Ordered Memory Allocator // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html diff --git a/wrappers/rust/icicle-cuda-runtime/src/device.rs b/wrappers/rust/icicle-cuda-runtime/src/device.rs new file mode 100644 index 00000000..80750b2f --- /dev/null +++ b/wrappers/rust/icicle-cuda-runtime/src/device.rs @@ -0,0 +1,18 @@ +use crate::{ + bindings::{cudaGetDevice, cudaGetDeviceCount, cudaSetDevice}, + error::{CudaResult, CudaResultWrap}, +}; + +pub fn set_device(device_id: usize) -> CudaResult<()> { + unsafe { cudaSetDevice(device_id as i32) }.wrap() +} + +pub fn get_device_count() -> CudaResult { + let mut count = 0; + unsafe { cudaGetDeviceCount(&mut count) }.wrap_value(count as usize) +} + +pub fn get_device() -> CudaResult { + let mut device_id = 0; + unsafe { cudaGetDevice(&mut device_id) }.wrap_value(device_id as usize) +} diff --git a/wrappers/rust/icicle-cuda-runtime/src/device_context.rs b/wrappers/rust/icicle-cuda-runtime/src/device_context.rs index f0f8c411..cfcae98f 100644 --- a/wrappers/rust/icicle-cuda-runtime/src/device_context.rs +++ b/wrappers/rust/icicle-cuda-runtime/src/device_context.rs @@ -1,27 +1,47 @@ use crate::memory::CudaMemPool; use crate::stream::CudaStream; -/// Properties of the device used in icicle functions. +pub const DEFAULT_DEVICE_ID: usize = 0; + +use crate::device::get_device; + +/// Properties of the device used in Icicle functions. #[repr(C)] #[derive(Debug, Clone)] pub struct DeviceContext<'a> { - /// Stream to use. Default value: 0. + /// Stream to use. Default value: 0. //TODO: multiple streams per device ? pub stream: &'a CudaStream, // Assuming the type is provided by a CUDA binding crate /// Index of the currently used GPU. Default value: 0. pub device_id: usize, - /// Mempool to use. Default value: 0. + /// Mempool to use. Default value: 0. //TODO: multiple mempools per device ? pub mempool: CudaMemPool, // Assuming the type is provided by a CUDA binding crate } -pub fn get_default_device_context() -> DeviceContext<'static> { - static default_stream: CudaStream = CudaStream { - handle: std::ptr::null_mut(), - }; - DeviceContext { - stream: &default_stream, - device_id: 0, - mempool: 0, +impl Default for DeviceContext<'_> { + fn default() -> Self { + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} + +impl DeviceContext<'_> { + /// Default for device_id + pub fn default_for_device(device_id: usize) -> DeviceContext<'static> { + static default_stream: CudaStream = CudaStream { + handle: std::ptr::null_mut(), + }; + DeviceContext { + stream: &default_stream, + device_id, + mempool: std::ptr::null_mut(), + } + } +} + +pub fn check_device(device_id: i32) { + match device_id == get_device().unwrap() as i32 { + true => (), + false => panic!("Attempt to use on a different device"), } } diff --git a/wrappers/rust/icicle-cuda-runtime/src/lib.rs b/wrappers/rust/icicle-cuda-runtime/src/lib.rs index 85d95c16..103fd56b 100644 --- a/wrappers/rust/icicle-cuda-runtime/src/lib.rs +++ b/wrappers/rust/icicle-cuda-runtime/src/lib.rs @@ -3,6 +3,7 @@ #[allow(dead_code)] mod bindings; +pub mod device; pub mod device_context; pub mod error; pub mod memory; diff --git a/wrappers/rust/icicle-cuda-runtime/src/memory.rs b/wrappers/rust/icicle-cuda-runtime/src/memory.rs index eb7886cb..0e12fd82 100644 --- a/wrappers/rust/icicle-cuda-runtime/src/memory.rs +++ b/wrappers/rust/icicle-cuda-runtime/src/memory.rs @@ -1,4 +1,8 @@ -use crate::bindings::{cudaFree, cudaMalloc, cudaMallocAsync, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind}; +use crate::bindings::{ + cudaFree, cudaMalloc, cudaMallocAsync, cudaMemPool_t, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind, +}; +use crate::device::get_device; +use crate::device_context::check_device; use crate::error::{CudaError, CudaResult, CudaResultWrap}; use crate::stream::CudaStream; use std::mem::{size_of, MaybeUninit}; @@ -8,60 +12,69 @@ use std::slice::from_raw_parts_mut; pub enum HostOrDeviceSlice<'a, T> { Host(Vec), - Device(&'a mut [T]), + Device(&'a mut [T], i32), } impl<'a, T> HostOrDeviceSlice<'a, T> { + // Function to get the device_id for Device variant + pub fn get_device_id(&self) -> Option { + match self { + HostOrDeviceSlice::Device(_, device_id) => Some(*device_id), + HostOrDeviceSlice::Host(_) => None, + } + } + pub fn len(&self) -> usize { match self { - Self::Device(s) => s.len(), + Self::Device(s, _) => s.len(), Self::Host(v) => v.len(), } } pub fn is_empty(&self) -> bool { match self { - Self::Device(s) => s.is_empty(), + Self::Device(s, _) => s.is_empty(), Self::Host(v) => v.is_empty(), } } pub fn is_on_device(&self) -> bool { match self { - Self::Device(_) => true, + Self::Device(_, _) => true, Self::Host(_) => false, } } pub fn as_mut_slice(&mut self) -> &mut [T] { match self { - Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"), + Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"), Self::Host(v) => v.as_mut_slice(), } } pub fn as_slice(&self) -> &[T] { match self { - Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"), + Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"), Self::Host(v) => v.as_slice(), } } pub fn as_ptr(&self) -> *const T { match self { - Self::Device(s) => s.as_ptr(), + Self::Device(s, _) => s.as_ptr(), Self::Host(v) => v.as_ptr(), } } pub fn as_mut_ptr(&mut self) -> *mut T { match self { - Self::Device(s) => s.as_mut_ptr(), + Self::Device(s, _) => s.as_mut_ptr(), Self::Host(v) => v.as_mut_ptr(), } } pub fn on_host(src: Vec) -> Self { + //TODO: HostOrDeviceSlice on_host() with slice input without actually copying the data Self::Host(src) } @@ -70,16 +83,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> { .checked_mul(size_of::()) .unwrap_or(0); if size == 0 { - return Err(CudaError::cudaErrorMemoryAllocation); + return Err(CudaError::cudaErrorMemoryAllocation); //TODO: only CUDA backend should return CudaError } let mut device_ptr = MaybeUninit::<*mut c_void>::uninit(); unsafe { cudaMalloc(device_ptr.as_mut_ptr(), size).wrap()?; - Ok(Self::Device(from_raw_parts_mut( - device_ptr.assume_init() as *mut T, - count, - ))) + Ok(Self::Device( + from_raw_parts_mut(device_ptr.assume_init() as *mut T, count), + get_device().unwrap() as i32, + )) } } @@ -94,16 +107,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> { let mut device_ptr = MaybeUninit::<*mut c_void>::uninit(); unsafe { cudaMallocAsync(device_ptr.as_mut_ptr(), size, stream.handle as *mut _ as *mut _).wrap()?; - Ok(Self::Device(from_raw_parts_mut( - device_ptr.assume_init() as *mut T, - count, - ))) + Ok(Self::Device( + from_raw_parts_mut(device_ptr.assume_init() as *mut T, count), + get_device().unwrap() as i32, + )) } } pub fn copy_from_host(&mut self, val: &[T]) -> CudaResult<()> { match self { - Self::Device(_) => {} + Self::Device(_, device_id) => check_device(*device_id), Self::Host(_) => panic!("Need device memory to copy into, and not host"), }; assert!( @@ -127,7 +140,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> { pub fn copy_to_host(&self, val: &mut [T]) -> CudaResult<()> { match self { - Self::Device(_) => {} + Self::Device(_, device_id) => check_device(*device_id), Self::Host(_) => panic!("Need device memory to copy from, and not host"), }; assert!( @@ -151,7 +164,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> { pub fn copy_from_host_async(&mut self, val: &[T], stream: &CudaStream) -> CudaResult<()> { match self { - Self::Device(_) => {} + Self::Device(_, device_id) => check_device(*device_id), Self::Host(_) => panic!("Need device memory to copy into, and not host"), }; assert!( @@ -176,7 +189,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> { pub fn copy_to_host_async(&self, val: &mut [T], stream: &CudaStream) -> CudaResult<()> { match self { - Self::Device(_) => {} + Self::Device(_, device_id) => check_device(*device_id), Self::Host(_) => panic!("Need device memory to copy from, and not host"), }; assert!( @@ -209,7 +222,7 @@ macro_rules! impl_index { fn index(&self, index: $t) -> &Self::Output { match self { - Self::Device(s) => s.index(index), + Self::Device(s, _) => s.index(index), Self::Host(v) => v.index(index), } } @@ -219,7 +232,7 @@ macro_rules! impl_index { { fn index_mut(&mut self, index: $t) -> &mut Self::Output { match self { - Self::Device(s) => s.index_mut(index), + Self::Device(s,_) => s.index_mut(index), Self::Host(v) => v.index_mut(index), } } @@ -239,7 +252,8 @@ impl_index! { impl<'a, T> Drop for HostOrDeviceSlice<'a, T> { fn drop(&mut self) { match self { - Self::Device(s) => { + Self::Device(s, device_id) => { + check_device(*device_id); if s.is_empty() { return; } @@ -256,4 +270,4 @@ impl<'a, T> Drop for HostOrDeviceSlice<'a, T> { } #[allow(non_camel_case_types)] -pub type CudaMemPool = usize; // This is a placeholder, TODO: actually make this into a proper CUDA wrapper +pub type CudaMemPool = cudaMemPool_t; diff --git a/wrappers/rust/icicle-curves/icicle-bls12-377/src/curve.rs b/wrappers/rust/icicle-curves/icicle-bls12-377/src/curve.rs index 3a3779d4..afbc72aa 100644 --- a/wrappers/rust/icicle-curves/icicle-bls12-377/src/curve.rs +++ b/wrappers/rust/icicle-curves/icicle-bls12-377/src/curve.rs @@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective}; use icicle_core::field::{Field, MontgomeryConvertibleField}; use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom}; use icicle_core::{impl_curve, impl_field, impl_scalar_field}; -use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext}; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; diff --git a/wrappers/rust/icicle-curves/icicle-bls12-377/src/ntt/mod.rs b/wrappers/rust/icicle-curves/icicle-bls12-377/src/ntt/mod.rs index 543f9fb4..bf8dfbd3 100644 --- a/wrappers/rust/icicle-curves/icicle-bls12-377/src/ntt/mod.rs +++ b/wrappers/rust/icicle-curves/icicle-bls12-377/src/ntt/mod.rs @@ -7,6 +7,7 @@ use icicle_core::impl_ntt; use icicle_core::ntt::{NTTConfig, NTTDir, NTT}; use icicle_core::traits::IcicleResultWrap; use icicle_cuda_runtime::device_context::DeviceContext; +use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; @@ -17,6 +18,7 @@ impl_ntt!("bw6_761", bw6_761, BaseField, BaseCfg); #[cfg(test)] pub(crate) mod tests { use crate::curve::ScalarField; + use crate::ntt::DEFAULT_DEVICE_ID; use icicle_core::impl_ntt_tests; use icicle_core::ntt::tests::*; use std::sync::OnceLock; diff --git a/wrappers/rust/icicle-curves/icicle-bls12-381/src/curve.rs b/wrappers/rust/icicle-curves/icicle-bls12-381/src/curve.rs index d4a6c2f2..c227b631 100644 --- a/wrappers/rust/icicle-curves/icicle-bls12-381/src/curve.rs +++ b/wrappers/rust/icicle-curves/icicle-bls12-381/src/curve.rs @@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective}; use icicle_core::field::{Field, MontgomeryConvertibleField}; use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom}; use icicle_core::{impl_curve, impl_field, impl_scalar_field}; -use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext}; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; diff --git a/wrappers/rust/icicle-curves/icicle-bls12-381/src/ntt/mod.rs b/wrappers/rust/icicle-curves/icicle-bls12-381/src/ntt/mod.rs index 755e7834..9e2e28c6 100644 --- a/wrappers/rust/icicle-curves/icicle-bls12-381/src/ntt/mod.rs +++ b/wrappers/rust/icicle-curves/icicle-bls12-381/src/ntt/mod.rs @@ -5,6 +5,7 @@ use icicle_core::impl_ntt; use icicle_core::ntt::{NTTConfig, NTTDir, NTT}; use icicle_core::traits::IcicleResultWrap; use icicle_cuda_runtime::device_context::DeviceContext; +use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; @@ -13,6 +14,7 @@ impl_ntt!("bls12_381", bls12_381, ScalarField, ScalarCfg); #[cfg(test)] pub(crate) mod tests { use crate::curve::ScalarField; + use crate::ntt::DEFAULT_DEVICE_ID; use icicle_core::impl_ntt_tests; use icicle_core::ntt::tests::*; use std::sync::OnceLock; diff --git a/wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs b/wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs index 2e76c239..b53264fb 100644 --- a/wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs +++ b/wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs @@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective}; use icicle_core::field::{Field, MontgomeryConvertibleField}; use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom}; use icicle_core::{impl_curve, impl_field, impl_scalar_field}; -use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext}; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; diff --git a/wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs b/wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs index 639c4ce1..f91deaf6 100644 --- a/wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs +++ b/wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs @@ -5,6 +5,7 @@ use icicle_core::impl_ntt; use icicle_core::ntt::{NTTConfig, NTTDir, NTT}; use icicle_core::traits::IcicleResultWrap; use icicle_cuda_runtime::device_context::DeviceContext; +use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; @@ -13,6 +14,7 @@ impl_ntt!("bn254", bn254, ScalarField, ScalarCfg); #[cfg(test)] pub(crate) mod tests { use crate::curve::ScalarField; + use crate::ntt::DEFAULT_DEVICE_ID; use icicle_core::impl_ntt_tests; use icicle_core::ntt::tests::*; use std::sync::OnceLock; diff --git a/wrappers/rust/icicle-curves/icicle-bw6-761/src/curve.rs b/wrappers/rust/icicle-curves/icicle-bw6-761/src/curve.rs index a9eb864e..0500f1a9 100644 --- a/wrappers/rust/icicle-curves/icicle-bw6-761/src/curve.rs +++ b/wrappers/rust/icicle-curves/icicle-bw6-761/src/curve.rs @@ -7,7 +7,7 @@ use icicle_core::curve::{Affine, Curve, Projective}; use icicle_core::field::Field; use icicle_core::traits::FieldConfig; use icicle_core::{impl_curve, impl_field}; -use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext}; +use icicle_cuda_runtime::device_context::DeviceContext; use icicle_cuda_runtime::error::CudaError; use icicle_cuda_runtime::memory::HostOrDeviceSlice; diff --git a/wrappers/rust/icicle-curves/icicle-bw6-761/src/ntt/mod.rs b/wrappers/rust/icicle-curves/icicle-bw6-761/src/ntt/mod.rs index 7e3867e8..fb673824 100644 --- a/wrappers/rust/icicle-curves/icicle-bw6-761/src/ntt/mod.rs +++ b/wrappers/rust/icicle-curves/icicle-bw6-761/src/ntt/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod tests { use crate::curve::ScalarField; use icicle_core::impl_ntt_tests; use icicle_core::ntt::tests::*; + use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID; use std::sync::OnceLock; impl_ntt_tests!(ScalarField);