multi card support (#356)

multi-GPU support
This commit is contained in:
VitaliiH
2024-02-14 22:29:30 +01:00
committed by GitHub
parent 0d70a0c003
commit 774250926c
34 changed files with 544 additions and 366 deletions

View File

@@ -1,3 +1,5 @@
inout
crate
lmit
mut
uint

View File

@@ -32,7 +32,6 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP.
> [!NOTE]
> Developers: We highly recommend reading our [documentation]
> [!TIP]
> Try out ICICLE by running some [examples] using ICICLE in C++ and our Rust bindings
@@ -43,8 +42,8 @@ ICICLE is a CUDA implementation of general functions widely used in ZKP.
- [GCC](https://gcc.gnu.org/install/download.html) version 9, latest version is recommended.
- Any Nvidia GPU (which supports CUDA Toolkit version 12.0 or above).
> [!NOTE]
> It is possible to use CUDA 11 for cards which dont support CUDA 12, however we dont officially support this version and in the future there may be issues.
> [!NOTE]
> It is possible to use CUDA 11 for cards which don't support CUDA 12, however we don't officially support this version and in the future there may be issues.
### Accessing Hardware

View File

@@ -1,43 +1,20 @@
use icicle_bn254::curve::{
CurveCfg,
ScalarCfg,
G1Projective,
G2CurveCfg,
G2Projective
};
use icicle_bn254::curve::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg};
use icicle_bls12_377::curve::{
CurveCfg as BLS12377CurveCfg,
ScalarCfg as BLS12377ScalarCfg,
G1Projective as BLS12377G1Projective
CurveCfg as BLS12377CurveCfg, G1Projective as BLS12377G1Projective, ScalarCfg as BLS12377ScalarCfg,
};
use icicle_cuda_runtime::{
stream::CudaStream,
memory::HostOrDeviceSlice
};
use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream};
use icicle_core::{
msm,
curve::Curve,
traits::GenerateRandom
};
use icicle_core::{curve::Curve, msm, traits::GenerateRandom};
#[cfg(feature = "arkworks")]
use icicle_core::traits::ArkConvertible;
#[cfg(feature = "arkworks")]
use ark_bn254::{
G1Projective as Bn254ArkG1Projective,
G1Affine as Bn254G1Affine,
Fr as Bn254Fr
};
use ark_bls12_377::{Fr as Bls12377Fr, G1Affine as Bls12377G1Affine, G1Projective as Bls12377ArkG1Projective};
#[cfg(feature = "arkworks")]
use ark_bls12_377::{
G1Projective as Bls12377ArkG1Projective,
G1Affine as Bls12377G1Affine,
Fr as Bls12377Fr
};
use ark_bn254::{Fr as Bn254Fr, G1Affine as Bn254G1Affine, G1Projective as Bn254ArkG1Projective};
#[cfg(feature = "arkworks")]
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
@@ -67,23 +44,26 @@ fn main() {
let upper_points = CurveCfg::generate_random_affine_points(upper_size);
let g2_upper_points = G2CurveCfg::generate_random_affine_points(upper_size);
let upper_scalars = ScalarCfg::generate_random(upper_size);
println!("Generating random inputs on host for bls12377...");
let upper_points_bls12377 = BLS12377CurveCfg::generate_random_affine_points(upper_size);
let upper_scalars_bls12377 = BLS12377ScalarCfg::generate_random(upper_size);
for i in lower_bound..=upper_bound {
for i in lower_bound..=upper_bound {
let log_size = i;
let size = 1 << log_size;
println!("---------------------- MSM size 2^{}={} ------------------------", log_size, size);
println!(
"---------------------- MSM size 2^{}={} ------------------------",
log_size, size
);
// Setting Bn254 points and scalars
let points = HostOrDeviceSlice::Host(upper_points[..size].to_vec());
let g2_points = HostOrDeviceSlice::Host(g2_upper_points[..size].to_vec());
let scalars = HostOrDeviceSlice::Host(upper_scalars[..size].to_vec());
// Setting bls12377 points and scalars
// let points_bls12377 = &upper_points_bls12377[..size];
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
let points_bls12377 = HostOrDeviceSlice::Host(upper_points_bls12377[..size].to_vec()); // &upper_points_bls12377[..size];
let scalars_bls12377 = HostOrDeviceSlice::Host(upper_scalars_bls12377[..size].to_vec());
println!("Configuring bn254 MSM...");
@@ -91,18 +71,24 @@ fn main() {
let mut g2_msm_results: HostOrDeviceSlice<'_, G2Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
let stream = CudaStream::create().unwrap();
let g2_stream = CudaStream::create().unwrap();
let mut cfg = msm::get_default_msm_config::<CurveCfg>();
let mut g2_cfg = msm::get_default_msm_config::<G2CurveCfg>();
cfg.ctx.stream = &stream;
g2_cfg.ctx.stream = &g2_stream;
let mut cfg = msm::MSMConfig::default();
let mut g2_cfg = msm::MSMConfig::default();
cfg.ctx
.stream = &stream;
g2_cfg
.ctx
.stream = &g2_stream;
cfg.is_async = true;
g2_cfg.is_async = true;
println!("Configuring bls12377 MSM...");
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> = HostOrDeviceSlice::cuda_malloc(1).unwrap();
let mut msm_results_bls12377: HostOrDeviceSlice<'_, BLS12377G1Projective> =
HostOrDeviceSlice::cuda_malloc(1).unwrap();
let stream_bls12377 = CudaStream::create().unwrap();
let mut cfg_bls12377 = msm::get_default_msm_config::<BLS12377CurveCfg>();
cfg_bls12377.ctx.stream = &stream_bls12377;
let mut cfg_bls12377 = msm::MSMConfig::default();
cfg_bls12377
.ctx
.stream = &stream_bls12377;
cfg_bls12377.is_async = true;
println!("Executing bn254 MSM on device...");
@@ -110,22 +96,37 @@ fn main() {
let start = Instant::now();
msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"ICICLE BN254 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
msm::msm(&scalars, &g2_points, &g2_cfg, &mut g2_msm_results).unwrap();
println!("Executing bls12377 MSM on device...");
#[cfg(feature = "profile")]
let start = Instant::now();
msm::msm(&scalars_bls12377, &points_bls12377, &cfg_bls12377, &mut msm_results_bls12377 ).unwrap();
msm::msm(
&scalars_bls12377,
&points_bls12377,
&cfg_bls12377,
&mut msm_results_bls12377,
)
.unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"ICICLE BLS12377 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
println!("Moving results to host..");
let mut msm_host_result = vec![G1Projective::zero(); 1];
let mut g2_msm_host_result = vec![G2Projective::zero(); 1];
let mut msm_host_result_bls12377 = vec![BLS12377G1Projective::zero(); 1];
stream
.synchronize()
.unwrap();
@@ -140,7 +141,7 @@ fn main() {
.unwrap();
println!("bn254 result: {:#?}", msm_host_result);
println!("G2 bn254 result: {:#?}", g2_msm_host_result);
stream_bls12377
.synchronize()
.unwrap();
@@ -148,37 +149,70 @@ fn main() {
.copy_to_host(&mut msm_host_result_bls12377[..])
.unwrap();
println!("bls12377 result: {:#?}", msm_host_result_bls12377);
#[cfg(feature = "arkworks")]
{
println!("Checking against arkworks...");
let ark_points: Vec<Bn254G1Affine> = points.as_slice().iter().map(|&point| point.to_ark()).collect();
let ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let ark_points: Vec<Bn254G1Affine> = points
.as_slice()
.iter()
.map(|&point| point.to_ark())
.collect();
let ark_scalars: Vec<Bn254Fr> = scalars
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377.as_slice().iter().map(|point| point.to_ark()).collect();
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let ark_points_bls12377: Vec<Bls12377G1Affine> = points_bls12377
.as_slice()
.iter()
.map(|point| point.to_ark())
.collect();
let ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
#[cfg(feature = "profile")]
let start = Instant::now();
let bn254_ark_msm_res = Bn254ArkG1Projective::msm(&ark_points, &ark_scalars).unwrap();
println!("Arkworks Bn254 result: {:#?}", bn254_ark_msm_res);
#[cfg(feature = "profile")]
println!("Ark BN254 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BN254 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
#[cfg(feature = "profile")]
let start = Instant::now();
let bls12377_ark_msm_res = Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
let bls12377_ark_msm_res =
Bls12377ArkG1Projective::msm(&ark_points_bls12377, &ark_scalars_bls12377).unwrap();
println!("Arkworks Bls12377 result: {:#?}", bls12377_ark_msm_res);
#[cfg(feature = "profile")]
println!("Ark BLS12377 MSM on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BLS12377 MSM on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
let bn254_icicle_msm_res_as_ark = msm_host_result[0].to_ark();
let bls12377_icicle_msm_res_as_ark = msm_host_result_bls12377[0].to_ark();
println!("Bn254 MSM is correct: {}", bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark));
println!("Bls12377 MSM is correct: {}", bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark));
println!(
"Bn254 MSM is correct: {}",
bn254_ark_msm_res.eq(&bn254_icicle_msm_res_as_ark)
);
println!(
"Bls12377 MSM is correct: {}",
bls12377_ark_msm_res.eq(&bls12377_icicle_msm_res_as_ark)
);
}
println!("Cleaning up bn254...");
stream
.destroy()

View File

@@ -1,28 +1,18 @@
use icicle_bn254::curve::{
ScalarCfg,
ScalarField,
};
use icicle_bn254::curve::{ScalarCfg, ScalarField};
use icicle_bls12_377::curve::{
ScalarCfg as BLS12377ScalarCfg,
ScalarField as BLS12377ScalarField
};
use icicle_bls12_377::curve::{ScalarCfg as BLS12377ScalarCfg, ScalarField as BLS12377ScalarField};
use icicle_cuda_runtime::{
stream::CudaStream,
memory::HostOrDeviceSlice,
device_context::get_default_device_context
};
use icicle_cuda_runtime::{device_context::DeviceContext, memory::HostOrDeviceSlice, stream::CudaStream};
use icicle_core::{
ntt::{self, NTT},
traits::{GenerateRandom, FieldImpl}
traits::{FieldImpl, GenerateRandom},
};
use icicle_core::traits::ArkConvertible;
use ark_bn254::Fr as Bn254Fr;
use ark_bls12_377::Fr as Bls12377Fr;
use ark_bn254::Fr as Bn254Fr;
use ark_ff::FftField;
use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use ark_std::cmp::{Ord, Ordering};
@@ -45,37 +35,52 @@ fn main() {
println!("Running Icicle Examples: Rust NTT");
let log_size = args.size;
let size = 1 << log_size;
println!("---------------------- NTT size 2^{}={} ------------------------", log_size, size);
println!(
"---------------------- NTT size 2^{}={} ------------------------",
log_size, size
);
// Setting Bn254 points and scalars
println!("Generating random inputs on host for bn254...");
let scalars = HostOrDeviceSlice::Host(ScalarCfg::generate_random(size));
let mut ntt_results: HostOrDeviceSlice<'_, ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
// Setting bls12377 points and scalars
println!("Generating random inputs on host for bls12377...");
let scalars_bls12377 = HostOrDeviceSlice::Host(BLS12377ScalarCfg::generate_random(size));
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> = HostOrDeviceSlice::cuda_malloc(size).unwrap();
let mut ntt_results_bls12377: HostOrDeviceSlice<'_, BLS12377ScalarField> =
HostOrDeviceSlice::cuda_malloc(size).unwrap();
println!("Setting up bn254 Domain...");
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
let ctx = get_default_device_context();
let icicle_omega = <Bn254Fr as FftField>::get_root_of_unity(
size.try_into()
.unwrap(),
)
.unwrap();
let ctx = DeviceContext::default();
ScalarCfg::initialize_domain(ScalarField::from_ark(icicle_omega), &ctx).unwrap();
println!("Configuring bn254 NTT...");
let stream = CudaStream::create().unwrap();
let mut cfg = ntt::get_default_ntt_config::<ScalarField>();
cfg.ctx.stream = &stream;
let mut cfg = ntt::NTTConfig::default();
cfg.ctx
.stream = &stream;
cfg.is_async = true;
println!("Setting up bls12377 Domain...");
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(size.try_into().unwrap()).unwrap();
let icicle_omega = <Bls12377Fr as FftField>::get_root_of_unity(
size.try_into()
.unwrap(),
)
.unwrap();
// reusing ctx from above
BLS12377ScalarCfg::initialize_domain(BLS12377ScalarField::from_ark(icicle_omega), &ctx).unwrap();
println!("Configuring bls12377 NTT...");
let stream_bls12377 = CudaStream::create().unwrap();
let mut cfg_bls12377 = ntt::get_default_ntt_config::<BLS12377ScalarField>();
cfg_bls12377.ctx.stream = &stream_bls12377;
let mut cfg_bls12377 = ntt::NTTConfig::default();
cfg_bls12377
.ctx
.stream = &stream_bls12377;
cfg_bls12377.is_async = true;
println!("Executing bn254 NTT on device...");
@@ -83,14 +88,30 @@ fn main() {
let start = Instant::now();
ntt::ntt(&scalars, ntt::NTTDir::kForward, &cfg, &mut ntt_results).unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BN254 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
println!(
"ICICLE BN254 NTT on size 2^{log_size} took: {} μs",
start
.elapsed()
.as_micros()
);
println!("Executing bls12377 NTT on device...");
#[cfg(feature = "profile")]
let start = Instant::now();
ntt::ntt(&scalars_bls12377, ntt::NTTDir::kForward, &cfg_bls12377, &mut ntt_results_bls12377).unwrap();
ntt::ntt(
&scalars_bls12377,
ntt::NTTDir::kForward,
&cfg_bls12377,
&mut ntt_results_bls12377,
)
.unwrap();
#[cfg(feature = "profile")]
println!("ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs", start.elapsed().as_micros());
println!(
"ICICLE BLS12377 NTT on size 2^{log_size} took: {} μs",
start
.elapsed()
.as_micros()
);
println!("Moving results to host..");
stream
@@ -100,7 +121,7 @@ fn main() {
ntt_results
.copy_to_host(&mut host_bn254_results[..])
.unwrap();
stream_bls12377
.synchronize()
.unwrap();
@@ -108,25 +129,43 @@ fn main() {
ntt_results_bls12377
.copy_to_host(&mut host_bls12377_results[..])
.unwrap();
println!("Checking against arkworks...");
let mut ark_scalars: Vec<Bn254Fr> = scalars.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let mut ark_scalars: Vec<Bn254Fr> = scalars
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let bn254_domain = <Radix2EvaluationDomain<Bn254Fr> as EvaluationDomain<Bn254Fr>>::new(size).unwrap();
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377.as_slice().iter().map(|scalar| scalar.to_ark()).collect();
let mut ark_scalars_bls12377: Vec<Bls12377Fr> = scalars_bls12377
.as_slice()
.iter()
.map(|scalar| scalar.to_ark())
.collect();
let bls12_377_domain = <Radix2EvaluationDomain<Bls12377Fr> as EvaluationDomain<Bls12377Fr>>::new(size).unwrap();
#[cfg(feature = "profile")]
let start = Instant::now();
bn254_domain.fft_in_place(&mut ark_scalars);
#[cfg(feature = "profile")]
println!("Ark BN254 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BN254 NTT on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
#[cfg(feature = "profile")]
let start = Instant::now();
bls12_377_domain.fft_in_place(&mut ark_scalars_bls12377);
#[cfg(feature = "profile")]
println!("Ark BLS12377 NTT on size 2^{log_size} took: {} ms", start.elapsed().as_millis());
println!(
"Ark BLS12377 NTT on size 2^{log_size} took: {} ms",
start
.elapsed()
.as_millis()
);
host_bn254_results
.iter()
@@ -135,7 +174,7 @@ fn main() {
assert_eq!(ark_scalar.cmp(&icicle_scalar.to_ark()), Ordering::Equal);
});
println!("Bn254 NTT is correct");
host_bls12377_results
.iter()
.zip(ark_scalars_bls12377.iter())
@@ -144,7 +183,7 @@ fn main() {
});
println!("Bls12377 NTT is correct");
println!("Cleaning up bn254...");
stream
.destroy()

View File

@@ -14,7 +14,6 @@
#include "primitives/affine.cuh"
#include "primitives/field.cuh"
#include "primitives/projective.cuh"
#include "utils/cuda_utils.cuh"
#include "utils/error_handler.cuh"
#include "utils/mont.cuh"
#include "utils/utils.h"
@@ -240,7 +239,7 @@ namespace msm {
unsigned large_bucket_index = tid / threads_per_bucket;
unsigned bucket_segment_index = tid % threads_per_bucket;
if (tid >= nof_buckets_to_compute * threads_per_bucket) { return; }
if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // dont need
if ((single_bucket_indices[large_bucket_index] & ((1 << c) - 1)) == 0) { // don't need
return; // skip zero buckets
}
unsigned write_bucket_index = bucket_segment_index * nof_buckets_to_compute + large_bucket_index;

View File

@@ -8,7 +8,6 @@
#include "../../primitives/affine.cuh"
#include "../../primitives/field.cuh"
#include "../../primitives/projective.cuh"
#include "../../utils/cuda_utils.cuh"
#include "../../utils/device_context.cuh"
#include "../../utils/error_handler.cuh"

View File

@@ -9,6 +9,8 @@
#include "utils/utils.h"
#include "appUtils/ntt/ntt_impl.cuh"
#include <mutex>
namespace ntt {
namespace {
@@ -361,72 +363,88 @@ namespace ntt {
template <typename S>
class Domain
{
static inline int max_size = 0;
static inline int max_log_size = 0;
static inline S* twiddles = nullptr;
static inline std::unordered_map<S, int> coset_index = {};
// Mutex for protecting access to the domain/device container array
static inline std::mutex device_domain_mutex;
// The domain-per-device container - assumption is InitDomain is called once per device per program.
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
int max_size = 0;
int max_log_size = 0;
S* twiddles = nullptr;
std::unordered_map<S, int> coset_index = {};
S* internal_twiddles = nullptr; // required by mixed-radix NTT
S* basic_twiddles = nullptr; // required by mixed-radix NTT
public:
template <typename U>
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
template <typename U, typename E>
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
};
template <typename S>
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
template <typename S>
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
{
CHK_INIT_IF_RETURN();
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
// only generate twiddles if they haven't been generated yet
// please note that this is not thread-safe at all,
// but it's a singleton that is supposed to be initialized once per program lifetime
if (!Domain<S>::twiddles) {
// please note that this offers just basic thread-safety,
// it's assumed a singleton (non-enforced) that is supposed
// to be initialized once per device per program lifetime
if (!domain.twiddles) {
// Mutex is automatically released when lock goes out of scope, even in case of exceptions
std::lock_guard<std::mutex> lock(Domain<S>::device_domain_mutex);
// double check locking
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
bool found_logn = false;
S omega = primitive_root;
unsigned omegas_count = S::get_omegas_count();
for (int i = 0; i < omegas_count; i++) {
omega = S::sqr(omega);
if (!found_logn) {
++Domain<S>::max_log_size;
++domain.max_log_size;
found_logn = omega == S::one();
if (found_logn) break;
}
}
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
domain.max_size = (int)pow(2, domain.max_log_size);
if (omega != S::one()) {
throw IcicleError(
THROW_ICICLE_ERR(
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
}
// allocate and calculate twiddles on GPU
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
CHK_IF_RETURN(cudaMallocManaged(&domain.twiddles, (domain.max_size + 1) * sizeof(S)));
CHK_IF_RETURN(generate_external_twiddles_generic(
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
Domain<S>::max_log_size, ctx.stream));
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
const bool is_map_only_powers_of_primitive_root = true;
if (is_map_only_powers_of_primitive_root) {
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
// etc.)
Domain<S>::coset_index[S::one()] = 0;
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
domain.coset_index[S::one()] = 0;
for (int i = 0; i < domain.max_log_size; ++i) {
const int index = (int)pow(2, i);
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
domain.coset_index[domain.twiddles[index]] = index;
}
} else {
// populate all values
for (int i = 0; i < Domain<S>::max_size; ++i) {
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
for (int i = 0; i < domain.max_size; ++i) {
domain.coset_index[domain.twiddles[i]] = i;
}
}
}
@@ -527,7 +545,9 @@ namespace ntt {
{
CHK_INIT_IF_RETURN();
if (size > Domain<S>::max_size) {
Domain<S>& domain = domains_for_devices<S>[config.ctx.device_id];
if (size > domain.max_size) {
std::ostringstream oss;
oss << "NTT size=" << size
<< " is too large for the domain. Consider generating your domain with a higher order root of unity.\n";
@@ -565,7 +585,7 @@ namespace ntt {
S* coset = nullptr;
int coset_index = 0;
try {
coset_index = Domain<S>::coset_index.at(config.coset_gen);
coset_index = domain.coset_index.at(config.coset_gen);
} catch (...) {
// if coset index is not found in the subgroup, compute coset powers on CPU and move them to device
std::vector<S> h_coset;
@@ -584,12 +604,12 @@ namespace ntt {
if (is_radix2_algorithm) {
CHK_IF_RETURN(ntt::radix2_ntt(
d_input, d_output, Domain<S>::twiddles, size, Domain<S>::max_size, batch_size, is_inverse, config.ordering,
coset, coset_index, stream));
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
} else {
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream));
d_input, d_output, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, size, domain.max_log_size,
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
}
if (!are_outputs_on_device)

View File

@@ -1,32 +0,0 @@
#pragma once
#include <cuda_runtime.h>
struct cuda_ctx {
int device_id;
cudaMemPool_t mempool;
cudaStream_t stream;
cuda_ctx(int gpu_id)
{
gpu_id = gpu_id;
cudaMemPoolProps pool_props;
pool_props.allocType = cudaMemAllocationTypePinned;
pool_props.handleTypes = cudaMemHandleTypePosixFileDescriptor;
pool_props.location.type = cudaMemLocationTypeDevice;
pool_props.location.id = device_id;
cudaMemPoolCreate(&mempool, &pool_props);
cudaStreamCreate(&stream);
}
void set_device() { cudaSetDevice(device_id); }
void sync_stream() { cudaStreamSynchronize(stream); }
void malloc(void* ptr, size_t bytesize) { cudaMallocFromPoolAsync(&ptr, bytesize, mempool, stream); }
void free(void* ptr) { cudaFreeAsync(ptr, stream); }
};
// -- Proposed Function Tops --------------------------------------------------
// ----------------------------------------------------------------------------

View File

@@ -6,6 +6,8 @@
namespace device_context {
constexpr std::size_t MAX_DEVICES = 32;
/**
* Properties of the device used in icicle functions.
*/
@@ -18,7 +20,7 @@ namespace device_context {
/**
* Return default device context that corresponds to using the default stream of the first GPU
*/
inline DeviceContext get_default_device_context()
inline DeviceContext get_default_device_context() // TODO: naming convention ?
{
static cudaStream_t default_stream = (cudaStream_t)0;
return DeviceContext{

View File

@@ -1,3 +1,4 @@
// TODO: remove this file, seems working without it
// based on https://leimao.github.io/blog/CUDA-Shared-Memory-Templated-Kernel/
// may be outdated, but only worked like that

View File

@@ -2,21 +2,23 @@
name = "icicle-core"
version = "1.2.0"
edition = "2021"
authors = [ "Ingonyama" ]
authors = ["Ingonyama"]
description = "A library for GPU ZK acceleration by Ingonyama"
homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/icicle"
[dependencies]
[dependencies]
icicle-cuda-runtime = { path = "../icicle-cuda-runtime" }
ark-ff = { version = "0.4.0", optional = true }
ark-ec = { version = "0.4.0", optional = true, features = [ "parallel" ] }
ark-ec = { version = "0.4.0", optional = true, features = ["parallel"] }
ark-poly = { version = "0.4.0", optional = true }
ark-std = { version = "0.4.0", optional = true }
rayon = "1.8.1"
[features]
default = []
arkworks = ["ark-ff", "ark-ec", "ark-poly", "ark-std"]

View File

@@ -284,7 +284,7 @@ macro_rules! impl_curve {
points.as_mut_ptr(),
points.len(),
is_into,
&get_default_device_context() as *const _ as *const DeviceContext,
&DeviceContext::default() as *const _ as *const DeviceContext,
)
}
}
@@ -298,7 +298,7 @@ macro_rules! impl_curve {
points.as_mut_ptr(),
points.len(),
is_into,
&get_default_device_context() as *const _ as *const DeviceContext,
&DeviceContext::default() as *const _ as *const DeviceContext,
)
}
}

View File

@@ -15,6 +15,9 @@ pub struct Field<const NUM_LIMBS: usize, F: FieldConfig> {
p: PhantomData<F>,
}
unsafe impl<const NUM_LIMBS: usize, F: FieldConfig> Send for Field<NUM_LIMBS, F> {}
unsafe impl<const NUM_LIMBS: usize, F: FieldConfig> Sync for Field<NUM_LIMBS, F> {}
impl<const NUM_LIMBS: usize, F: FieldConfig> Display for Field<NUM_LIMBS, F> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "0x")?;
@@ -165,7 +168,7 @@ macro_rules! impl_scalar_field {
impl_field!($num_limbs, $field_name, $field_cfg, $ark_equiv);
mod $field_prefix_ident {
use crate::curve::{get_default_device_context, $field_name, CudaError, DeviceContext, HostOrDeviceSlice};
use crate::curve::{$field_name, CudaError, DeviceContext, HostOrDeviceSlice};
extern "C" {
#[link_name = concat!($field_prefix, "GenerateScalars")]
@@ -189,7 +192,7 @@ macro_rules! impl_scalar_field {
scalars.as_mut_ptr(),
scalars.len(),
is_into,
&get_default_device_context() as *const _ as *const DeviceContext,
&DeviceContext::default() as *const _ as *const DeviceContext,
)
}
}

View File

@@ -1,6 +1,6 @@
use crate::curve::{Affine, Curve, Projective};
use crate::error::IcicleResult;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID};
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
#[cfg(feature = "arkworks")]
@@ -59,6 +59,33 @@ pub struct MSMConfig<'a> {
pub is_async: bool,
}
impl<'a> Default for MSMConfig<'a> {
fn default() -> Self {
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl<'a> MSMConfig<'a> {
pub fn default_for_device(device_id: usize) -> Self {
Self {
ctx: DeviceContext::default_for_device(device_id),
points_size: 0,
precompute_factor: 1,
c: 0,
bitsize: 0,
large_bucket_factor: 10,
batch_size: 1,
are_scalars_on_device: false,
are_scalars_montgomery_form: false,
are_points_on_device: false,
are_points_montgomery_form: false,
are_results_on_device: false,
is_big_triangle: false,
is_async: false,
}
}
}
#[doc(hidden)]
pub trait MSM<C: Curve> {
fn msm_unchecked(
@@ -67,8 +94,6 @@ pub trait MSM<C: Curve> {
cfg: &MSMConfig,
results: &mut HostOrDeviceSlice<Projective<C>>,
) -> IcicleResult<()>;
fn get_default_msm_config() -> MSMConfig<'static>;
}
/// Computes the multi-scalar multiplication, or MSM: `s1*P1 + s2*P2 + ... + sn*Pn`, or a batch of several MSMs.
@@ -113,11 +138,6 @@ pub fn msm<C: Curve + MSM<C>>(
C::msm_unchecked(scalars, points, &local_cfg, results)
}
/// Returns [MSM config](MSMConfig) struct populated with default values.
pub fn get_default_msm_config<C: Curve + MSM<C>>() -> MSMConfig<'static> {
C::get_default_msm_config()
}
#[macro_export]
macro_rules! impl_msm {
(
@@ -161,10 +181,6 @@ macro_rules! impl_msm {
.wrap()
}
}
fn get_default_msm_config() -> MSMConfig<'static> {
unsafe { $curve_prefix_indent::default_msm_config() }
}
}
};
}

View File

@@ -1,8 +1,11 @@
use super::{get_default_msm_config, msm, MSM};
use crate::curve::{Affine, Curve, Projective};
use crate::msm::{msm, MSMConfig, MSM};
use crate::traits::{FieldImpl, GenerateRandom};
use icicle_cuda_runtime::device::{get_device_count, set_device};
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use icicle_cuda_runtime::stream::CudaStream;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
#[cfg(feature = "arkworks")]
use crate::traits::ArkConvertible;
@@ -19,61 +22,73 @@ where
C::ScalarField: ArkConvertible<ArkEquivalent = <C::ArkSWConfig as ArkCurveConfig>::ScalarField>,
C::BaseField: ArkConvertible<ArkEquivalent = <C::ArkSWConfig as ArkCurveConfig>::BaseField>,
{
let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18];
let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap();
for test_size in test_sizes {
let points = C::generate_random_affine_points(test_size);
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size);
let points_ark: Vec<_> = points
.iter()
.map(|x| x.to_ark())
.collect();
let scalars_ark: Vec<_> = scalars
.iter()
.map(|x| x.to_ark())
.collect();
// if we simply transmute arkworks types, we'll get scalars or points in Montgomery format
// (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that)
let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) };
let device_count = get_device_count().unwrap();
(0..device_count) // TODO: this is proto-loadbalancer
.into_par_iter()
.for_each(move |device_id| {
//TODO: currently supported multi-GPU workflow:
// 1) User starts child host thread from parent host thread
// 2) Calls set_device once with selected device_id (0 is default device .. < device_count)
// 3) Perform all operations (without changing device on the thread)
// 4) If necessary - export results to parent host thread
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap();
let stream = CudaStream::create().unwrap();
scalars_d
.copy_from_host_async(&scalars_mont, &stream)
.unwrap();
set_device(device_id).unwrap();
let test_sizes = [4, 8, 16, 32, 64, 128, 256, 1000, 1 << 18];
let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap();
for test_size in test_sizes {
let points = C::generate_random_affine_points(test_size);
let scalars = <C::ScalarField as FieldImpl>::Config::generate_random(test_size);
let points_ark: Vec<_> = points
.iter()
.map(|x| x.to_ark())
.collect();
let scalars_ark: Vec<_> = scalars
.iter()
.map(|x| x.to_ark())
.collect();
// if we simply transmute arkworks types, we'll get scalars or points in Montgomery format
// (just beware the possible extra flag in affine point types, can't transmute ark Affine because of that)
let scalars_mont = unsafe { &*(&scalars_ark[..] as *const _ as *const [C::ScalarField]) };
let mut cfg = get_default_msm_config::<C>();
cfg.ctx
.stream = &stream;
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap();
// need to make sure that scalars_d weren't mutated by the previous call
let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size];
scalars_d
.copy_to_host_async(&mut scalars_mont_after, &stream)
.unwrap();
assert_eq!(scalars_mont, scalars_mont_after);
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size).unwrap();
let stream = CudaStream::create().unwrap();
scalars_d
.copy_from_host_async(&scalars_mont, &stream)
.unwrap();
let mut msm_host_result = vec![Projective::<C>::zero(); 1];
msm_results
.copy_to_host(&mut msm_host_result[..])
.unwrap();
stream
.synchronize()
.unwrap();
stream
.destroy()
.unwrap();
let mut cfg = MSMConfig::default_for_device(device_id);
cfg.ctx
.stream = &stream;
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
msm(&scalars_d, &HostOrDeviceSlice::on_host(points), &cfg, &mut msm_results).unwrap();
// need to make sure that scalars_d weren't mutated by the previous call
let mut scalars_mont_after = vec![C::ScalarField::zero(); test_size];
scalars_d
.copy_to_host_async(&mut scalars_mont_after, &stream)
.unwrap();
assert_eq!(scalars_mont, scalars_mont_after);
let msm_result_ark: ark_ec::models::short_weierstrass::Projective<C::ArkSWConfig> =
VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap();
let msm_res_affine: ark_ec::short_weierstrass::Affine<C::ArkSWConfig> = msm_host_result[0]
.to_ark()
.into();
assert!(msm_res_affine.is_on_curve());
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
}
let mut msm_host_result = vec![Projective::<C>::zero(); 1];
msm_results
.copy_to_host(&mut msm_host_result[..])
.unwrap();
stream
.synchronize()
.unwrap();
stream
.destroy()
.unwrap();
let msm_result_ark: ark_ec::models::short_weierstrass::Projective<C::ArkSWConfig> =
VariableBaseMSM::msm(&points_ark, &scalars_ark).unwrap();
let msm_res_affine: ark_ec::short_weierstrass::Affine<C::ArkSWConfig> = msm_host_result[0]
.to_ark()
.into();
assert!(msm_res_affine.is_on_curve());
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
}
});
}
pub fn check_msm_batch<C: Curve + MSM<C>>()
@@ -104,7 +119,7 @@ where
.copy_from_host_async(&points_cloned, &stream)
.unwrap();
let mut cfg = get_default_msm_config::<C>();
let mut cfg = MSMConfig::default();
cfg.ctx
.stream = &stream;
cfg.is_async = true;
@@ -181,7 +196,7 @@ where
let mut msm_results = HostOrDeviceSlice::on_host(vec![Projective::<C>::zero(); batch_size]);
let mut cfg = get_default_msm_config::<C>();
let mut cfg = MSMConfig::default();
if test_size < test_threshold {
cfg.bitsize = 1;
}

View File

@@ -1,4 +1,4 @@
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID};
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use crate::{error::IcicleResult, traits::FieldImpl};
@@ -89,11 +89,16 @@ pub struct NTTConfig<'a, S> {
pub ntt_algorithm: NttAlgorithm,
}
impl<'a, S: FieldImpl> Default for NTTConfig<'a, S> {
fn default() -> Self {
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
pub fn default_config() -> Self {
let ctx = get_default_device_context();
pub fn default_for_device(device_id: usize) -> Self {
NTTConfig {
ctx,
ctx: DeviceContext::default_for_device(device_id),
coset_gen: S::one(),
batch_size: 1,
ordering: Ordering::kNN,
@@ -114,7 +119,6 @@ pub trait NTT<F: FieldImpl> {
output: &mut HostOrDeviceSlice<F>,
) -> IcicleResult<()>;
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
fn get_default_ntt_config() -> NTTConfig<'static, F>;
}
/// Computes the NTT, or a batch of several NTTs.
@@ -169,15 +173,6 @@ where
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain(primitive_root, ctx)
}
/// Returns [NTT config](NTTConfig) struct populated with default values.
pub fn get_default_ntt_config<F>() -> NTTConfig<'static, F>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
<<F as FieldImpl>::Config as NTT<F>>::get_default_ntt_config()
}
#[macro_export]
macro_rules! impl_ntt {
(
@@ -187,7 +182,7 @@ macro_rules! impl_ntt {
$field_config:ident
) => {
mod $field_prefix_ident {
use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir};
use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir, DEFAULT_DEVICE_ID};
extern "C" {
#[link_name = concat!($field_prefix, "NTTCuda")]
@@ -226,10 +221,6 @@ macro_rules! impl_ntt {
fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() }
}
fn get_default_ntt_config() -> NTTConfig<'static, $field> {
NTTConfig::<$field>::default_config()
}
}
};
}
@@ -244,31 +235,31 @@ macro_rules! impl_ntt_tests {
#[test]
fn test_ntt() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
check_ntt::<$field>()
}
#[test]
fn test_ntt_coset_from_subgroup() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
check_ntt_coset_from_subgroup::<$field>()
}
#[test]
fn test_ntt_arbitrary_coset() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
check_ntt_arbitrary_coset::<$field>()
}
#[test]
fn test_ntt_batch() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
check_ntt_batch::<$field>()
}
#[test]
fn test_ntt_device_async() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
// init_domain is in this test is performed per-device
check_ntt_device_async::<$field>()
}
};

View File

@@ -1,23 +1,27 @@
use ark_ff::{FftField, Field as ArkField, One};
use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
use ark_std::{ops::Neg, test_rng, UniformRand};
use icicle_cuda_runtime::device_context::get_default_device_context;
use icicle_cuda_runtime::device::get_device_count;
use icicle_cuda_runtime::device::set_device;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use icicle_cuda_runtime::stream::CudaStream;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use crate::{
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
ntt::{initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
traits::{ArkConvertible, FieldImpl, GenerateRandom},
};
use super::NTTConfig;
use super::NTT;
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64)
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize)
where
F::ArkEquivalent: FftField,
<F as FieldImpl>::Config: NTT<F>,
{
let ctx = get_default_device_context();
let ctx = DeviceContext::default_for_device(device_id);
let ark_rou = F::ArkEquivalent::get_root_of_unity(max_size).unwrap();
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
}
@@ -61,7 +65,7 @@ where
let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) };
let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec());
let mut config = get_default_ntt_config();
let mut config = NTTConfig::default();
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ntt_algorithm = alg;
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
@@ -107,7 +111,7 @@ where
.collect::<Vec<F::ArkEquivalent>>();
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
let mut config = get_default_ntt_config();
let mut config = NTTConfig::default();
config.ordering = Ordering::kNR;
config.ntt_algorithm = alg;
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
@@ -188,7 +192,7 @@ where
.map(|v| F::ArkEquivalent::from_random_bytes(&v.to_bytes_le()).unwrap())
.collect::<Vec<F::ArkEquivalent>>();
let mut config = get_default_ntt_config();
let mut config = NTTConfig::default();
config.coset_gen = F::from_ark(coset_gen);
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ordering = Ordering::kNR;
@@ -229,7 +233,7 @@ where
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
let mut config = get_default_ntt_config();
let mut config = NTTConfig::default();
for batch_size in batch_sizes {
let scalars = HostOrDeviceSlice::on_host(F::Config::generate_random(test_size * batch_size));
@@ -279,55 +283,65 @@ where
F::ArkEquivalent: FftField,
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
{
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
let stream = CudaStream::create().unwrap();
let mut config = get_default_ntt_config();
for batch_size in batch_sizes {
let scalars_h: Vec<F> = F::Config::generate_random(test_size * batch_size);
let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size]
.iter()
.map(|x| x.to_ark())
.sum();
let mut scalars_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
scalars_d
.copy_from_host_async(&scalars_h, &stream)
.unwrap();
let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
let device_count = get_device_count().unwrap();
for coset_gen in coset_generators {
for ordering in [Ordering::kNN, Ordering::kRR] {
config.coset_gen = coset_gen;
config.ordering = ordering;
config.batch_size = batch_size as i32;
config.is_async = true;
config
.ctx
.stream = &stream;
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ntt_algorithm = alg;
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
scalars_d
.copy_to_host_async(&mut intt_result_h, &stream)
.unwrap();
stream
.synchronize()
.unwrap();
assert_eq!(scalars_h, intt_result_h);
if coset_gen == F::one() {
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
ntt_out_d
.copy_to_host(&mut ntt_result_h)
.unwrap();
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
(0..device_count)
.into_par_iter()
.for_each(move |device_id| {
set_device(device_id).unwrap();
init_domain::<F>(1 << 16, device_id); // init domain per device
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
let mut config = NTTConfig::default_for_device(device_id);
let stream = config
.ctx
.stream;
for batch_size in batch_sizes {
let scalars_h: Vec<F> = F::Config::generate_random(test_size * batch_size);
let sum_of_coeffs: F::ArkEquivalent = scalars_h[..test_size]
.iter()
.map(|x| x.to_ark())
.sum();
let mut scalars_d = HostOrDeviceSlice::cuda_malloc(test_size * batch_size).unwrap();
scalars_d
.copy_from_host(&scalars_h)
.unwrap();
let mut ntt_out_d = HostOrDeviceSlice::cuda_malloc_async(test_size * batch_size, &stream).unwrap();
for coset_gen in coset_generators {
for ordering in [Ordering::kNN, Ordering::kRR] {
config.coset_gen = coset_gen;
config.ordering = ordering;
config.batch_size = batch_size as i32;
config.is_async = true;
config
.ctx
.stream = &stream;
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ntt_algorithm = alg;
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
scalars_d
.copy_to_host_async(&mut intt_result_h, &stream)
.unwrap();
stream
.synchronize()
.unwrap();
assert_eq!(scalars_h, intt_result_h);
if coset_gen == F::one() {
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
ntt_out_d
.copy_to_host(&mut ntt_result_h)
.unwrap();
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
}
}
}
}
}
}
}
}
});
}

View File

@@ -2,7 +2,7 @@
pub mod tests;
use icicle_cuda_runtime::{
device_context::{get_default_device_context, DeviceContext},
device_context::{DeviceContext, DEFAULT_DEVICE_ID},
memory::HostOrDeviceSlice,
};
@@ -60,9 +60,14 @@ pub struct PoseidonConfig<'a> {
impl<'a> Default for PoseidonConfig<'a> {
fn default() -> Self {
let ctx = get_default_device_context();
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl<'a> PoseidonConfig<'a> {
pub fn default_for_device(device_id: usize) -> Self {
Self {
ctx,
ctx: DeviceContext::default_for_device(device_id),
are_inputs_on_device: false,
are_outputs_on_device: false,
input_is_a_state: false,

View File

@@ -1,5 +1,5 @@
use crate::traits::FieldImpl;
use icicle_cuda_runtime::device_context::get_default_device_context;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use std::io::Read;
@@ -15,7 +15,7 @@ pub fn init_poseidon<'a, F: FieldImpl>(arity: u32) -> PoseidonConstants<'a, F>
where
<F as FieldImpl>::Config: Poseidon<F>,
{
let ctx = get_default_device_context();
let ctx = DeviceContext::default();
load_optimized_poseidon_constants::<F>(arity, &ctx).unwrap()
}
@@ -71,7 +71,7 @@ where
let full_rounds_half = 4;
let ctx = get_default_device_context();
let ctx = DeviceContext::default();
let cargo_manifest_dir = env!("CARGO_MANIFEST_DIR");
let constants_file = PathBuf::from(cargo_manifest_dir)
.join("tests")

View File

@@ -18,7 +18,9 @@ pub trait FieldConfig: Debug + PartialEq + Copy + Clone {
type ArkField: ArkField;
}
pub trait FieldImpl: Display + Debug + PartialEq + Copy + Clone + Into<Self::Repr> + From<Self::Repr> {
pub trait FieldImpl:
Display + Debug + PartialEq + Copy + Clone + Into<Self::Repr> + From<Self::Repr> + Send + Sync
{
#[doc(hidden)]
type Config: FieldConfig;
type Repr;

View File

@@ -1,5 +1,5 @@
use icicle_cuda_runtime::{
device_context::{get_default_device_context, DeviceContext},
device_context::{DeviceContext, DEFAULT_DEVICE_ID},
memory::HostOrDeviceSlice,
};
@@ -28,9 +28,13 @@ pub struct TreeBuilderConfig<'a> {
impl<'a> Default for TreeBuilderConfig<'a> {
fn default() -> Self {
let ctx = get_default_device_context();
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl<'a> TreeBuilderConfig<'a> {
fn default_for_device(device_id: usize) -> Self {
Self {
ctx,
ctx: DeviceContext::default_for_device(device_id),
keep_rows: 0,
are_inputs_on_device: false,
is_async: false,

View File

@@ -45,7 +45,9 @@ fn main() {
.must_use_type("cudaError")
// device management
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
.allowlist_function("cudaGetDevice")
.allowlist_function("cudaSetDevice")
.allowlist_function("cudaGetDeviceCount")
// error handling
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__ERROR.html
.allowlist_function("cudaGetLastError")
@@ -69,6 +71,7 @@ fn main() {
.allowlist_function("cudaMemcpyAsync")
.allowlist_function("cudaMemset")
.allowlist_function("cudaMemsetAsync")
.allowlist_function("cudaDeviceGetDefaultMemPool")
.rustified_enum("cudaMemcpyKind")
// Stream Ordered Memory Allocator
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html

View File

@@ -0,0 +1,18 @@
use crate::{
bindings::{cudaGetDevice, cudaGetDeviceCount, cudaSetDevice},
error::{CudaResult, CudaResultWrap},
};
pub fn set_device(device_id: usize) -> CudaResult<()> {
unsafe { cudaSetDevice(device_id as i32) }.wrap()
}
pub fn get_device_count() -> CudaResult<usize> {
let mut count = 0;
unsafe { cudaGetDeviceCount(&mut count) }.wrap_value(count as usize)
}
pub fn get_device() -> CudaResult<usize> {
let mut device_id = 0;
unsafe { cudaGetDevice(&mut device_id) }.wrap_value(device_id as usize)
}

View File

@@ -1,27 +1,47 @@
use crate::memory::CudaMemPool;
use crate::stream::CudaStream;
/// Properties of the device used in icicle functions.
pub const DEFAULT_DEVICE_ID: usize = 0;
use crate::device::get_device;
/// Properties of the device used in Icicle functions.
#[repr(C)]
#[derive(Debug, Clone)]
pub struct DeviceContext<'a> {
/// Stream to use. Default value: 0.
/// Stream to use. Default value: 0. //TODO: multiple streams per device ?
pub stream: &'a CudaStream, // Assuming the type is provided by a CUDA binding crate
/// Index of the currently used GPU. Default value: 0.
pub device_id: usize,
/// Mempool to use. Default value: 0.
/// Mempool to use. Default value: 0. //TODO: multiple mempools per device ?
pub mempool: CudaMemPool, // Assuming the type is provided by a CUDA binding crate
}
pub fn get_default_device_context() -> DeviceContext<'static> {
static default_stream: CudaStream = CudaStream {
handle: std::ptr::null_mut(),
};
DeviceContext {
stream: &default_stream,
device_id: 0,
mempool: 0,
impl Default for DeviceContext<'_> {
fn default() -> Self {
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl DeviceContext<'_> {
/// Default for device_id
pub fn default_for_device(device_id: usize) -> DeviceContext<'static> {
static default_stream: CudaStream = CudaStream {
handle: std::ptr::null_mut(),
};
DeviceContext {
stream: &default_stream,
device_id,
mempool: std::ptr::null_mut(),
}
}
}
pub fn check_device(device_id: i32) {
match device_id == get_device().unwrap() as i32 {
true => (),
false => panic!("Attempt to use on a different device"),
}
}

View File

@@ -3,6 +3,7 @@
#[allow(dead_code)]
mod bindings;
pub mod device;
pub mod device_context;
pub mod error;
pub mod memory;

View File

@@ -1,4 +1,8 @@
use crate::bindings::{cudaFree, cudaMalloc, cudaMallocAsync, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
use crate::bindings::{
cudaFree, cudaMalloc, cudaMallocAsync, cudaMemPool_t, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind,
};
use crate::device::get_device;
use crate::device_context::check_device;
use crate::error::{CudaError, CudaResult, CudaResultWrap};
use crate::stream::CudaStream;
use std::mem::{size_of, MaybeUninit};
@@ -8,60 +12,69 @@ use std::slice::from_raw_parts_mut;
pub enum HostOrDeviceSlice<'a, T> {
Host(Vec<T>),
Device(&'a mut [T]),
Device(&'a mut [T], i32),
}
impl<'a, T> HostOrDeviceSlice<'a, T> {
// Function to get the device_id for Device variant
pub fn get_device_id(&self) -> Option<i32> {
match self {
HostOrDeviceSlice::Device(_, device_id) => Some(*device_id),
HostOrDeviceSlice::Host(_) => None,
}
}
pub fn len(&self) -> usize {
match self {
Self::Device(s) => s.len(),
Self::Device(s, _) => s.len(),
Self::Host(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Device(s) => s.is_empty(),
Self::Device(s, _) => s.is_empty(),
Self::Host(v) => v.is_empty(),
}
}
pub fn is_on_device(&self) -> bool {
match self {
Self::Device(_) => true,
Self::Device(_, _) => true,
Self::Host(_) => false,
}
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
match self {
Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
Self::Host(v) => v.as_mut_slice(),
}
}
pub fn as_slice(&self) -> &[T] {
match self {
Self::Device(_) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
Self::Device(_, _) => panic!("Use copy_to_host and copy_to_host_async to move device data to a slice"),
Self::Host(v) => v.as_slice(),
}
}
pub fn as_ptr(&self) -> *const T {
match self {
Self::Device(s) => s.as_ptr(),
Self::Device(s, _) => s.as_ptr(),
Self::Host(v) => v.as_ptr(),
}
}
pub fn as_mut_ptr(&mut self) -> *mut T {
match self {
Self::Device(s) => s.as_mut_ptr(),
Self::Device(s, _) => s.as_mut_ptr(),
Self::Host(v) => v.as_mut_ptr(),
}
}
pub fn on_host(src: Vec<T>) -> Self {
//TODO: HostOrDeviceSlice on_host() with slice input without actually copying the data
Self::Host(src)
}
@@ -70,16 +83,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
.checked_mul(size_of::<T>())
.unwrap_or(0);
if size == 0 {
return Err(CudaError::cudaErrorMemoryAllocation);
return Err(CudaError::cudaErrorMemoryAllocation); //TODO: only CUDA backend should return CudaError
}
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
unsafe {
cudaMalloc(device_ptr.as_mut_ptr(), size).wrap()?;
Ok(Self::Device(from_raw_parts_mut(
device_ptr.assume_init() as *mut T,
count,
)))
Ok(Self::Device(
from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
get_device().unwrap() as i32,
))
}
}
@@ -94,16 +107,16 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
unsafe {
cudaMallocAsync(device_ptr.as_mut_ptr(), size, stream.handle as *mut _ as *mut _).wrap()?;
Ok(Self::Device(from_raw_parts_mut(
device_ptr.assume_init() as *mut T,
count,
)))
Ok(Self::Device(
from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
get_device().unwrap() as i32,
))
}
}
pub fn copy_from_host(&mut self, val: &[T]) -> CudaResult<()> {
match self {
Self::Device(_) => {}
Self::Device(_, device_id) => check_device(*device_id),
Self::Host(_) => panic!("Need device memory to copy into, and not host"),
};
assert!(
@@ -127,7 +140,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
pub fn copy_to_host(&self, val: &mut [T]) -> CudaResult<()> {
match self {
Self::Device(_) => {}
Self::Device(_, device_id) => check_device(*device_id),
Self::Host(_) => panic!("Need device memory to copy from, and not host"),
};
assert!(
@@ -151,7 +164,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
pub fn copy_from_host_async(&mut self, val: &[T], stream: &CudaStream) -> CudaResult<()> {
match self {
Self::Device(_) => {}
Self::Device(_, device_id) => check_device(*device_id),
Self::Host(_) => panic!("Need device memory to copy into, and not host"),
};
assert!(
@@ -176,7 +189,7 @@ impl<'a, T> HostOrDeviceSlice<'a, T> {
pub fn copy_to_host_async(&self, val: &mut [T], stream: &CudaStream) -> CudaResult<()> {
match self {
Self::Device(_) => {}
Self::Device(_, device_id) => check_device(*device_id),
Self::Host(_) => panic!("Need device memory to copy from, and not host"),
};
assert!(
@@ -209,7 +222,7 @@ macro_rules! impl_index {
fn index(&self, index: $t) -> &Self::Output {
match self {
Self::Device(s) => s.index(index),
Self::Device(s, _) => s.index(index),
Self::Host(v) => v.index(index),
}
}
@@ -219,7 +232,7 @@ macro_rules! impl_index {
{
fn index_mut(&mut self, index: $t) -> &mut Self::Output {
match self {
Self::Device(s) => s.index_mut(index),
Self::Device(s,_) => s.index_mut(index),
Self::Host(v) => v.index_mut(index),
}
}
@@ -239,7 +252,8 @@ impl_index! {
impl<'a, T> Drop for HostOrDeviceSlice<'a, T> {
fn drop(&mut self) {
match self {
Self::Device(s) => {
Self::Device(s, device_id) => {
check_device(*device_id);
if s.is_empty() {
return;
}
@@ -256,4 +270,4 @@ impl<'a, T> Drop for HostOrDeviceSlice<'a, T> {
}
#[allow(non_camel_case_types)]
pub type CudaMemPool = usize; // This is a placeholder, TODO: actually make this into a proper CUDA wrapper
pub type CudaMemPool = cudaMemPool_t;

View File

@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
use icicle_core::field::{Field, MontgomeryConvertibleField};
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -7,6 +7,7 @@ use icicle_core::impl_ntt;
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
use icicle_core::traits::IcicleResultWrap;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
@@ -17,6 +18,7 @@ impl_ntt!("bw6_761", bw6_761, BaseField, BaseCfg);
#[cfg(test)]
pub(crate) mod tests {
use crate::curve::ScalarField;
use crate::ntt::DEFAULT_DEVICE_ID;
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use std::sync::OnceLock;

View File

@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
use icicle_core::field::{Field, MontgomeryConvertibleField};
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -5,6 +5,7 @@ use icicle_core::impl_ntt;
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
use icicle_core::traits::IcicleResultWrap;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
@@ -13,6 +14,7 @@ impl_ntt!("bls12_381", bls12_381, ScalarField, ScalarCfg);
#[cfg(test)]
pub(crate) mod tests {
use crate::curve::ScalarField;
use crate::ntt::DEFAULT_DEVICE_ID;
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use std::sync::OnceLock;

View File

@@ -6,7 +6,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
use icicle_core::field::{Field, MontgomeryConvertibleField};
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
use icicle_core::{impl_curve, impl_field, impl_scalar_field};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -5,6 +5,7 @@ use icicle_core::impl_ntt;
use icicle_core::ntt::{NTTConfig, NTTDir, NTT};
use icicle_core::traits::IcicleResultWrap;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
@@ -13,6 +14,7 @@ impl_ntt!("bn254", bn254, ScalarField, ScalarCfg);
#[cfg(test)]
pub(crate) mod tests {
use crate::curve::ScalarField;
use crate::ntt::DEFAULT_DEVICE_ID;
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use std::sync::OnceLock;

View File

@@ -7,7 +7,7 @@ use icicle_core::curve::{Affine, Curve, Projective};
use icicle_core::field::Field;
use icicle_core::traits::FieldConfig;
use icicle_core::{impl_curve, impl_field};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -3,6 +3,7 @@ pub(crate) mod tests {
use crate::curve::ScalarField;
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use icicle_cuda_runtime::device_context::DEFAULT_DEVICE_ID;
use std::sync::OnceLock;
impl_ntt_tests!(ScalarField);