Merge branch 'dev' into poseidon-examples-no-cuda

This commit is contained in:
DmytroTym
2024-02-15 08:27:28 +02:00
committed by GitHub
34 changed files with 544 additions and 366 deletions

View File

@@ -1,43 +1,20 @@
use icicle_bn254::curve::{
CurveCfg,
ScalarCfg,
G1Projective,
G2CurveCfg,
G2Projective
};
use icicle_bn254::curve::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg};
use icicle_bls12_377::curve::{
CurveCfg as BLS12377CurveCfg,
ScalarCfg as BLS12377ScalarCfg,
G1Projective as BLS12377G1Projective
CurveCfg as BLS12377CurveCfg, G1Projective as BLS12377G1Projective, ScalarCfg as BLS12377ScalarCfg,
};
use icicle_cuda_runtime::{
stream::CudaStream,
memory::HostOrDeviceSlice
};
use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream};
use icicle_core::{
msm,
curve::Curve,
traits::GenerateRandom
};
use icicle_core::{curve::Curve, msm, traits::GenerateRandom};
#[cfg(feature = "arkworks")]
use icicle_core::traits::ArkConvertible;
#[cfg(feature = "arkworks")]
use ark_bn254::{
G1Projective as Bn254ArkG1Projective,
G1Affine as Bn254G1Affine,
Fr as Bn254Fr
};
use ark_bls12_377::{Fr as Bls12377Fr, G1Affine as Bls12377G1Affine, G1Projective as Bls12377ArkG1Projective};
#[cfg(feature = "arkworks")]
use ark_bls12_377::{
G1Projective as Bls12377ArkG1Projective,
G1Affine as Bls12377G1Affine,
Fr as Bls12377Fr
};
use ark_bn254::{Fr as Bn254Fr, G1Affine as Bn254G1Affine, G1Projective as Bn254ArkG1Projective};
#[cfg(feature = "arkworks")]
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
@@ -67,23 +44,26 @@ fn main() {
let upper_points = CurveCfg::generate_random_affine_points(upper_size);
let g2_upper_points = G2CurveCfg::generate_random_affine_points(upper_size);
let upper_scalars = ScalarCfg::generate_random(upper_size);
println!("Generating random inputs on host for bls12377...");
let upper_points_bls12377 = BLS12377CurveCfg::generate_random_affine_points(upper_size);
let upper_scalars_bls12377 = BLS12377ScalarCfg::generate_random(upper_size);
for i in lower_bound..=upper_bound {
for i in lower_bound..=upper_bound {
let log_size = i;
let size = 1 << log_size;
println!("---------------------- MSM size 2^{}={} ------------------------", log_size, size);
println!(
"---------------------- MSM size 2^{}={} ------------------------",
log_size, size
);
// Setting Bn254 points and scalars
let points = HostOrDeviceSlice::Host(upper_points[..size].to_vec());
let g2_points = HostOrDeviceSlice::Host(g2_upper_points[..size].to_vec());
let scalars = HostOrDeviceSlice::Host(upper_scalars[..size].to_vec());
// Setting bls12377 points and scalars
// let points_bls12377 = &upper_points_bls12377[..size];
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
let scalars_bls12377 = HostOrDeviceSlice::Host(upper_scalars_bls12377[..size].to_vec());
println!("Configuring bn254 MSM...");
@@ -91,18 +71,24 @@ fn main() {
let mut g2_msm_results: HostOrDeviceSlice<'_, G2Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
let stream = CudaStream::create().unwrap();
let g2_stream = CudaStream::create().unwrap();
let mut cfg = msm::get_default_msm_config::<CurveCfg>();
let mut g2_cfg = msm::get_default_msm_config::<G2CurveCfg>();
cfg.ctx.stream = &stream;
g2_cfg.ctx.stream = &g2_stream;
let mut cfg = msm::MSMConfig::default();
let mut g2_cfg = msm::MSMConfig::default();
cfg.ctx
.stream = &stream;
g2_cfg
.ctx
.stream = &g2_stream;
cfg.is_async = true;
g2_cfg.is_async = true;
println!("Configuring bls12377 MSM...");
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> =
HostOrDeviceSlice::cuda_malloc(1).unwrap();
let stream_bls12377 = CudaStream::create().unwrap();
let mut cfg_bls12377 = msm::get_default_msm_config::<BLS12377CurveCfg>();
cfg_bls12377.ctx.stream = &stream_bls12377;
let mut cfg_bls12377 = msm::MSMConfig::default();
cfg_bls12377
.ctx
.stream = &stream_bls12377;
cfg_bls12377.is_async = true;
println!("Executing bn254 MSM on device...");
@@ -110,22 +96,37 @@ fn main() {
let start = Instant::now();
msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"ICICLE BN254 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
msm::msm(&scalars, &g2_points, &g2_cfg, &mut g2_msm_results).unwrap();
println!("Executing bls12377 MSM on device...");
#[cfg(feature = "profile")]
let start = Instant::now();
msm::msm(&scalars_bls12377, &points_bls12377, &cfg_bls12377, &mut msm_results_bls12377 ).unwrap();
msm::msm(
&scalars_bls12377,
&points_bls12377,
&cfg_bls12377,
&mut msm_results_bls12377,
)
.unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
println!("Moving results to host..");
let mut msm_host_result = vec![G1Projective::zero(); 1];
let mut g2_msm_host_result = vec![G2Projective::zero(); 1];
let mut msm_host_result_bls12377 = vec![BLS12377G1Projective::zero(); 1];
stream
.synchronize()
.unwrap();
@@ -140,7 +141,7 @@ fn main() {
.unwrap();
println!("bn254 result: {:#?}", msm_host_result);
println!("G2 bn254 result: {:#?}", g2_msm_host_result);
stream_bls12377
.synchronize()
.unwrap();
@@ -148,37 +149,70 @@ fn main() {
.copy_to_host(&mut msm_host_result_bls12377[..])
.unwrap();
println!("bls12377 result: {:#?}", msm_host_result_bls12377);
#[cfg(feature = "arkworks")]
{
println!("Checking against arkworks...");
let ark_points: Vec<Bn254G1Affine> = points.as_slice().iter().map(|&point| point.to_ark()).collect();
let ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let ark_points: Vec<Bn254G1Affine> = points
.as_slice()
.iter()
.map(|&point| point.to_ark())
.collect();
let ark_scalars: Vec<Bn254Fr> = scalars
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377.as_slice().iter().map(|point| point.to_ark()).collect();
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377
.as_slice()
.iter()
.map(|point| point.to_ark())
.collect();
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
#[cfg(feature = "profile")]
let start = Instant::now();
let bn254_ark_msm_res = Bn254ArkG1Projective::msm(&ark_points, &ark_scalars).unwrap();
println!("Arkworks Bn254 result: {:#?}", bn254_ark_msm_res);
#[cfg(feature = "profile")]
println!("Ark BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BN254 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
#[cfg(feature = "profile")]
let start = Instant::now();
let bls12377_ark_msm_res = Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
let bls12377_ark_msm_res =
Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
println!("Arkworks Bls12377 result: {:#?}", bls12377_ark_msm_res);
#[cfg(feature = "profile")]
println!("Ark BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BLS12377 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
let bn254_icicle_msm_res_as_ark = msm_host_result[0].to_ark();
let bls12377_icicle_msm_res_as_ark = msm_host_result_bls12377[0].to_ark();
println!("Bn254 MSM is correct: {}", bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark));
println!("Bls12377 MSM is correct: {}", bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark));
println!(
"Bn254 MSM is correct: {}",
bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark)
);
println!(
"Bls12377 MSM is correct: {}",
bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark)
);
}
println!("Cleaning up bn254...");
stream
.destroy()

View File

@@ -1,28 +1,18 @@
use icicle_bn254::curve::{
ScalarCfg,
ScalarField,
};
use icicle_bn254::curve::{ScalarCfg, ScalarField};
use icicle_bls12_377::curve::{
ScalarCfg as BLS12377ScalarCfg,
ScalarField as BLS12377ScalarField
};
use icicle_bls12_377::curve::{ScalarCfg as BLS12377ScalarCfg, ScalarField as BLS12377ScalarField};
use icicle_cuda_runtime::{
stream::CudaStream,
memory::HostOrDeviceSlice,
device_context::get_default_device_context
};
use icicle_cuda_runtime::{device_context::DeviceContext, memory::HostOrDeviceSlice, stream::CudaStream};
use icicle_core::{
ntt::{self, NTT},
traits::{GenerateRandom, FieldImpl}
traits::{FieldImpl, GenerateRandom},
};
use icicle_core::traits::ArkConvertible;
use ark_bn254::Fr as Bn254Fr;
use ark_bls12_377::Fr as Bls12377Fr;
use ark_bn254::Fr as Bn254Fr;
use ark_ff::FftField;
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use ark_std::cmp::{Ord, Ordering};
@@ -45,37 +35,52 @@ fn main() {
println!("Running Icicle Examples: Rust NTT");
let log_size = args.size;
let size = 1 << log_size;
println!("---------------------- NTT size 2^{}={} ------------------------", log_size, size);
println!(
"---------------------- NTT size 2^{}={} ------------------------",
log_size, size
);
// Setting Bn254 points and scalars
println!("Generating random inputs on host for bn254...");
let scalars = HostOrDeviceSlice::Host(ScalarCfg::generate_random(size));
let mut ntt_results: HostOrDeviceSlice<'_, ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
// Setting bls12377 points and scalars
println!("Generating random inputs on host for bls12377...");
let scalars_bls12377 = HostOrDeviceSlice::Host(BLS12377ScalarCfg::generate_random(size));
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> =
HostOrDeviceSlice::cuda_malloc(size).unwrap();
println!("Setting up bn254 Domain...");
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
let ctx = get_default_device_context();
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(
size.try_into()
.unwrap(),
)
.unwrap();
let ctx = DeviceContext::default();
ScalarCfg::initialize_domain(ScalarField::from_ark(icicle_omega), &ctx).unwrap();
println!("Configuring bn254 NTT...");
let stream = CudaStream::create().unwrap();
let mut cfg = ntt::get_default_ntt_config::<ScalarField>();
cfg.ctx.stream = &stream;
let mut cfg = ntt::NTTConfig::default();
cfg.ctx
.stream = &stream;
cfg.is_async = true;
println!("Setting up bls12377 Domain...");
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(
size.try_into()
.unwrap(),
)
.unwrap();
// reusing ctx from above
BLS12377ScalarCfg::initialize_domain(BLS12377ScalarField::from_ark(icicle_omega), &ctx).unwrap();
println!("Configuring bls12377 NTT...");
let stream_bls12377 = CudaStream::create().unwrap();
let mut cfg_bls12377 = ntt::get_default_ntt_config::<BLS12377ScalarField>();
cfg_bls12377.ctx.stream = &stream_bls12377;
let mut cfg_bls12377 = ntt::NTTConfig::default();
cfg_bls12377
.ctx
.stream = &stream_bls12377;
cfg_bls12377.is_async = true;
println!("Executing bn254 NTT on device...");
@@ -83,14 +88,30 @@ fn main() {
let start = Instant::now();
ntt::ntt(&scalars, ntt::NTTDir::kForward, &cfg, &mut ntt_results).unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BN254 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
println!(
"ICICLE BN254 NTT on size 2^{log_size} took: {} μs",
start
.elapsed()
.as_micros()
);
println!("Executing bls12377 NTT on device...");
#[cfg(feature = "profile")]
let start = Instant::now();
ntt::ntt(&scalars_bls12377, ntt::NTTDir::kForward, &cfg_bls12377, &mut ntt_results_bls12377).unwrap();
ntt::ntt(
&scalars_bls12377,
ntt::NTTDir::kForward,
&cfg_bls12377,
&mut ntt_results_bls12377,
)
.unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
println!(
"ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs",
start
.elapsed()
.as_micros()
);
println!("Moving results to host..");
stream
@@ -100,7 +121,7 @@ fn main() {
ntt_results
.copy_to_host(&mut host_bn254_results[..])
.unwrap();
stream_bls12377
.synchronize()
.unwrap();
@@ -108,25 +129,43 @@ fn main() {
ntt_results_bls12377
.copy_to_host(&mut host_bls12377_results[..])
.unwrap();
println!("Checking against arkworks...");
let mut ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let mut ark_scalars: Vec<Bn254Fr> = scalars
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let bn254_domain = <Radix2EvaluationDomain<Bn254Fr> as EvaluationDomain<Bn254Fr>>::new(size).unwrap();
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let bls12_377_domain = <Radix2EvaluationDomain<Bls12377Fr> as EvaluationDomain<Bls12377Fr>>::new(size).unwrap();
#[cfg(feature = "profile")]
let start = Instant::now();
bn254_domain.fft_in_place(&mut ark_scalars);
#[cfg(feature = "profile")]
println!("Ark BN254 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BN254 NTT on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
#[cfg(feature = "profile")]
let start = Instant::now();
bls12_377_domain.fft_in_place(&mut ark_scalars_bls12377);
#[cfg(feature = "profile")]
println!("Ark BLS12377 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BLS12377 NTT on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
host_bn254_results
.iter()
@@ -135,7 +174,7 @@ fn main() {
assert_eq!(ark_scalar.cmp(&icicle_scalar.to_ark()), Ordering::Equal);
});
println!("Bn254 NTT is correct");
host_bls12377_results
.iter()
.zip(ark_scalars_bls12377.iter())
@@ -144,7 +183,7 @@ fn main() {
});
println!("Bls12377 NTT is correct");
println!("Cleaning up bn254...");
stream
.destroy()