mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-10 07:57:56 -05:00
@@ -1,3 +1,5 @@
|
||||
inout
|
||||
crate
|
||||
lmit
|
||||
mut
|
||||
uint
|
||||
|
||||
@@ -32,7 +32,6 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP.
|
||||
|
||||
> [!NOTE]
|
||||
> Developers: We highly recommend reading our [documentation]
|
||||
|
||||
> [!TIP]
|
||||
> Try out ICICLE by running some [examples] using ICICLE in C++ and our Rust bindings
|
||||
|
||||
@@ -43,8 +42,8 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP.
|
||||
- [GCC](https://gcc.gnu.org/install/download.html) version 9, latest version is recommended.
|
||||
- Any Nvidia GPU (which supports CUDA Toolkit version 12.0 or above).
|
||||
|
||||
> [!NOTE]
|
||||
> It is possible to use CUDA 11 for cards which dont support CUDA 12, however we dont officially support this version and in the future there may be issues.
|
||||
> [!NOTE]
|
||||
> It is possible to use CUDA 11 for cards which don't support CUDA 12, however we don't officially support this version and in the future there may be issues.
|
||||
|
||||
### Accessing Hardware
|
||||
|
||||
|
||||
@@ -1,43 +1,20 @@
|
||||
use icicle_bn254::curve::{
|
||||
CurveCfg,
|
||||
ScalarCfg,
|
||||
G1Projective,
|
||||
G2CurveCfg,
|
||||
G2Projective
|
||||
};
|
||||
use icicle_bn254::curve::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg};
|
||||
|
||||
use icicle_bls12_377::curve::{
|
||||
CurveCfg as BLS12377CurveCfg,
|
||||
ScalarCfg as BLS12377ScalarCfg,
|
||||
G1Projective as BLS12377G1Projective
|
||||
CurveCfg as BLS12377CurveCfg, G1Projective as BLS12377G1Projective, ScalarCfg as BLS12377ScalarCfg,
|
||||
};
|
||||
|
||||
use icicle_cuda_runtime::{
|
||||
stream::CudaStream,
|
||||
memory::HostOrDeviceSlice
|
||||
};
|
||||
use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream};
|
||||
|
||||
use icicle_core::{
|
||||
msm,
|
||||
curve::Curve,
|
||||
traits::GenerateRandom
|
||||
};
|
||||
use icicle_core::{curve::Curve, msm, traits::GenerateRandom};
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
use icicle_core::traits::ArkConvertible;
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_bn254::{
|
||||
G1Projective as Bn254ArkG1Projective,
|
||||
G1Affine as Bn254G1Affine,
|
||||
Fr as Bn254Fr
|
||||
};
|
||||
use ark_bls12_377::{Fr as Bls12377Fr, G1Affine as Bls12377G1Affine, G1Projective as Bls12377ArkG1Projective};
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_bls12_377::{
|
||||
G1Projective as Bls12377ArkG1Projective,
|
||||
G1Affine as Bls12377G1Affine,
|
||||
Fr as Bls12377Fr
|
||||
};
|
||||
use ark_bn254::{Fr as Bn254Fr, G1Affine as Bn254G1Affine, G1Projective as Bn254ArkG1Projective};
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
|
||||
|
||||
@@ -67,23 +44,26 @@ fn main() {
|
||||
let upper_points = CurveCfg::generate_random_affine_points(upper_size);
|
||||
let g2_upper_points = G2CurveCfg::generate_random_affine_points(upper_size);
|
||||
let upper_scalars = ScalarCfg::generate_random(upper_size);
|
||||
|
||||
|
||||
println!("Generating random inputs on host for bls12377...");
|
||||
let upper_points_bls12377 = BLS12377CurveCfg::generate_random_affine_points(upper_size);
|
||||
let upper_scalars_bls12377 = BLS12377ScalarCfg::generate_random(upper_size);
|
||||
|
||||
for i in lower_bound..=upper_bound {
|
||||
for i in lower_bound..=upper_bound {
|
||||
let log_size = i;
|
||||
let size = 1 << log_size;
|
||||
println!("---------------------- MSM size 2^{}={} ------------------------", log_size, size);
|
||||
println!(
|
||||
"---------------------- MSM size 2^{}={} ------------------------",
|
||||
log_size, size
|
||||
);
|
||||
// Setting Bn254 points and scalars
|
||||
let points = HostOrDeviceSlice::Host(upper_points[..size].to_vec());
|
||||
let g2_points = HostOrDeviceSlice::Host(g2_upper_points[..size].to_vec());
|
||||
let scalars = HostOrDeviceSlice::Host(upper_scalars[..size].to_vec());
|
||||
|
||||
|
||||
// Setting bls12377 points and scalars
|
||||
// let points_bls12377 = &upper_points_bls12377[..size];
|
||||
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
|
||||
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
|
||||
let scalars_bls12377 = HostOrDeviceSlice::Host(upper_scalars_bls12377[..size].to_vec());
|
||||
|
||||
println!("Configuring bn254 MSM...");
|
||||
@@ -91,18 +71,24 @@ fn main() {
|
||||
let mut g2_msm_results: HostOrDeviceSlice<'_, G2Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
|
||||
let stream = CudaStream::create().unwrap();
|
||||
let g2_stream = CudaStream::create().unwrap();
|
||||
let mut cfg = msm::get_default_msm_config::<CurveCfg>();
|
||||
let mut g2_cfg = msm::get_default_msm_config::<G2CurveCfg>();
|
||||
cfg.ctx.stream = &stream;
|
||||
g2_cfg.ctx.stream = &g2_stream;
|
||||
let mut cfg = msm::MSMConfig::default();
|
||||
let mut g2_cfg = msm::MSMConfig::default();
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
g2_cfg
|
||||
.ctx
|
||||
.stream = &g2_stream;
|
||||
cfg.is_async = true;
|
||||
g2_cfg.is_async = true;
|
||||
|
||||
println!("Configuring bls12377 MSM...");
|
||||
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
|
||||
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> =
|
||||
HostOrDeviceSlice::cuda_malloc(1).unwrap();
|
||||
let stream_bls12377 = CudaStream::create().unwrap();
|
||||
let mut cfg_bls12377 = msm::get_default_msm_config::<BLS12377CurveCfg>();
|
||||
cfg_bls12377.ctx.stream = &stream_bls12377;
|
||||
let mut cfg_bls12377 = msm::MSMConfig::default();
|
||||
cfg_bls12377
|
||||
.ctx
|
||||
.stream = &stream_bls12377;
|
||||
cfg_bls12377.is_async = true;
|
||||
|
||||
println!("Executing bn254 MSM on device...");
|
||||
@@ -110,22 +96,37 @@ fn main() {
|
||||
let start = Instant::now();
|
||||
msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap();
|
||||
#[cfg(feature = "profile")]
|
||||
println!("ICICLE BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"ICICLE BN254 MSM on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
msm::msm(&scalars, &g2_points, &g2_cfg, &mut g2_msm_results).unwrap();
|
||||
|
||||
|
||||
println!("Executing bls12377 MSM on device...");
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
msm::msm(&scalars_bls12377, &points_bls12377, &cfg_bls12377, &mut msm_results_bls12377 ).unwrap();
|
||||
msm::msm(
|
||||
&scalars_bls12377,
|
||||
&points_bls12377,
|
||||
&cfg_bls12377,
|
||||
&mut msm_results_bls12377,
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(feature = "profile")]
|
||||
println!("ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
|
||||
println!("Moving results to host..");
|
||||
let mut msm_host_result = vec![G1Projective::zero(); 1];
|
||||
let mut g2_msm_host_result = vec![G2Projective::zero(); 1];
|
||||
let mut msm_host_result_bls12377 = vec![BLS12377G1Projective::zero(); 1];
|
||||
|
||||
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
@@ -140,7 +141,7 @@ fn main() {
|
||||
.unwrap();
|
||||
println!("bn254 result: {:#?}", msm_host_result);
|
||||
println!("G2 bn254 result: {:#?}", g2_msm_host_result);
|
||||
|
||||
|
||||
stream_bls12377
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
@@ -148,37 +149,70 @@ fn main() {
|
||||
.copy_to_host(&mut msm_host_result_bls12377[..])
|
||||
.unwrap();
|
||||
println!("bls12377 result: {:#?}", msm_host_result_bls12377);
|
||||
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
{
|
||||
println!("Checking against arkworks...");
|
||||
let ark_points: Vec<Bn254G1Affine> = points.as_slice().iter().map(|&point| point.to_ark()).collect();
|
||||
let ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
|
||||
let ark_points: Vec<Bn254G1Affine> = points
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|&point| point.to_ark())
|
||||
.collect();
|
||||
let ark_scalars: Vec<Bn254Fr> = scalars
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|scalar| scalar.to_ark())
|
||||
.collect();
|
||||
|
||||
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377.as_slice().iter().map(|point| point.to_ark()).collect();
|
||||
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
|
||||
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|point| point.to_ark())
|
||||
.collect();
|
||||
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|scalar| scalar.to_ark())
|
||||
.collect();
|
||||
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
let bn254_ark_msm_res = Bn254ArkG1Projective::msm(&ark_points, &ark_scalars).unwrap();
|
||||
println!("Arkworks Bn254 result: {:#?}", bn254_ark_msm_res);
|
||||
#[cfg(feature = "profile")]
|
||||
println!("Ark BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"Ark BN254 MSM on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
let bls12377_ark_msm_res = Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
|
||||
let bls12377_ark_msm_res =
|
||||
Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
|
||||
println!("Arkworks Bls12377 result: {:#?}", bls12377_ark_msm_res);
|
||||
#[cfg(feature = "profile")]
|
||||
println!("Ark BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"Ark BLS12377 MSM on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
|
||||
let bn254_icicle_msm_res_as_ark = msm_host_result[0].to_ark();
|
||||
let bls12377_icicle_msm_res_as_ark = msm_host_result_bls12377[0].to_ark();
|
||||
|
||||
println!("Bn254 MSM is correct: {}", bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark));
|
||||
println!("Bls12377 MSM is correct: {}", bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark));
|
||||
println!(
|
||||
"Bn254 MSM is correct: {}",
|
||||
bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark)
|
||||
);
|
||||
println!(
|
||||
"Bls12377 MSM is correct: {}",
|
||||
bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
println!("Cleaning up bn254...");
|
||||
stream
|
||||
.destroy()
|
||||
|
||||
@@ -1,28 +1,18 @@
|
||||
use icicle_bn254::curve::{
|
||||
ScalarCfg,
|
||||
ScalarField,
|
||||
};
|
||||
use icicle_bn254::curve::{ScalarCfg, ScalarField};
|
||||
|
||||
use icicle_bls12_377::curve::{
|
||||
ScalarCfg as BLS12377ScalarCfg,
|
||||
ScalarField as BLS12377ScalarField
|
||||
};
|
||||
use icicle_bls12_377::curve::{ScalarCfg as BLS12377ScalarCfg, ScalarField as BLS12377ScalarField};
|
||||
|
||||
use icicle_cuda_runtime::{
|
||||
stream::CudaStream,
|
||||
memory::HostOrDeviceSlice,
|
||||
device_context::get_default_device_context
|
||||
};
|
||||
use icicle_cuda_runtime::{device_context::DeviceContext, memory::HostOrDeviceSlice, stream::CudaStream};
|
||||
|
||||
use icicle_core::{
|
||||
ntt::{self, NTT},
|
||||
traits::{GenerateRandom, FieldImpl}
|
||||
traits::{FieldImpl, GenerateRandom},
|
||||
};
|
||||
|
||||
use icicle_core::traits::ArkConvertible;
|
||||
|
||||
use ark_bn254::Fr as Bn254Fr;
|
||||
use ark_bls12_377::Fr as Bls12377Fr;
|
||||
use ark_bn254::Fr as Bn254Fr;
|
||||
use ark_ff::FftField;
|
||||
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
|
||||
use ark_std::cmp::{Ord, Ordering};
|
||||
@@ -45,37 +35,52 @@ fn main() {
|
||||
println!("Running Icicle Examples: Rust NTT");
|
||||
let log_size = args.size;
|
||||
let size = 1 << log_size;
|
||||
println!("---------------------- NTT size 2^{}={} ------------------------", log_size, size);
|
||||
println!(
|
||||
"---------------------- NTT size 2^{}={} ------------------------",
|
||||
log_size, size
|
||||
);
|
||||
// Setting Bn254 points and scalars
|
||||
println!("Generating random inputs on host for bn254...");
|
||||
let scalars = HostOrDeviceSlice::Host(ScalarCfg::generate_random(size));
|
||||
let mut ntt_results: HostOrDeviceSlice<'_, ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
|
||||
|
||||
|
||||
// Setting bls12377 points and scalars
|
||||
println!("Generating random inputs on host for bls12377...");
|
||||
let scalars_bls12377 = HostOrDeviceSlice::Host(BLS12377ScalarCfg::generate_random(size));
|
||||
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
|
||||
|
||||
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> =
|
||||
HostOrDeviceSlice::cuda_malloc(size).unwrap();
|
||||
|
||||
println!("Setting up bn254 Domain...");
|
||||
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
|
||||
let ctx = get_default_device_context();
|
||||
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(
|
||||
size.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let ctx = DeviceContext::default();
|
||||
ScalarCfg::initialize_domain(ScalarField::from_ark(icicle_omega), &ctx).unwrap();
|
||||
|
||||
println!("Configuring bn254 NTT...");
|
||||
let stream = CudaStream::create().unwrap();
|
||||
let mut cfg = ntt::get_default_ntt_config::<ScalarField>();
|
||||
cfg.ctx.stream = &stream;
|
||||
let mut cfg = ntt::NTTConfig::default();
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
cfg.is_async = true;
|
||||
|
||||
println!("Setting up bls12377 Domain...");
|
||||
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
|
||||
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(
|
||||
size.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
// reusing ctx from above
|
||||
BLS12377ScalarCfg::initialize_domain(BLS12377ScalarField::from_ark(icicle_omega), &ctx).unwrap();
|
||||
|
||||
println!("Configuring bls12377 NTT...");
|
||||
let stream_bls12377 = CudaStream::create().unwrap();
|
||||
let mut cfg_bls12377 = ntt::get_default_ntt_config::<BLS12377ScalarField>();
|
||||
cfg_bls12377.ctx.stream = &stream_bls12377;
|
||||
let mut cfg_bls12377 = ntt::NTTConfig::default();
|
||||
cfg_bls12377
|
||||
.ctx
|
||||
.stream = &stream_bls12377;
|
||||
cfg_bls12377.is_async = true;
|
||||
|
||||
println!("Executing bn254 NTT on device...");
|
||||
@@ -83,14 +88,30 @@ fn main() {
|
||||
let start = Instant::now();
|
||||
ntt::ntt(&scalars, ntt::NTTDir::kForward, &cfg, &mut ntt_results).unwrap();
|
||||
#[cfg(feature = "profile")]
|
||||
println!("ICICLE BN254 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
|
||||
println!(
|
||||
"ICICLE BN254 NTT on size 2^{log_size} took: {} μs",
|
||||
start
|
||||
.elapsed()
|
||||
.as_micros()
|
||||
);
|
||||
|
||||
println!("Executing bls12377 NTT on device...");
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
ntt::ntt(&scalars_bls12377, ntt::NTTDir::kForward, &cfg_bls12377, &mut ntt_results_bls12377).unwrap();
|
||||
ntt::ntt(
|
||||
&scalars_bls12377,
|
||||
ntt::NTTDir::kForward,
|
||||
&cfg_bls12377,
|
||||
&mut ntt_results_bls12377,
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(feature = "profile")]
|
||||
println!("ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
|
||||
println!(
|
||||
"ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs",
|
||||
start
|
||||
.elapsed()
|
||||
.as_micros()
|
||||
);
|
||||
|
||||
println!("Moving results to host..");
|
||||
stream
|
||||
@@ -100,7 +121,7 @@ fn main() {
|
||||
ntt_results
|
||||
.copy_to_host(&mut host_bn254_results[..])
|
||||
.unwrap();
|
||||
|
||||
|
||||
stream_bls12377
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
@@ -108,25 +129,43 @@ fn main() {
|
||||
ntt_results_bls12377
|
||||
.copy_to_host(&mut host_bls12377_results[..])
|
||||
.unwrap();
|
||||
|
||||
|
||||
println!("Checking against arkworks...");
|
||||
let mut ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
|
||||
let mut ark_scalars: Vec<Bn254Fr> = scalars
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|scalar| scalar.to_ark())
|
||||
.collect();
|
||||
let bn254_domain = <Radix2EvaluationDomain<Bn254Fr> as EvaluationDomain<Bn254Fr>>::new(size).unwrap();
|
||||
|
||||
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
|
||||
|
||||
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|scalar| scalar.to_ark())
|
||||
.collect();
|
||||
let bls12_377_domain = <Radix2EvaluationDomain<Bls12377Fr> as EvaluationDomain<Bls12377Fr>>::new(size).unwrap();
|
||||
|
||||
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
bn254_domain.fft_in_place(&mut ark_scalars);
|
||||
#[cfg(feature = "profile")]
|
||||
println!("Ark BN254 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"Ark BN254 NTT on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
|
||||
#[cfg(feature = "profile")]
|
||||
let start = Instant::now();
|
||||
bls12_377_domain.fft_in_place(&mut ark_scalars_bls12377);
|
||||
#[cfg(feature = "profile")]
|
||||
println!("Ark BLS12377 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
|
||||
println!(
|
||||
"Ark BLS12377 NTT on size 2^{log_size} took: {} ms",
|
||||
start
|
||||
.elapsed()
|
||||
.as_millis()
|
||||
);
|
||||
|
||||
host_bn254_results
|
||||
.iter()
|
||||
@@ -135,7 +174,7 @@ fn main() {
|
||||
assert_eq!(ark_scalar.cmp(&icicle_scalar.to_ark()), Ordering::Equal);
|
||||
});
|
||||
println!("Bn254 NTT is correct");
|
||||
|
||||
|
||||
host_bls12377_results
|
||||
.iter()
|
||||
.zip(ark_scalars_bls12377.iter())
|
||||
@@ -144,7 +183,7 @@ fn main() {
|
||||
});
|
||||
|
||||
println!("Bls12377 NTT is correct");
|
||||
|
||||
|
||||
println!("Cleaning up bn254...");
|
||||
stream
|
||||
.destroy()
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "primitives/affine.cuh"
|
||||
#include "primitives/field.cuh"
|
||||
#include "primitives/projective.cuh"
|
||||
#include "utils/cuda_utils.cuh"
|
||||
#include "utils/error_handler.cuh"
|
||||
#include "utils/mont.cuh"
|
||||
#include "utils/utils.h"
|
||||
@@ -240,7 +239,7 @@ namespace msm {
|
||||
unsigned large_bucket_index = tid / threads_per_bucket;
|
||||
unsigned bucket_segment_index = tid % threads_per_bucket;
|
||||
if (tid >= nof_buckets_to_compute * threads_per_bucket) { return; }
|
||||
if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // dont need
|
||||
if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // don't need
|
||||
return; // skip zero buckets
|
||||
}
|
||||
unsigned write_bucket_index = bucket_segment_index * nof_buckets_to_compute + large_bucket_index;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "../../primitives/affine.cuh"
|
||||
#include "../../primitives/field.cuh"
|
||||
#include "../../primitives/projective.cuh"
|
||||
#include "../../utils/cuda_utils.cuh"
|
||||
#include "../../utils/device_context.cuh"
|
||||
#include "../../utils/error_handler.cuh"
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "utils/utils.h"
|
||||
#include "appUtils/ntt/ntt_impl.cuh"
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace ntt {
|
||||
|
||||
namespace {
|
||||
@@ -361,72 +363,88 @@ namespace ntt {
|
||||
template <typename S>
|
||||
class Domain
|
||||
{
|
||||
static inline int max_size = 0;
|
||||
static inline int max_log_size = 0;
|
||||
static inline S* twiddles = nullptr;
|
||||
static inline std::unordered_map<S, int> coset_index = {};
|
||||
// Mutex for protecting access to the domain/device container array
|
||||
static inline std::mutex device_domain_mutex;
|
||||
// The domain-per-device container - assumption is InitDomain is called once per device per program.
|
||||
|
||||
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
int max_size = 0;
|
||||
int max_log_size = 0;
|
||||
S* twiddles = nullptr;
|
||||
std::unordered_map<S, int> coset_index = {};
|
||||
|
||||
S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
|
||||
public:
|
||||
template <typename U>
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
|
||||
|
||||
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
template <typename U, typename E>
|
||||
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
};
|
||||
|
||||
template <typename S>
|
||||
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
|
||||
|
||||
template <typename S>
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
|
||||
|
||||
// only generate twiddles if they haven't been generated yet
|
||||
// please note that this is not thread-safe at all,
|
||||
// but it's a singleton that is supposed to be initialized once per program lifetime
|
||||
if (!Domain<S>::twiddles) {
|
||||
// please note that this offers just basic thread-safety,
|
||||
// it's assumed a singleton (non-enforced) that is supposed
|
||||
// to be initialized once per device per program lifetime
|
||||
if (!domain.twiddles) {
|
||||
// Mutex is automatically released when lock goes out of scope, even in case of exceptions
|
||||
std::lock_guard<std::mutex> lock(Domain<S>::device_domain_mutex);
|
||||
// double check locking
|
||||
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
|
||||
|
||||
bool found_logn = false;
|
||||
S omega = primitive_root;
|
||||
unsigned omegas_count = S::get_omegas_count();
|
||||
for (int i = 0; i < omegas_count; i++) {
|
||||
omega = S::sqr(omega);
|
||||
if (!found_logn) {
|
||||
++Domain<S>::max_log_size;
|
||||
++domain.max_log_size;
|
||||
found_logn = omega == S::one();
|
||||
if (found_logn) break;
|
||||
}
|
||||
}
|
||||
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
|
||||
|
||||
domain.max_size = (int)pow(2, domain.max_log_size);
|
||||
if (omega != S::one()) {
|
||||
throw IcicleError(
|
||||
THROW_ICICLE_ERR(
|
||||
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
|
||||
}
|
||||
|
||||
// allocate and calculate twiddles on GPU
|
||||
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
|
||||
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
|
||||
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
|
||||
CHK_IF_RETURN(cudaMallocManaged(&domain.twiddles, (domain.max_size + 1) * sizeof(S)));
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
|
||||
Domain<S>::max_log_size, ctx.stream));
|
||||
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
|
||||
ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
|
||||
|
||||
const bool is_map_only_powers_of_primitive_root = true;
|
||||
if (is_map_only_powers_of_primitive_root) {
|
||||
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
|
||||
// etc.)
|
||||
Domain<S>::coset_index[S::one()] = 0;
|
||||
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
|
||||
domain.coset_index[S::one()] = 0;
|
||||
for (int i = 0; i < domain.max_log_size; ++i) {
|
||||
const int index = (int)pow(2, i);
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
|
||||
domain.coset_index[domain.twiddles[index]] = index;
|
||||
}
|
||||
} else {
|
||||
// populate all values
|
||||
for (int i = 0; i < Domain<S>::max_size; ++i) {
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
|
||||
for (int i = 0; i < domain.max_size; ++i) {
|
||||
domain.coset_index[domain.twiddles[i]] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -527,7 +545,9 @@ namespace ntt {
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
if (size > Domain<S>::max_size) {
|
||||
Domain<S>& domain = domains_for_devices<S>[config.ctx.device_id];
|
||||
|
||||
if (size > domain.max_size) {
|
||||
std::ostringstream oss;
|
||||
oss << "NTT size=" << size
|
||||
<< " is too large for the domain. Consider generating your domain with a higher order root of unity.\n";
|
||||
@@ -565,7 +585,7 @@ namespace ntt {
|
||||
S* coset = nullptr;
|
||||
int coset_index = 0;
|
||||
try {
|
||||
coset_index = Domain<S>::coset_index.at(config.coset_gen);
|
||||
coset_index = domain.coset_index.at(config.coset_gen);
|
||||
} catch (...) {
|
||||
// if coset index is not found in the subgroup, compute coset powers on CPU and move them to device
|
||||
std::vector<S> h_coset;
|
||||
@@ -584,12 +604,12 @@ namespace ntt {
|
||||
|
||||
if (is_radix2_algorithm) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, size, Domain<S>::max_size, batch_size, is_inverse, config.ordering,
|
||||
coset, coset_index, stream));
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
|
||||
Domain<S>::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream));
|
||||
d_input, d_output, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, size, domain.max_log_size,
|
||||
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
struct cuda_ctx {
|
||||
int device_id;
|
||||
cudaMemPool_t mempool;
|
||||
cudaStream_t stream;
|
||||
|
||||
cuda_ctx(int gpu_id)
|
||||
{
|
||||
gpu_id = gpu_id;
|
||||
cudaMemPoolProps pool_props;
|
||||
pool_props.allocType = cudaMemAllocationTypePinned;
|
||||
pool_props.handleTypes = cudaMemHandleTypePosixFileDescriptor;
|
||||
pool_props.location.type = cudaMemLocationTypeDevice;
|
||||
pool_props.location.id = device_id;
|
||||
|
||||
cudaMemPoolCreate(&mempool, &pool_props);
|
||||
cudaStreamCreate(&stream);
|
||||
}
|
||||
|
||||
void set_device() { cudaSetDevice(device_id); }
|
||||
|
||||
void sync_stream() { cudaStreamSynchronize(stream); }
|
||||
|
||||
void malloc(void* ptr, size_t bytesize) { cudaMallocFromPoolAsync(&ptr, bytesize, mempool, stream); }
|
||||
|
||||
void free(void* ptr) { cudaFreeAsync(ptr, stream); }
|
||||
};
|
||||
|
||||
// -- Proposed Function Tops --------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
namespace device_context {
|
||||
|
||||
constexpr std::size_t MAX_DEVICES = 32;
|
||||
|
||||
/**
|
||||
* Properties of the device used in icicle functions.
|
||||
*/
|
||||
@@ -18,7 +20,7 @@ namespace device_context {
|
||||
/**
|
||||
* Return default device context that corresponds to using the default stream of the first GPU
|
||||
*/
|
||||
inline DeviceContext get_default_device_context()
|
||||
inline DeviceContext get_default_device_context() // TODO: naming convention ?
|
||||
{
|
||||
static cudaStream_t default_stream = (cudaStream_t)0;
|
||||
return DeviceContext{
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// TODO: remove this file, seems working without it
|
||||
// based on https://leimao.github.io/blog/CUDA-Shared-Memory-Templated-Kernel/
|
||||
// may be outdated, but only worked like that
|
||||
|
||||
|
||||
@@ -2,21 +2,23 @@
|
||||
name = "icicle-core"
|
||||
version = "1.2.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
authors = ["Ingonyama"]
|
||||
description = "A library for GPU ZK acceleration by Ingonyama"
|
||||
homepage = "https://www.ingonyama.com"
|
||||
repository = "https://github.com/ingonyama-zk/icicle"
|
||||
|
||||
|
||||
[dependencies]
|
||||
[dependencies]
|
||||
|
||||
icicle-cuda-runtime = { path = "../icicle-cuda-runtime" }
|
||||
|
||||
ark-ff = { version = "0.4.0", optional = true }
|
||||
ark-ec = { version = "0.4.0", optional = true, features = [ "parallel" ] }
|
||||
ark-ec = { version = "0.4.0", optional = true, features = ["parallel"] }
|
||||
ark-poly = { version = "0.4.0", optional = true }
|
||||
ark-std = { version = "0.4.0", optional = true }
|
||||
|
||||
rayon = "1.8.1"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
arkworks = ["ark-ff", "ark-ec", "ark-poly", "ark-std"]
|
||||
|
||||
@@ -284,7 +284,7 @@ macro_rules! impl_curve {
|
||||
points.as_mut_ptr(),
|
||||
points.len(),
|
||||
is_into,
|
||||
&get_default_device_context() as *const _ as *const DeviceContext,
|
||||
&DeviceContext::default() as *const _ as *const DeviceContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -298,7 +298,7 @@ macro_rules! impl_curve {
|
||||
points.as_mut_ptr(),
|
||||
points.len(),
|
||||
is_into,
|
||||
&get_default_device_context() as *const _ as *const DeviceContext,
|
||||
&DeviceContext::default() as *const _ as *const DeviceContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@ pub struct Field<const NUM_LIMBS: usize, F: FieldConfig> {
|
||||
p: PhantomData<F>,
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize, F: FieldConfig> Send for Field<NUM_LIMBS, F> {}
|
||||
unsafe impl<const NUM_LIMBS: usize, F: FieldConfig> Sync for Field<NUM_LIMBS, F> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F: FieldConfig> Display for Field<NUM_LIMBS, F> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "0x")?;
|
||||
@@ -165,7 +168,7 @@ macro_rules! impl_scalar_field {
|
||||
impl_field!($num_limbs, $field_name, $field_cfg, $ark_equiv);
|
||||
|
||||
mod $field_prefix_ident {
|
||||
use crate::curve::{get_default_device_context, $field_name, CudaError, DeviceContext, HostOrDeviceSlice};
|
||||
use crate::curve::{$field_name, CudaError, DeviceContext, HostOrDeviceSlice};
|
||||
|
||||
extern "C" {
|
||||
#[link_name = concat!($field_prefix, "GenerateScalars")]
|
||||
@@ -189,7 +192,7 @@ macro_rules! impl_scalar_field {
|
||||
scalars.as_mut_ptr(),
|
||||
scalars.len(),
|
||||
is_into,
|
||||
&get_default_device_context() as *const _ as *const DeviceContext,
|
||||
&DeviceContext::default() as *const _ as *const DeviceContext,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::curve::{Affine, Curve, Projective};
|
||||
use crate::error::IcicleResult;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID};
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
@@ -59,6 +59,33 @@ pub struct MSMConfig<'a> {
|
||||
pub is_async: bool,
|
||||
}
|
||||
|
||||
impl<'a> Default for MSMConfig<'a> {
|
||||
fn default() -> Self {
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> MSMConfig<'a> {
|
||||
pub fn default_for_device(device_id: usize) -> Self {
|
||||
Self {
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
points_size: 0,
|
||||
precompute_factor: 1,
|
||||
c: 0,
|
||||
bitsize: 0,
|
||||
large_bucket_factor: 10,
|
||||
batch_size: 1,
|
||||
are_scalars_on_device: false,
|
||||
are_scalars_montgomery_form: false,
|
||||
are_points_on_device: false,
|
||||
are_points_montgomery_form: false,
|
||||
are_results_on_device: false,
|
||||
is_big_triangle: false,
|
||||
is_async: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub trait MSM<C: Curve> {
|
||||
fn msm_unchecked(
|
||||
@@ -67,8 +94,6 @@ pub trait MSM<C: Curve> {
|
||||
cfg: &MSMConfig,
|
||||
results: &mut HostOrDeviceSlice<Projective<C>>,
|
||||
) -> IcicleResult<()>;
|
||||
|
||||
fn get_default_msm_config() -> MSMConfig<'static>;
|
||||
}
|
||||
|
||||
/// Computes the multi-scalar multiplication, or MSM: `s1*P1 + s2*P2 + ... + sn*Pn`, or a batch of several MSMs.
|
||||
@@ -113,11 +138,6 @@ pub fn msm<C: Curve + MSM<C>>(
|
||||
C::msm_unchecked(scalars, points, &local_cfg, results)
|
||||
}
|
||||
|
||||
/// Returns [MSM config](MSMConfig) struct populated with default values.
|
||||
pub fn get_default_msm_config<C: Curve + MSM<C>>() -> MSMConfig<'static> {
|
||||
C::get_default_msm_config()
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_msm {
|
||||
(
|
||||
@@ -161,10 +181,6 @@ macro_rules! impl_msm {
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_default_msm_config() -> MSMConfig<'static> {
|
||||
unsafe { $curve_prefix_indent::default_msm_config() }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use super::{get_default_msm_config, msm, MSM};
|
||||
use crate::curve::{Affine, Curve, Projective};
|
||||
use crate::msm::{msm, MSMConfig, MSM};
|
||||
use crate::traits::{FieldImpl, GenerateRandom};
|
||||
use icicle_cuda_runtime::device::{get_device_count, set_device};
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
use icicle_cuda_runtime::stream::CudaStream;
|
||||
use rayon::iter::IntoParallelIterator;
|
||||
use rayon::iter::ParallelIterator;
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
use crate::traits::ArkConvertible;
|
||||
@@ -19,61 +22,73 @@ where
|
||||
C::ScalarField: ArkConvertible<ArkEquivalent = <C::ArkSWConfig as ArkCurveConfig>::ScalarField>,
|
||||
C::BaseField: ArkConvertible<ArkEquivalent = <C::ArkSWConfig as ArkCurveConfig>::BaseField>,
|
||||
{
|
||||
let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18];
|
||||
let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap();
|
||||
for test_size in test_sizes {
|
||||
let points = C::generate_random_affine_points(test_size);
|
||||
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size);
|
||||
let points_ark: Vec<_> = points
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
let scalars_ark: Vec<_> = scalars
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
// if we simply transmute arkworks types, we'll get scalars or points in Montgomery format
|
||||
// (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that)
|
||||
let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) };
|
||||
let device_count = get_device_count().unwrap();
|
||||
(0..device_count) // TODO: this is proto-loadbalancer
|
||||
.into_par_iter()
|
||||
.for_each(move |device_id| {
|
||||
//TODO: currently supported multi-GPU workflow:
|
||||
// 1) User starts child host thread from parent host thread
|
||||
// 2) Calls set_device once with selected device_id (0 is default device .. < device_count)
|
||||
// 3) Perform all operations (without changing device on the thread)
|
||||
// 4) If necessary - export results to parent host thread
|
||||
|
||||
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap();
|
||||
let stream = CudaStream::create().unwrap();
|
||||
scalars_d
|
||||
.copy_from_host_async(&scalars_mont, &stream)
|
||||
.unwrap();
|
||||
set_device(device_id).unwrap();
|
||||
let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18];
|
||||
let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap();
|
||||
for test_size in test_sizes {
|
||||
let points = C::generate_random_affine_points(test_size);
|
||||
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size);
|
||||
let points_ark: Vec<_> = points
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
let scalars_ark: Vec<_> = scalars
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
// if we simply transmute arkworks types, we'll get scalars or points in Montgomery format
|
||||
// (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that)
|
||||
let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) };
|
||||
|
||||
let mut cfg = get_default_msm_config::<C>();
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
cfg.is_async = true;
|
||||
cfg.are_scalars_montgomery_form = true;
|
||||
msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap();
|
||||
// need to make sure that scalars_d weren't mutated by the previous call
|
||||
let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut scalars_mont_after, &stream)
|
||||
.unwrap();
|
||||
assert_eq!(scalars_mont, scalars_mont_after);
|
||||
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap();
|
||||
let stream = CudaStream::create().unwrap();
|
||||
scalars_d
|
||||
.copy_from_host_async(&scalars_mont, &stream)
|
||||
.unwrap();
|
||||
|
||||
let mut msm_host_result = vec![Projective::<C>::zero(); 1];
|
||||
msm_results
|
||||
.copy_to_host(&mut msm_host_result[..])
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
stream
|
||||
.destroy()
|
||||
.unwrap();
|
||||
let mut cfg = MSMConfig::default_for_device(device_id);
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
cfg.is_async = true;
|
||||
cfg.are_scalars_montgomery_form = true;
|
||||
msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap();
|
||||
// need to make sure that scalars_d weren't mutated by the previous call
|
||||
let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut scalars_mont_after, &stream)
|
||||
.unwrap();
|
||||
assert_eq!(scalars_mont, scalars_mont_after);
|
||||
|
||||
let msm_result_ark: ark_ec::models::short_weierstrass::Projective<C::ArkSWConfig> =
|
||||
VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap();
|
||||
let msm_res_affine: ark_ec::short_weierstrass::Affine<C::ArkSWConfig> = msm_host_result[0]
|
||||
.to_ark()
|
||||
.into();
|
||||
assert!(msm_res_affine.is_on_curve());
|
||||
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
|
||||
}
|
||||
let mut msm_host_result = vec![Projective::<C>::zero(); 1];
|
||||
msm_results
|
||||
.copy_to_host(&mut msm_host_result[..])
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
stream
|
||||
.destroy()
|
||||
.unwrap();
|
||||
|
||||
let msm_result_ark: ark_ec::models::short_weierstrass::Projective<C::ArkSWConfig> =
|
||||
VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap();
|
||||
let msm_res_affine: ark_ec::short_weierstrass::Affine<C::ArkSWConfig> = msm_host_result[0]
|
||||
.to_ark()
|
||||
.into();
|
||||
assert!(msm_res_affine.is_on_curve());
|
||||
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn check_msm_batch<C: Curve + MSM<C>>()
|
||||
@@ -104,7 +119,7 @@ where
|
||||
.copy_from_host_async(&points_cloned, &stream)
|
||||
.unwrap();
|
||||
|
||||
let mut cfg = get_default_msm_config::<C>();
|
||||
let mut cfg = MSMConfig::default();
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
cfg.is_async = true;
|
||||
@@ -181,7 +196,7 @@ where
|
||||
|
||||
let mut msm_results = HostOrDeviceSlice::on_host(vec![Projective::<C>::zero(); batch_size]);
|
||||
|
||||
let mut cfg = get_default_msm_config::<C>();
|
||||
let mut cfg = MSMConfig::default();
|
||||
if test_size < test_threshold {
|
||||
cfg.bitsize = 1;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID};
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
use crate::{error::IcicleResult, traits::FieldImpl};
|
||||
@@ -89,11 +89,16 @@ pub struct NTTConfig<'a, S> {
|
||||
pub ntt_algorithm: NttAlgorithm,
|
||||
}
|
||||
|
||||
impl<'a, S: FieldImpl> Default for NTTConfig<'a, S> {
|
||||
fn default() -> Self {
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
pub fn default_config() -> Self {
|
||||
let ctx = get_default_device_context();
|
||||
pub fn default_for_device(device_id: usize) -> Self {
|
||||
NTTConfig {
|
||||
ctx,
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
coset_gen: S::one(),
|
||||
batch_size: 1,
|
||||
ordering: Ordering::kNN,
|
||||
@@ -114,7 +119,6 @@ pub trait NTT<F: FieldImpl> {
|
||||
output: &mut HostOrDeviceSlice<F>,
|
||||
) -> IcicleResult<()>;
|
||||
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
fn get_default_ntt_config() -> NTTConfig<'static, F>;
|
||||
}
|
||||
|
||||
/// Computes the NTT, or a batch of several NTTs.
|
||||
@@ -169,15 +173,6 @@ where
|
||||
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain(primitive_root, ctx)
|
||||
}
|
||||
|
||||
/// Returns [NTT config](NTTConfig) struct populated with default values.
|
||||
pub fn get_default_ntt_config<F>() -> NTTConfig<'static, F>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
<<F as FieldImpl>::Config as NTT<F>>::get_default_ntt_config()
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_ntt {
|
||||
(
|
||||
@@ -187,7 +182,7 @@ macro_rules! impl_ntt {
|
||||
$field_config:ident
|
||||
) => {
|
||||
mod $field_prefix_ident {
|
||||
use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir};
|
||||
use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir, DEFAULT_DEVICE_ID};
|
||||
|
||||
extern "C" {
|
||||
#[link_name = concat!($field_prefix, "NTTCuda")]
|
||||
@@ -226,10 +221,6 @@ macro_rules! impl_ntt {
|
||||
fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
|
||||
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() }
|
||||
}
|
||||
|
||||
fn get_default_ntt_config() -> NTTConfig<'static, $field> {
|
||||
NTTConfig::<$field>::default_config()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -244,31 +235,31 @@ macro_rules! impl_ntt_tests {
|
||||
|
||||
#[test]
|
||||
fn test_ntt() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
check_ntt::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_coset_from_subgroup() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
check_ntt_coset_from_subgroup::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_arbitrary_coset() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
check_ntt_arbitrary_coset::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_batch() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
check_ntt_batch::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_device_async() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
|
||||
// init_domain is in this test is performed per-device
|
||||
check_ntt_device_async::<$field>()
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
use ark_ff::{FftField, Field as ArkField, One};
|
||||
use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
|
||||
use ark_std::{ops::Neg, test_rng, UniformRand};
|
||||
use icicle_cuda_runtime::device_context::get_default_device_context;
|
||||
use icicle_cuda_runtime::device::get_device_count;
|
||||
use icicle_cuda_runtime::device::set_device;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
use icicle_cuda_runtime::stream::CudaStream;
|
||||
use rayon::iter::IntoParallelIterator;
|
||||
use rayon::iter::ParallelIterator;
|
||||
|
||||
use crate::{
|
||||
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
|
||||
ntt::{initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
|
||||
traits::{ArkConvertible, FieldImpl, GenerateRandom},
|
||||
};
|
||||
|
||||
use super::NTTConfig;
|
||||
use super::NTT;
|
||||
|
||||
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64)
|
||||
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize)
|
||||
where
|
||||
F::ArkEquivalent: FftField,
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
let ctx = get_default_device_context();
|
||||
let ctx = DeviceContext::default_for_device(device_id);
|
||||
let ark_rou = F::ArkEquivalent::get_root_of_unity(max_size).unwrap();
|
||||
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
|
||||
}
|
||||
@@ -61,7 +65,7 @@ where
|
||||
let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) };
|
||||
let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec());
|
||||
|
||||
let mut config = get_default_ntt_config();
|
||||
let mut config = NTTConfig::default();
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ntt_algorithm = alg;
|
||||
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
@@ -107,7 +111,7 @@ where
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
let mut config = get_default_ntt_config();
|
||||
let mut config = NTTConfig::default();
|
||||
config.ordering = Ordering::kNR;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
@@ -188,7 +192,7 @@ where
|
||||
.map(|v| F::ArkEquivalent::from_random_bytes(&v.to_bytes_le()).unwrap())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
|
||||
let mut config = get_default_ntt_config();
|
||||
let mut config = NTTConfig::default();
|
||||
config.coset_gen = F::from_ark(coset_gen);
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ordering = Ordering::kNR;
|
||||
@@ -229,7 +233,7 @@ where
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
|
||||
let mut config = get_default_ntt_config();
|
||||
let mut config = NTTConfig::default();
|
||||
for batch_size in batch_sizes {
|
||||
let scalars = HostOrDeviceSlice::on_host(F::Config::generate_random(test_size * batch_size));
|
||||
|
||||
@@ -279,55 +283,65 @@ where
|
||||
F::ArkEquivalent: FftField,
|
||||
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
|
||||
{
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
|
||||
let stream = CudaStream::create().unwrap();
|
||||
let mut config = get_default_ntt_config();
|
||||
for batch_size in batch_sizes {
|
||||
let scalars_h: Vec<F> = F::Config::generate_random(test_size * batch_size);
|
||||
let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size]
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.sum();
|
||||
let mut scalars_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
|
||||
scalars_d
|
||||
.copy_from_host_async(&scalars_h, &stream)
|
||||
.unwrap();
|
||||
let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
|
||||
let device_count = get_device_count().unwrap();
|
||||
|
||||
for coset_gen in coset_generators {
|
||||
for ordering in [Ordering::kNN, Ordering::kRR] {
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = ordering;
|
||||
config.batch_size = batch_size as i32;
|
||||
config.is_async = true;
|
||||
config
|
||||
.ctx
|
||||
.stream = &stream;
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ntt_algorithm = alg;
|
||||
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
|
||||
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
|
||||
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut intt_result_h, &stream)
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
assert_eq!(scalars_h, intt_result_h);
|
||||
if coset_gen == F::one() {
|
||||
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
ntt_out_d
|
||||
.copy_to_host(&mut ntt_result_h)
|
||||
.unwrap();
|
||||
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
|
||||
(0..device_count)
|
||||
.into_par_iter()
|
||||
.for_each(move |device_id| {
|
||||
set_device(device_id).unwrap();
|
||||
init_domain::<F>(1 << 16, device_id); // init domain per device
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
|
||||
let mut config = NTTConfig::default_for_device(device_id);
|
||||
let stream = config
|
||||
.ctx
|
||||
.stream;
|
||||
for batch_size in batch_sizes {
|
||||
let scalars_h: Vec<F> = F::Config::generate_random(test_size * batch_size);
|
||||
let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size]
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.sum();
|
||||
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size * batch_size).unwrap();
|
||||
scalars_d
|
||||
.copy_from_host(&scalars_h)
|
||||
.unwrap();
|
||||
let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
|
||||
|
||||
for coset_gen in coset_generators {
|
||||
for ordering in [Ordering::kNN, Ordering::kRR] {
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = ordering;
|
||||
config.batch_size = batch_size as i32;
|
||||
config.is_async = true;
|
||||
config
|
||||
.ctx
|
||||
.stream = &stream;
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ntt_algorithm = alg;
|
||||
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
|
||||
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
|
||||
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut intt_result_h, &stream)
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
assert_eq!(scalars_h, intt_result_h);
|
||||
if coset_gen == F::one() {
|
||||
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
ntt_out_d
|
||||
.copy_to_host(&mut ntt_result_h)
|
||||
.unwrap();
|
||||
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
pub mod tests;
|
||||
|
||||
use icicle_cuda_runtime::{
|
||||
device_context::{get_default_device_context, DeviceContext},
|
||||
device_context::{DeviceContext, DEFAULT_DEVICE_ID},
|
||||
memory::HostOrDeviceSlice,
|
||||
};
|
||||
|
||||
@@ -60,9 +60,14 @@ pub struct PoseidonConfig<'a> {
|
||||
|
||||
impl<'a> Default for PoseidonConfig<'a> {
|
||||
fn default() -> Self {
|
||||
let ctx = get_default_device_context();
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> PoseidonConfig<'a> {
|
||||
pub fn default_for_device(device_id: usize) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
are_inputs_on_device: false,
|
||||
are_outputs_on_device: false,
|
||||
input_is_a_state: false,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::traits::FieldImpl;
|
||||
use icicle_cuda_runtime::device_context::get_default_device_context;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
use std::io::Read;
|
||||
@@ -15,7 +15,7 @@ pub fn init_poseidon<'a, F: FieldImpl>(arity: u32) -> PoseidonConstants<'a, F>
|
||||
where
|
||||
<F as FieldImpl>::Config: Poseidon<F>,
|
||||
{
|
||||
let ctx = get_default_device_context();
|
||||
let ctx = DeviceContext::default();
|
||||
|
||||
load_optimized_poseidon_constants::<F>(arity, &ctx).unwrap()
|
||||
}
|
||||
@@ -71,7 +71,7 @@ where
|
||||
|
||||
let full_rounds_half = 4;
|
||||
|
||||
let ctx = get_default_device_context();
|
||||
let ctx = DeviceContext::default();
|
||||
let cargo_manifest_dir = env!("CARGO_MANIFEST_DIR");
|
||||
let constants_file = PathBuf::from(cargo_manifest_dir)
|
||||
.join("tests")
|
||||
|
||||
@@ -18,7 +18,9 @@ pub trait FieldConfig: Debug + PartialEq + Copy + Clone {
|
||||
type ArkField: ArkField;
|
||||
}
|
||||
|
||||
pub trait FieldImpl: Display + Debug + PartialEq + Copy + Clone + Into<Self::Repr> + From<Self::Repr> {
|
||||
pub trait FieldImpl:
|
||||
Display + Debug + PartialEq + Copy + Clone + Into<Self::Repr> + From<Self::Repr> + Send + Sync
|
||||
{
|
||||
#[doc(hidden)]
|
||||
type Config: FieldConfig;
|
||||
type Repr;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use icicle_cuda_runtime::{
|
||||
device_context::{get_default_device_context, DeviceContext},
|
||||
device_context::{DeviceContext, DEFAULT_DEVICE_ID},
|
||||
memory::HostOrDeviceSlice,
|
||||
};
|
||||
|
||||
@@ -28,9 +28,13 @@ pub struct TreeBuilderConfig<'a> {
|
||||
|
||||
impl<'a> Default for TreeBuilderConfig<'a> {
|
||||
fn default() -> Self {
|
||||
let ctx = get_default_device_context();
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
impl<'a> TreeBuilderConfig<'a> {
|
||||
fn default_for_device(device_id: usize) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
keep_rows: 0,
|
||||
are_inputs_on_device: false,
|
||||
is_async: false,
|
||||
|
||||
@@ -45,7 +45,9 @@ fn main() {
|
||||
.must_use_type("cudaError")
|
||||
// device management
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
|
||||
.allowlist_function("cudaGetDevice")
|
||||
.allowlist_function("cudaSetDevice")
|
||||
.allowlist_function("cudaGetDeviceCount")
|
||||
// error handling
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__ERROR.html
|
||||
.allowlist_function("cudaGetLastError")
|
||||
@@ -69,6 +71,7 @@ fn main() {
|
||||
.allowlist_function("cudaMemcpyAsync")
|
||||
.allowlist_function("cudaMemset")
|
||||
.allowlist_function("cudaMemsetAsync")
|
||||
.allowlist_function("cudaDeviceGetDefaultMemPool")
|
||||
.rustified_enum("cudaMemcpyKind")
|
||||
// Stream Ordered Memory Allocator
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html
|
||||
|
||||
18
wrappers/rust/icicle-cuda-runtime/src/device.rs
Normal file
18
wrappers/rust/icicle-cuda-runtime/src/device.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use crate::{
|
||||
bindings::{cudaGetDevice, cudaGetDeviceCount, cudaSetDevice},
|
||||
error::{CudaResult, CudaResultWrap},
|
||||
};
|
||||
|
||||
pub fn set_device(device_id: usize) -> CudaResult<()> {
|
||||
unsafe { cudaSetDevice(device_id as i32) }.wrap()
|
||||
}
|
||||
|
||||
pub fn get_device_count() -> CudaResult<usize> {
|
||||
let mut count = 0;
|
||||
unsafe { cudaGetDeviceCount(&mut count) }.wrap_value(count as usize)
|
||||
}
|
||||
|
||||
pub fn get_device() -> CudaResult<usize> {
|
||||
let mut device_id = 0;
|
||||
unsafe { cudaGetDevice(&mut device_id) }.wrap_value(device_id as usize)
|
||||
}
|
||||
@@ -1,27 +1,47 @@
|
||||
use crate::memory::CudaMemPool;
|
||||
use crate::stream::CudaStream;
|
||||
|
||||
/// Properties of the device used in icicle functions.
|
||||
pub const DEFAULT_DEVICE_ID: usize = 0;
|
||||
|
||||
use crate::device::get_device;
|
||||
|
||||
/// Properties of the device used in Icicle functions.
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceContext<'a> {
|
||||
/// Stream to use. Default value: 0.
|
||||
/// Stream to use. Default value: 0. //TODO: multiple streams per device ?
|
||||
pub stream: &'a CudaStream, // Assuming the type is provided by a CUDA binding crate
|
||||
|
||||
/// Index of the currently used GPU. Default value: 0.
|
||||
pub device_id: usize,
|
||||
|
||||
/// Mempool to use. Default value: 0.
|
||||
/// Mempool to use. Default value: 0. //TODO: multiple mempools per device ?
|
||||
pub mempool: CudaMemPool, // Assuming the type is provided by a CUDA binding crate
|
||||
}
|
||||
|
||||
pub fn get_default_device_context() -> DeviceContext<'static> {
|
||||
static default_stream: CudaStream = CudaStream {
|
||||
handle: std::ptr::null_mut(),
|
||||
};
|
||||
DeviceContext {
|
||||
stream: &default_stream,
|
||||
device_id: 0,
|
||||
mempool: 0,
|
||||
impl Default for DeviceContext<'_> {
|
||||
fn default() -> Self {
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
|
||||
impl DeviceContext<'_> {
|
||||
/// Default for device_id
|
||||
pub fn default_for_device(device_id: usize) -> DeviceContext<'static> {
|
||||
static default_stream: CudaStream = CudaStream {
|
||||
handle: std::ptr::null_mut(),
|
||||
};
|
||||
DeviceContext {
|
||||
stream: &default_stream,
|
||||
device_id,
|
||||
mempool: std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_device(device_id: i32) {
|
||||
match device_id == get_device().unwrap() as i32 {
|
||||
true => (),
|
||||
false => panic!("Attempt to use on a different device"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#[allow(dead_code)]
|
||||
mod bindings;
|
||||
pub mod device;
|
||||
pub mod device_context;
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use crate::bindings::{cudaFree, cudaMalloc, cudaMallocAsync, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
|
||||
use crate::bindings::{
|
||||
cudaFree, cudaMalloc, cudaMallocAsync, cudaMemPool_t, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind,
|
||||
};
|
||||
use crate::device::get_device;
|
||||
use crate::device_context::check_device;
|
||||
use crate::error::{CudaError, CudaResult, CudaResultWrap};
|
||||
use crate::stream::CudaStream;
|
||||
use std::mem::{size_of, MaybeUninit};
|
||||
@@ -8,60 +12,69 @@ use std::slice::from_raw_parts_mut;
|
||||
|
||||
pub enum HostOrDeviceSlice<'a, T> {
|
||||
Host(Vec<T>),
|
||||
Device(&'a mut [T]),
|
||||
Device(&'a mut [T], i32),
|
||||
}
|
||||
|
||||
impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
// Function to get the device_id for Device variant
|
||||
pub fn get_device_id(&self) -> Option<i32> {
|
||||
match self {
|
||||
HostOrDeviceSlice::Device(_, device_id) => Some(*device_id),
|
||||
HostOrDeviceSlice::Host(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
Self::Device(s) => s.len(),
|
||||
Self::Device(s, _) => s.len(),
|
||||
Self::Host(v) => v.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
Self::Device(s) => s.is_empty(),
|
||||
Self::Device(s, _) => s.is_empty(),
|
||||
Self::Host(v) => v.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_on_device(&self) -> bool {
|
||||
match self {
|
||||
Self::Device(_) => true,
|
||||
Self::Device(_, _) => true,
|
||||
Self::Host(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
||||
match self {
|
||||
Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
|
||||
Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
|
||||
Self::Host(v) => v.as_mut_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
match self {
|
||||
Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
|
||||
Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
|
||||
Self::Host(v) => v.as_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const T {
|
||||
match self {
|
||||
Self::Device(s) => s.as_ptr(),
|
||||
Self::Device(s, _) => s.as_ptr(),
|
||||
Self::Host(v) => v.as_ptr(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&mut self) -> *mut T {
|
||||
match self {
|
||||
Self::Device(s) => s.as_mut_ptr(),
|
||||
Self::Device(s, _) => s.as_mut_ptr(),
|
||||
Self::Host(v) => v.as_mut_ptr(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_host(src: Vec<T>) -> Self {
|
||||
//TODO: HostOrDeviceSlice on_host() with slice input without actually copying the data
|
||||
Self::Host(src)
|
||||
}
|
||||
|
||||
@@ -70,16 +83,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
.checked_mul(size_of::<T>())
|
||||
.unwrap_or(0);
|
||||
if size == 0 {
|
||||
return Err(CudaError::cudaErrorMemoryAllocation);
|
||||
return Err(CudaError::cudaErrorMemoryAllocation); //TODO: only CUDA backend should return CudaError
|
||||
}
|
||||
|
||||
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
|
||||
unsafe {
|
||||
cudaMalloc(device_ptr.as_mut_ptr(), size).wrap()?;
|
||||
Ok(Self::Device(from_raw_parts_mut(
|
||||
device_ptr.assume_init() as *mut T,
|
||||
count,
|
||||
)))
|
||||
Ok(Self::Device(
|
||||
from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
|
||||
get_device().unwrap() as i32,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,16 +107,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
|
||||
unsafe {
|
||||
cudaMallocAsync(device_ptr.as_mut_ptr(), size, stream.handle as *mut _ as *mut _).wrap()?;
|
||||
Ok(Self::Device(from_raw_parts_mut(
|
||||
device_ptr.assume_init() as *mut T,
|
||||
count,
|
||||
)))
|
||||
Ok(Self::Device(
|
||||
from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
|
||||
get_device().unwrap() as i32,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn copy_from_host(&mut self, val: &[T]) -> CudaResult<()> {
|
||||
match self {
|
||||
Self::Device(_) => {}
|
||||
Self::Device(_, device_id) => check_device(*device_id),
|
||||
Self::Host(_) => panic!("Need device memory to copy into, and not host"),
|
||||
};
|
||||
assert!(
|
||||
@@ -127,7 +140,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
|
||||
pub fn copy_to_host(&self, val: &mut [T]) -> CudaResult<()> {
|
||||
match self {
|
||||
Self::Device(_) => {}
|
||||
Self::Device(_, device_id) => check_device(*device_id),
|
||||
Self::Host(_) => panic!("Need device memory to copy from, and not host"),
|
||||
};
|
||||
assert!(
|
||||
@@ -151,7 +164,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
|
||||
pub fn copy_from_host_async(&mut self, val: &[T], stream: &CudaStream) -> CudaResult<()> {
|
||||
match self {
|
||||
Self::Device(_) => {}
|
||||
Self::Device(_, device_id) => check_device(*device_id),
|
||||
Self::Host(_) => panic!("Need device memory to copy into, and not host"),
|
||||
};
|
||||
assert!(
|
||||
@@ -176,7 +189,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
|
||||
|
||||
pub fn copy_to_host_async(&self, val: &mut [T], stream: &CudaStream) -> CudaResult<()> {
|
||||
match self {
|
||||
Self::Device(_) => {}
|
||||
Self::Device(_, device_id) => check_device(*device_id),
|
||||
Self::Host(_) => panic!("Need device memory to copy from, and not host"),
|
||||
};
|
||||
assert!(
|
||||
@@ -209,7 +222,7 @@ macro_rules! impl_index {
|
||||
|
||||
fn index(&self, index: $t) -> &Self::Output {
|
||||
match self {
|
||||
Self::Device(s) => s.index(index),
|
||||
Self::Device(s, _) => s.index(index),
|
||||
Self::Host(v) => v.index(index),
|
||||
}
|
||||
}
|
||||
@@ -219,7 +232,7 @@ macro_rules! impl_index {
|
||||
{
|
||||
fn index_mut(&mut self, index: $t) -> &mut Self::Output {
|
||||
match self {
|
||||
Self::Device(s) => s.index_mut(index),
|
||||
Self::Device(s,_) => s.index_mut(index),
|
||||
Self::Host(v) => v.index_mut(index),
|
||||
}
|
||||
}
|
||||
@@ -239,7 +252,8 @@ impl_index! {
|
||||
impl<'a, T> Drop for HostOrDeviceSlice<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
Self::Device(s) => {
|
||||
Self::Device(s, device_id) => {
|
||||
check_device(*device_id);
|
||||
if s.is_empty() {
|
||||
return;
|
||||
}
|
||||
@@ -256,4 +270,4 @@ impl<'a, T> Drop for HostOrDeviceSlice<'a, T> {
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type CudaMemPool = usize; // This is a placeholder, TODO: actually make this into a proper CUDA wrapper
|
||||
pub type CudaMemPool = cudaMemPool_t;
|
||||
|
||||
@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
|
||||
use icicle_core::field::{Field, MontgomeryConvertibleField};
|
||||
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
|
||||
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use icicle_core::impl_ntt;
|
||||
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -17,6 +18,7 @@ impl_ntt!("bw6_761", bw6_761, BaseField, BaseCfg);
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::curve::ScalarField;
|
||||
use crate::ntt::DEFAULT_DEVICE_ID;
|
||||
use icicle_core::impl_ntt_tests;
|
||||
use icicle_core::ntt::tests::*;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
|
||||
use icicle_core::field::{Field, MontgomeryConvertibleField};
|
||||
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
|
||||
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use icicle_core::impl_ntt;
|
||||
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -13,6 +14,7 @@ impl_ntt!("bls12_381", bls12_381, ScalarField, ScalarCfg);
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::curve::ScalarField;
|
||||
use crate::ntt::DEFAULT_DEVICE_ID;
|
||||
use icicle_core::impl_ntt_tests;
|
||||
use icicle_core::ntt::tests::*;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
|
||||
use icicle_core::field::{Field, MontgomeryConvertibleField};
|
||||
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
|
||||
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use icicle_core::impl_ntt;
|
||||
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -13,6 +14,7 @@ impl_ntt!("bn254", bn254, ScalarField, ScalarCfg);
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::curve::ScalarField;
|
||||
use crate::ntt::DEFAULT_DEVICE_ID;
|
||||
use icicle_core::impl_ntt_tests;
|
||||
use icicle_core::ntt::tests::*;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
@@ -7,7 +7,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
|
||||
use icicle_core::field::Field;
|
||||
use icicle_core::traits::FieldConfig;
|
||||
use icicle_core::{impl_curve, impl_field};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ pub(crate) mod tests {
|
||||
use crate::curve::ScalarField;
|
||||
use icicle_core::impl_ntt_tests;
|
||||
use icicle_core::ntt::tests::*;
|
||||
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
impl_ntt_tests!(ScalarField);
|
||||
|
||||
Reference in New Issue
Block a user