add ntt benchmark for rust

This commit is contained in:
Yuval Shekel
2024-07-31 19:35:40 +03:00
parent d08088a4bb
commit cccd6c1679
9 changed files with 173 additions and 1 deletions

View File

@@ -427,4 +427,138 @@ macro_rules! impl_ntt_tests {
};
}
// TODO Yuval : becnhmarks
#[macro_export]
macro_rules! impl_ntt_bench {
(
$field_prefix:literal,
$field:ident
) => {
use std::{env, sync::OnceLock};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use icicle_runtime::{memory::{HostSlice,HostOrDeviceSlice},device::Device,is_device_available, get_active_device, set_device, runtime::load_backend_from_env_or_default};
use icicle_core::{
ntt::{NTTConfig, NTTInitDomainConfig, NTTDir, NttAlgorithm, Ordering, NTTDomain, ntt, NTT},
traits::{GenerateRandom,FieldImpl},
vec_ops::VecOps,
};
fn ntt_for_bench<T, F: FieldImpl>(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
mut batch_ntt_result: &mut (impl HostOrDeviceSlice<F> + ?Sized),
dir: NTTDir,
config: &mut NTTConfig<F>,
_seed: u32,
) where
<F as FieldImpl>::Config: NTT<F, F> + GenerateRandom<F>,
<F as FieldImpl>::Config: VecOps<F>,
{
ntt(input, dir, config, batch_ntt_result).unwrap();
}
static INIT: OnceLock<()> = OnceLock::new();
fn load_and_init_backend_device() {
// Attempt to load the backends
load_backend_from_env_or_default(); // try loading from /opt/icicle/backend or env ${ICICLE_BACKEND_INSTALL_DIR}
// Check if BENCH_TARGET is defined
let target = env::var("BENCH_TARGET").unwrap_or_else(|_| {
// If not defined, try CUDA first, fallback to CPU
if is_device_available(&Device::new("CUDA", 0)) {
"CUDA".to_string()
} else {
"CPU".to_string()
}
});
// Initialize the device with the determined target
let device = Device::new(&target, 0);
set_device(&device).unwrap();
println!("ICICLE benchmark with {:?}", device);
}
fn benchmark_ntt<T, F: FieldImpl>(c: &mut Criterion)
where
<F as FieldImpl>::Config: NTT<F, F> + GenerateRandom<F>,
<F as FieldImpl>::Config: VecOps<F>,
{
use criterion::SamplingMode;
use icicle_core::ntt::tests::init_domain;
use std::env;
load_and_init_backend_device();
let group_id = format!("{} NTT", $field_prefix);
let mut group = c.benchmark_group(&group_id);
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
const MAX_LOG2: u32 = 25; // max length = 2 ^ MAX_LOG2
let max_log2 = env::var("MAX_LOG2")
.unwrap_or_else(|_| MAX_LOG2.to_string())
.parse::<u32>()
.unwrap_or(MAX_LOG2);
const FAST_TWIDDLES_MODE: bool = false;
INIT.get_or_init(move || init_domain::<$field>(1 << max_log2, FAST_TWIDDLES_MODE));
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
let mut config = NTTConfig::<F>::default();
for test_size_log2 in (13u32..=max_log2) {
for batch_size_log2 in (7u32..17u32) {
let test_size = 1 << test_size_log2;
let batch_size = 1 << batch_size_log2;
let full_size = batch_size * test_size;
if full_size > 1 << max_log2 {
continue;
}
let scalars = F::Config::generate_random(full_size);
let input = HostSlice::from_slice(&scalars);
let mut batch_ntt_result = vec![F::zero(); batch_size * test_size];
let batch_ntt_result = HostSlice::from_mut_slice(&mut batch_ntt_result);
let mut config = NTTConfig::<F>::default();
for dir in [NTTDir::kForward, NTTDir::kInverse ] {
for ordering in [
Ordering::kNN,
Ordering::kNR,
Ordering::kRN,
Ordering::kRR,
Ordering::kNM,
Ordering::kMN,
] {
config.ordering = ordering;
config.batch_size = batch_size as i32;
let bench_descr = format!(
"{:?} {:?} {} x {}",
ordering, dir, test_size, batch_size
);
group.bench_function(&bench_descr, |b| {
b.iter(|| {
ntt_for_bench::<F, F>(
input,
batch_ntt_result,
dir,
&mut config,
black_box(1),
)
})
});
}
}
}
}
group.finish();
}
criterion_group!(benches, benchmark_ntt<$field, $field>);
criterion_main!(benches);
};
}

View File

@@ -27,3 +27,10 @@ g2 = ["icicle-core/g2"]
ec_ntt = ["icicle-core/ec_ntt"]
cuda_backend = ["icicle-runtime/cuda_backend"]
pull_cuda_backend = ["icicle-runtime/pull_cuda_backend"]
[[bench]]
name = "ntt"
harness = false

View File

@@ -0,0 +1,5 @@
use icicle_bls12_377::curve::ScalarField;
use icicle_core::impl_ntt_bench;
impl_ntt_bench!("bls12_377", ScalarField);

View File

@@ -25,3 +25,7 @@ g2 = ["icicle-core/g2"]
ec_ntt = ["icicle-core/ec_ntt"]
cuda_backend = ["icicle-runtime/cuda_backend"]
pull_cuda_backend = ["icicle-runtime/pull_cuda_backend"]
[[bench]]
name = "ntt"
harness = false

View File

@@ -0,0 +1,5 @@
use icicle_bls12_381::curve::ScalarField;
use icicle_core::impl_ntt_bench;
impl_ntt_bench!("bls12_381", ScalarField);

View File

@@ -25,3 +25,7 @@ g2 = ["icicle-core/g2"]
ec_ntt = ["icicle-core/ec_ntt"]
cuda_backend = ["icicle-runtime/cuda_backend"]
pull_cuda_backend = ["icicle-runtime/pull_cuda_backend"]
[[bench]]
name = "ntt"
harness = false

View File

@@ -0,0 +1,5 @@
use icicle_bn254::curve::ScalarField;
use icicle_core::impl_ntt_bench;
impl_ntt_bench!("bn254", ScalarField);

View File

@@ -24,3 +24,6 @@ cmake = "0.1.50"
default = []
g2 = ["icicle-bls12-377/bw6-761-g2"]
[[bench]]
name = "ntt"
harness = false

View File

@@ -0,0 +1,5 @@
use icicle_bw6_761::curve::ScalarField;
use icicle_core::impl_ntt_bench;
impl_ntt_bench!("bw6_761", ScalarField);