mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
Update Rust apis (#262)
* fix memory error in single_stage_multi_reduction_kernel (#235) * refactor * refactor * revert * refactor: clang format * Update icicle/appUtils/msm/msm.cu * Added separate device context struct, returned lde * wip - msm and eq * added lde to cmake * Montgomery param added in lde.cu mul function * fixed on_device for ntt and lde * CamelCase * fixed msm_test, int unification, google guilde * wip - ntt crash debugging * async MSM with a rust wrapper * wip ntt tests with corretness * hotfix for correctness > 2^9 * wip on device inout mixing with correctness * cleanup * preserving twiddles after first call * fixed twiddles preserving * formatting * removed some printing * disable ecntt temporarily * format * rust fmt * exclude target from format * passing ntt after merge * hotfix for linking issue * format * format * draft of pr comments + correctness restored * wip refactor + format * domain wip * rust format * Merged feature branch in and Rust MSM correctness * rust build for correct curve * Slowdown fixed by passing release flag to cmake * WIP field and curve * still wip field and curve * field and curve in rust 1.0 * Refactored rust into several crates * Arkworks is now an option, bn254 crate created * Rust msm and ntt wip * A version of rust msm done, cuda runtime wrapped * refactor rust by creating a curve folder * vec_ops instead of lde for now * format --------- Co-authored-by: ImmanuelSegol <3ditds@gmail.com> Co-authored-by: Vitalii <vitalii@ingonyama.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@
|
||||
**/.DS_Store
|
||||
**/Cargo.lock
|
||||
**/icicle/build/
|
||||
**/wrappers/rust/icicle-cuda-runtime/src/bindings.rs
|
||||
|
||||
49
Cargo.toml
49
Cargo.toml
@@ -1,49 +0,0 @@
|
||||
[package]
|
||||
name = "icicle"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
description = "An implementation of the Ingonyama CUDA Library"
|
||||
homepage = "https://www.ingonyama.com"
|
||||
repository = "https://github.com/ingonyama-zk/icicle"
|
||||
|
||||
[[bench]]
|
||||
name = "ntt"
|
||||
path = "benches/ntt.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "msm"
|
||||
path = "benches/msm.rs"
|
||||
harness = false
|
||||
|
||||
[dependencies]
|
||||
hex = "*"
|
||||
ark-std = "0.3.0"
|
||||
ark-ff = "0.3.0"
|
||||
ark-poly = "0.3.0"
|
||||
ark-ec = { version = "0.3.0", features = [ "parallel" ] }
|
||||
ark-bls12-381 = "0.3.0"
|
||||
ark-bls12-377 = "0.3.0"
|
||||
ark-bn254 = "0.3.0"
|
||||
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_derive = "1.0"
|
||||
serde_cbor = "0.11.2"
|
||||
|
||||
rustacuda = "0.1"
|
||||
rustacuda_core = "0.1"
|
||||
rustacuda_derive = "0.1"
|
||||
|
||||
rand = "*" #TODO: move rand and ark dependencies to dev once random scalar/point generation is done "natively"
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"] }
|
||||
|
||||
[dev-dependencies]
|
||||
"criterion" = "0.4.0"
|
||||
|
||||
[features]
|
||||
default = ["bls12_381"]
|
||||
bls12_381 = ["ark-bls12-381/curve"]
|
||||
g2 = []
|
||||
@@ -1,54 +0,0 @@
|
||||
extern crate criterion;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use icicle::test_bls12_381::{commit_batch_bls12_381, generate_random_points_bls12_381, set_up_scalars_bls12_381};
|
||||
use icicle::utils::*;
|
||||
#[cfg(feature = "g2")]
|
||||
use icicle::{commit_batch_g2, field::ExtensionField};
|
||||
|
||||
use rustacuda::prelude::*;
|
||||
|
||||
const LOG_MSM_SIZES: [usize; 1] = [12];
|
||||
const BATCH_SIZES: [usize; 2] = [128, 256];
|
||||
|
||||
fn bench_msm(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("MSM");
|
||||
for log_msm_size in LOG_MSM_SIZES {
|
||||
for batch_size in BATCH_SIZES {
|
||||
let msm_size = 1 << log_msm_size;
|
||||
let (scalars, _, _) = set_up_scalars_bls12_381(msm_size, 0, false);
|
||||
let batch_scalars = vec![scalars; batch_size].concat();
|
||||
let mut d_scalars = DeviceBuffer::from_slice(&batch_scalars[..]).unwrap();
|
||||
|
||||
let points = generate_random_points_bls12_381(msm_size, get_rng(None));
|
||||
let batch_points = vec![points; batch_size].concat();
|
||||
let mut d_points = DeviceBuffer::from_slice(&batch_points[..]).unwrap();
|
||||
|
||||
#[cfg(feature = "g2")]
|
||||
let g2_points = generate_random_points::<ExtensionField>(msm_size, get_rng(None));
|
||||
#[cfg(feature = "g2")]
|
||||
let g2_batch_points = vec![g2_points; batch_size].concat();
|
||||
#[cfg(feature = "g2")]
|
||||
let mut d_g2_points = DeviceBuffer::from_slice(&g2_batch_points[..]).unwrap();
|
||||
|
||||
group
|
||||
.sample_size(30)
|
||||
.bench_function(
|
||||
&format!("MSM of size 2^{} in batch {}", log_msm_size, batch_size),
|
||||
|b| b.iter(|| commit_batch_bls12_381(&mut d_points, &mut d_scalars, batch_size)),
|
||||
);
|
||||
|
||||
#[cfg(feature = "g2")]
|
||||
group
|
||||
.sample_size(10)
|
||||
.bench_function(
|
||||
&format!("G2 MSM of size 2^{} in batch {}", log_msm_size, batch_size),
|
||||
|b| b.iter(|| commit_batch_g2(&mut d_g2_points, &mut d_scalars, batch_size)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(msm_benches, bench_msm);
|
||||
criterion_main!(msm_benches);
|
||||
@@ -1,85 +0,0 @@
|
||||
extern crate criterion;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use icicle::test_bls12_381::*;
|
||||
|
||||
const LOG_NTT_SIZES: [usize; 3] = [20, 9, 10];
|
||||
const BATCH_SIZES: [usize; 3] = [1, 512, 1024];
|
||||
|
||||
fn bench_ntt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("NTT");
|
||||
for log_ntt_size in LOG_NTT_SIZES {
|
||||
for batch_size in BATCH_SIZES {
|
||||
let ntt_size = 1 << log_ntt_size;
|
||||
|
||||
if ntt_size * batch_size > 1 << 25 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let scalar_samples = 20;
|
||||
|
||||
let (_, mut d_evals, mut d_domain) = set_up_scalars_bls12_381(ntt_size * batch_size, log_ntt_size, true);
|
||||
|
||||
group
|
||||
.sample_size(scalar_samples)
|
||||
.bench_function(
|
||||
&format!("Scalar NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| evaluate_scalars_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size)),
|
||||
);
|
||||
|
||||
group
|
||||
.sample_size(scalar_samples)
|
||||
.bench_function(
|
||||
&format!("Scalar iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| interpolate_scalars_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size)),
|
||||
);
|
||||
|
||||
group
|
||||
.sample_size(scalar_samples)
|
||||
.bench_function(
|
||||
&format!("Scalar inplace NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| ntt_inplace_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size, false, 0)),
|
||||
);
|
||||
|
||||
group
|
||||
.sample_size(scalar_samples)
|
||||
.bench_function(
|
||||
&format!("Scalar inplace iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| ntt_inplace_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size, true, 0)),
|
||||
);
|
||||
|
||||
drop(d_evals);
|
||||
drop(d_domain);
|
||||
|
||||
if ntt_size * batch_size > 1 << 18 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let point_samples = 10;
|
||||
|
||||
let (_, mut d_points_evals, mut d_domain) =
|
||||
set_up_points_bls12_381(ntt_size * batch_size, log_ntt_size, true);
|
||||
|
||||
group
|
||||
.sample_size(point_samples)
|
||||
.bench_function(
|
||||
&format!("EC NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| interpolate_points_batch_bls12_381(&mut d_points_evals, &mut d_domain, batch_size)),
|
||||
);
|
||||
|
||||
group
|
||||
.sample_size(point_samples)
|
||||
.bench_function(
|
||||
&format!("EC iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|
||||
|b| b.iter(|| evaluate_points_batch_bls12_381(&mut d_points_evals, &mut d_domain, batch_size)),
|
||||
);
|
||||
|
||||
drop(d_points_evals);
|
||||
drop(d_domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(ntt_benches, bench_ntt);
|
||||
criterion_main!(ntt_benches);
|
||||
31
build.rs
31
build.rs
@@ -1,31 +0,0 @@
|
||||
use std::env;
|
||||
|
||||
fn main() {
|
||||
//TODO: check cargo features selected
|
||||
//TODO: can conflict/duplicate with make ?
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CXXFLAGS");
|
||||
println!("cargo:rerun-if-changed=./icicle");
|
||||
|
||||
let arch_type = env::var("ARCH_TYPE").unwrap_or(String::from("native"));
|
||||
let stream_type = env::var("DEFAULT_STREAM").unwrap_or(String::from("legacy"));
|
||||
|
||||
let mut arch = String::from("-arch=");
|
||||
arch.push_str(&arch_type);
|
||||
let mut stream = String::from("-default-stream=");
|
||||
stream.push_str(&stream_type);
|
||||
|
||||
let mut nvcc = cc::Build::new();
|
||||
|
||||
println!("Compiling icicle library using arch: {}", &arch);
|
||||
|
||||
if cfg!(feature = "g2") {
|
||||
nvcc.define("G2_DEFINED", None);
|
||||
}
|
||||
nvcc.cuda(true);
|
||||
nvcc.debug(false);
|
||||
nvcc.flag(&arch);
|
||||
nvcc.flag(&stream);
|
||||
nvcc.files(["./icicle/curves/index.cu"]);
|
||||
nvcc.compile("ingo_icicle"); //TODO: extension??
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use icicle::{curves::bls12_381::ScalarField_BLS12_381, test_bls12_381::*};
|
||||
use rustacuda::prelude::DeviceBuffer;
|
||||
|
||||
const LOG_NTT_SIZES: [usize; 3] = [20, 10, 9];
|
||||
const BATCH_SIZES: [usize; 3] = [1, 1 << 9, 1 << 10];
|
||||
|
||||
const MAX_POINTS_LOG2: usize = 18;
|
||||
const MAX_SCALARS_LOG2: usize = 26;
|
||||
|
||||
fn bench_lde() {
|
||||
for log_ntt_size in LOG_NTT_SIZES {
|
||||
for batch_size in BATCH_SIZES {
|
||||
let ntt_size = 1 << log_ntt_size;
|
||||
|
||||
fn ntt_scalars_batch_bls12_381(
|
||||
d_inout: &mut DeviceBuffer<ScalarField_BLS12_381>,
|
||||
d_twiddles: &mut DeviceBuffer<ScalarField_BLS12_381>,
|
||||
batch_size: usize,
|
||||
) -> i32 {
|
||||
ntt_inplace_batch_bls12_381(d_inout, d_twiddles, batch_size, false, 0);
|
||||
0
|
||||
}
|
||||
|
||||
fn intt_scalars_batch_bls12_381(
|
||||
d_inout: &mut DeviceBuffer<ScalarField_BLS12_381>,
|
||||
d_twiddles: &mut DeviceBuffer<ScalarField_BLS12_381>,
|
||||
batch_size: usize,
|
||||
) -> i32 {
|
||||
ntt_inplace_batch_bls12_381(d_inout, d_twiddles, batch_size, true, 0);
|
||||
0
|
||||
}
|
||||
|
||||
// copy
|
||||
bench_ntt_template(
|
||||
MAX_SCALARS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_scalars_bls12_381,
|
||||
evaluate_scalars_batch_bls12_381,
|
||||
"NTT",
|
||||
false,
|
||||
100,
|
||||
);
|
||||
|
||||
bench_ntt_template(
|
||||
MAX_SCALARS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_scalars_bls12_381,
|
||||
interpolate_scalars_batch_bls12_381,
|
||||
"iNTT",
|
||||
true,
|
||||
100,
|
||||
);
|
||||
|
||||
bench_ntt_template(
|
||||
MAX_POINTS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_points_bls12_381,
|
||||
evaluate_points_batch_bls12_381,
|
||||
"EC NTT",
|
||||
false,
|
||||
20,
|
||||
);
|
||||
|
||||
bench_ntt_template(
|
||||
MAX_POINTS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_points_bls12_381,
|
||||
interpolate_points_batch_bls12_381,
|
||||
"EC iNTT",
|
||||
true,
|
||||
20,
|
||||
);
|
||||
|
||||
// inplace
|
||||
bench_ntt_template(
|
||||
MAX_SCALARS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_scalars_bls12_381,
|
||||
ntt_scalars_batch_bls12_381,
|
||||
"NTT inplace",
|
||||
false,
|
||||
100,
|
||||
);
|
||||
|
||||
bench_ntt_template(
|
||||
MAX_SCALARS_LOG2,
|
||||
ntt_size,
|
||||
batch_size,
|
||||
log_ntt_size,
|
||||
set_up_scalars_bls12_381,
|
||||
intt_scalars_batch_bls12_381,
|
||||
"iNTT inplace",
|
||||
true,
|
||||
100,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_ntt_template<E, S, R>(
|
||||
log_max_size: usize,
|
||||
ntt_size: usize,
|
||||
batch_size: usize,
|
||||
log_ntt_size: usize,
|
||||
set_data: fn(test_size: usize, log_domain_size: usize, inverse: bool) -> (Vec<E>, DeviceBuffer<E>, DeviceBuffer<S>),
|
||||
bench_fn: fn(d_evaluations: &mut DeviceBuffer<E>, d_domain: &mut DeviceBuffer<S>, batch_size: usize) -> R,
|
||||
id: &str,
|
||||
inverse: bool,
|
||||
samples: usize,
|
||||
) -> Option<(Vec<E>, R)> {
|
||||
let count = ntt_size * batch_size;
|
||||
|
||||
let bench_id = format!("{} of size 2^{} in batch {}", id, log_ntt_size, batch_size);
|
||||
|
||||
if count > 1 << log_max_size {
|
||||
println!("Bench size exceeded: {}", bench_id);
|
||||
return None;
|
||||
}
|
||||
|
||||
println!("{}", bench_id);
|
||||
|
||||
let (input, mut d_evals, mut d_domain) = set_data(ntt_size * batch_size, log_ntt_size, inverse);
|
||||
|
||||
let first = bench_fn(&mut d_evals, &mut d_domain, batch_size);
|
||||
|
||||
let start = Instant::now();
|
||||
for _ in 0..samples {
|
||||
bench_fn(&mut d_evals, &mut d_domain, batch_size);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
println!(
|
||||
"{} {:0?} us x {} = {:?}",
|
||||
bench_id,
|
||||
elapsed.as_micros() as f32 / (samples as f32),
|
||||
samples,
|
||||
elapsed
|
||||
);
|
||||
|
||||
Some((input, first))
|
||||
}
|
||||
|
||||
fn main() {
|
||||
bench_lde();
|
||||
}
|
||||
@@ -19,8 +19,8 @@ set(CMAKE_CUDA_FLAGS_RELEASE "")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
|
||||
|
||||
|
||||
# when adding a new curve, append its name to the end of this list
|
||||
set(SUPPORTED_CURVES bn254;bls12_381;bls12_377)
|
||||
# when adding a new curve/field, append its name to the end of this list
|
||||
set(SUPPORTED_CURVES bn254;bls12_381;bls12_377;bw6_671)
|
||||
|
||||
set(IS_CURVE_SUPPORTED FALSE)
|
||||
set(I 0)
|
||||
@@ -42,25 +42,48 @@ if (NOT BUILD_TESTS)
|
||||
|
||||
add_library(
|
||||
icicle
|
||||
utils/error_handler.cu
|
||||
utils/utils_kernels.cu
|
||||
utils/vec_ops.cu
|
||||
primitives/field.cu
|
||||
primitives/projective.cu
|
||||
appUtils/msm/msm.cu
|
||||
appUtils/ntt/ntt.cu
|
||||
appUtils/lde/lde.cu
|
||||
)
|
||||
#set_target_properties(icicle PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
||||
#set_target_properties(icicle PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
# # cuda_add_library(cudanet STATIC ${cu})
|
||||
# target_link_libraries(icicle ${CUDA_LIBRARIES})
|
||||
# target_link_libraries(icicle cudart cuda cublas curand)
|
||||
# message(STATUS "include & link cuda")
|
||||
# link_directories(/usr/local/cuda/lib64)
|
||||
target_compile_options(icicle PRIVATE -c)
|
||||
# find_package(CUDA QUIET REQUIRED)
|
||||
|
||||
add_custom_command(
|
||||
TARGET icicle
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym msm_cuda=${CURVE}_msm_cuda ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
|
||||
COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
|
||||
COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
|
||||
COMMAND ${CMAKE_AR} ARGS -rcs ${LIBRARY_OUTPUT_DIRECTORY}/libingo_${CURVE}.a ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
|
||||
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym MSMCuda=${CURVE}MSMCuda ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
|
||||
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym GetDefaultMSMConfig=${CURVE}GetDefaultMSMConfig ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
|
||||
# COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
|
||||
# COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
|
||||
COMMAND ${CMAKE_AR} ARGS -rcs ${LIBRARY_OUTPUT_DIRECTORY}/libingo_${CURVE}.a
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/error_handler.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/vec_ops.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/field.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/projective.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/utils_kernels.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
|
||||
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
|
||||
|
||||
)
|
||||
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o)
|
||||
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o)
|
||||
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/projective.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/utils_kernels.cu.o)
|
||||
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/error_handler.cu.o)
|
||||
|
||||
else()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "../../primitives/projective.cuh"
|
||||
#include "../../utils/cuda_utils.cuh"
|
||||
#include "../../utils/error_handler.cuh"
|
||||
#include "../../utils/mont.cuh"
|
||||
|
||||
namespace msm {
|
||||
|
||||
@@ -27,6 +28,12 @@ namespace msm {
|
||||
// #define BIG_TRIANGLE
|
||||
// #define SSM_SUM //WIP
|
||||
|
||||
template <typename S>
|
||||
int get_optimal_c(int bitsize)
|
||||
{
|
||||
return ceil(log2(bitsize)) - 4;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
__global__ void single_stage_multi_reduction_kernel(
|
||||
P* v,
|
||||
@@ -132,32 +139,14 @@ namespace msm {
|
||||
buckets_indices[current_index] =
|
||||
(msm_index << (c + bm_bitsize)) | (bm << c) |
|
||||
bucket_index; // the bucket module number and the msm number are appended at the msbs
|
||||
if (scalar == S::zero() || scalar == S::one() || bucket_index == 0)
|
||||
buckets_indices[current_index] = 0; // will be skipped
|
||||
point_indices[current_index] = tid; // the point index is saved for later
|
||||
if (scalar == S::zero() || bucket_index == 0) buckets_indices[current_index] = 0; // will be skipped
|
||||
point_indices[current_index] = tid; // the point index is saved for later
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename P, typename A, typename S>
|
||||
__global__ void
|
||||
add_ones_kernel(A* points, S* scalars, P* results, const unsigned msm_size, const unsigned run_length)
|
||||
{
|
||||
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned nof_threads = (msm_size + run_length - 1) / run_length; // 129256
|
||||
if (tid >= nof_threads) {
|
||||
results[tid] = P::zero();
|
||||
return;
|
||||
}
|
||||
const unsigned start_index = tid * run_length;
|
||||
P sum = P::zero();
|
||||
for (int i = start_index; i < min(start_index + run_length, msm_size); i++) {
|
||||
if (scalars[i] == S::one()) sum = sum + points[i];
|
||||
}
|
||||
results[tid] = sum;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void
|
||||
find_cutoff_kernel(unsigned* v, unsigned size, unsigned cutoff, unsigned run_length, unsigned* result)
|
||||
{
|
||||
@@ -174,6 +163,7 @@ namespace msm {
|
||||
if (tid == 0 && v[size - 1] > cutoff) { result[0] = size; }
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void
|
||||
find_max_size(unsigned* bucket_sizes, unsigned* single_bucket_indices, unsigned c, unsigned* largest_bucket_size)
|
||||
{
|
||||
@@ -330,8 +320,8 @@ namespace msm {
|
||||
// this kernel computes the final result using the double and add algorithm
|
||||
// it is done by a single thread
|
||||
template <typename P, typename S>
|
||||
__global__ void final_accumulation_kernel(
|
||||
P* final_sums, P* ones_result, P* final_results, unsigned nof_msms, unsigned nof_bms, unsigned c, bool add_ones)
|
||||
__global__ void
|
||||
final_accumulation_kernel(P* final_sums, P* final_results, unsigned nof_msms, unsigned nof_bms, unsigned c)
|
||||
{
|
||||
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (tid > nof_msms) return;
|
||||
@@ -343,38 +333,53 @@ namespace msm {
|
||||
final_result = final_result + final_result;
|
||||
}
|
||||
}
|
||||
if (add_ones)
|
||||
final_results[tid] = final_result + final_sums[tid * nof_bms] + ones_result[0];
|
||||
else
|
||||
final_results[tid] = final_result + final_sums[tid * nof_bms];
|
||||
final_results[tid] = final_result + final_sums[tid * nof_bms];
|
||||
}
|
||||
|
||||
// this function computes msm using the bucket method
|
||||
template <typename S, typename P, typename A>
|
||||
void bucket_method_msm(
|
||||
unsigned bitsize,
|
||||
unsigned c,
|
||||
int bitsize,
|
||||
int c,
|
||||
S* scalars,
|
||||
A* points,
|
||||
unsigned size,
|
||||
int size,
|
||||
P* final_result,
|
||||
bool on_device,
|
||||
bool big_triangle,
|
||||
unsigned large_bucket_factor,
|
||||
bool are_scalars_on_device,
|
||||
bool are_scalars_montgomery_form,
|
||||
bool are_points_on_device,
|
||||
bool are_points_montgomery_form,
|
||||
bool is_result_on_device,
|
||||
bool is_big_triangle,
|
||||
int large_bucket_factor,
|
||||
bool is_async,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
S* d_scalars;
|
||||
A* d_points;
|
||||
if (!on_device) {
|
||||
// copy scalars and points to gpu
|
||||
if (!are_scalars_on_device) {
|
||||
// copy scalars to gpu
|
||||
cudaMallocAsync(&d_scalars, sizeof(S) * size, stream);
|
||||
cudaMallocAsync(&d_points, sizeof(A) * size, stream);
|
||||
cudaMemcpyAsync(d_scalars, scalars, sizeof(S) * size, cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(d_points, points, sizeof(A) * size, cudaMemcpyHostToDevice, stream);
|
||||
} else {
|
||||
d_scalars = scalars;
|
||||
}
|
||||
cudaStream_t stream_points;
|
||||
if (!are_points_on_device || are_points_montgomery_form) cudaStreamCreate(&stream_points);
|
||||
if (!are_points_on_device) {
|
||||
// copy points to gpu
|
||||
cudaMallocAsync(&d_points, sizeof(A) * size, stream_points);
|
||||
cudaMemcpyAsync(d_points, points, sizeof(A) * size, cudaMemcpyHostToDevice, stream_points);
|
||||
} else {
|
||||
d_points = points;
|
||||
}
|
||||
if (are_points_montgomery_form) mont::FromMontgomery(d_points, size, stream_points);
|
||||
if (are_scalars_montgomery_form) mont::FromMontgomery(d_scalars, size, stream);
|
||||
cudaEvent_t event_points_uploaded;
|
||||
if (!are_points_on_device || are_points_montgomery_form) {
|
||||
cudaEventCreateWithFlags(&event_points_uploaded, cudaEventDisableTiming);
|
||||
cudaEventRecord(event_points_uploaded, stream_points);
|
||||
}
|
||||
|
||||
P* buckets;
|
||||
// compute number of bucket modules and number of buckets in each module
|
||||
@@ -393,22 +398,6 @@ namespace msm {
|
||||
unsigned NUM_BLOCKS = (nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
|
||||
initialize_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, nof_buckets);
|
||||
|
||||
// accumulate ones
|
||||
P* ones_results; // fix whole division, in last run in kernel too
|
||||
const unsigned nof_runs = msm_log_size > 10 ? (1 << (msm_log_size - 6)) : 16;
|
||||
const unsigned run_length = (size + nof_runs - 1) / nof_runs;
|
||||
cudaMallocAsync(&ones_results, sizeof(P) * nof_runs, stream);
|
||||
NUM_THREADS = min(1 << 8, nof_runs);
|
||||
NUM_BLOCKS = (nof_runs + NUM_THREADS - 1) / NUM_THREADS;
|
||||
add_ones_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(d_points, d_scalars, ones_results, size, run_length);
|
||||
|
||||
for (int s = nof_runs >> 1; s > 0; s >>= 1) {
|
||||
NUM_THREADS = min(MAX_TH, s);
|
||||
NUM_BLOCKS = (s + NUM_THREADS - 1) / NUM_THREADS;
|
||||
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
ones_results, ones_results, s * 2, 0, 0, 0, s);
|
||||
}
|
||||
|
||||
unsigned* bucket_indices;
|
||||
unsigned* point_indices;
|
||||
cudaMallocAsync(&bucket_indices, sizeof(unsigned) * size * (nof_bms + 1), stream);
|
||||
@@ -482,7 +471,7 @@ namespace msm {
|
||||
|
||||
// if all points are 0 just return point 0
|
||||
if (h_nof_buckets_to_compute == 0) {
|
||||
if (!on_device)
|
||||
if (!is_result_on_device)
|
||||
final_result[0] = P::zero();
|
||||
else {
|
||||
P* h_final_result = (P*)malloc(sizeof(P));
|
||||
@@ -533,7 +522,7 @@ namespace msm {
|
||||
unsigned cutoff_nof_runs = (h_nof_buckets_to_compute + cutoff_run_length - 1) / cutoff_run_length;
|
||||
NUM_THREADS = min(1 << 5, cutoff_nof_runs);
|
||||
NUM_BLOCKS = (cutoff_nof_runs + NUM_THREADS - 1) / NUM_THREADS;
|
||||
find_cutoff_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
find_cutoff_kernel<S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
sorted_bucket_sizes, h_nof_buckets_to_compute, bucket_th, cutoff_run_length, nof_large_buckets);
|
||||
|
||||
unsigned h_nof_large_buckets;
|
||||
@@ -541,7 +530,7 @@ namespace msm {
|
||||
|
||||
unsigned* max_res;
|
||||
cudaMallocAsync(&max_res, sizeof(unsigned) * 2, stream);
|
||||
find_max_size<<<1, 1, 0, stream>>>(sorted_bucket_sizes, sorted_single_bucket_indices, c, max_res);
|
||||
find_max_size<S><<<1, 1, 0, stream>>>(sorted_bucket_sizes, sorted_single_bucket_indices, c, max_res);
|
||||
|
||||
unsigned h_max_res[2];
|
||||
cudaMemcpyAsync(h_max_res, max_res, sizeof(unsigned) * 2, cudaMemcpyDeviceToHost, stream);
|
||||
@@ -550,11 +539,19 @@ namespace msm {
|
||||
unsigned large_buckets_to_compute =
|
||||
h_nof_large_buckets > h_nof_zero_large_buckets ? h_nof_large_buckets - h_nof_zero_large_buckets : 0;
|
||||
|
||||
cudaStream_t stream2;
|
||||
cudaStreamCreate(&stream2);
|
||||
P* large_buckets;
|
||||
if (!are_points_on_device || are_points_montgomery_form) {
|
||||
// by this point, points need to be already uploaded and un-Montgomeried
|
||||
cudaStreamWaitEvent(stream, event_points_uploaded);
|
||||
cudaStreamDestroy(stream_points);
|
||||
}
|
||||
|
||||
cudaStream_t stream_large_buckets;
|
||||
cudaEvent_t event_large_buckets_accumulated;
|
||||
P* large_buckets;
|
||||
if (large_buckets_to_compute > 0 && bucket_th > 0) {
|
||||
cudaStreamCreate(&stream_large_buckets);
|
||||
cudaEventCreateWithFlags(&event_large_buckets_accumulated, cudaEventDisableTiming);
|
||||
|
||||
unsigned threads_per_bucket =
|
||||
1 << (unsigned)ceil(log2((h_largest_bucket_size + bucket_th - 1) / bucket_th)); // global param
|
||||
unsigned max_bucket_size_run_length = (h_largest_bucket_size + threads_per_bucket - 1) / threads_per_bucket;
|
||||
@@ -563,7 +560,7 @@ namespace msm {
|
||||
|
||||
NUM_THREADS = min(1 << 8, total_large_buckets_size);
|
||||
NUM_BLOCKS = (total_large_buckets_size + NUM_THREADS - 1) / NUM_THREADS;
|
||||
accumulate_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
|
||||
accumulate_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
|
||||
large_buckets, sorted_bucket_offsets + h_nof_zero_large_buckets,
|
||||
sorted_bucket_sizes + h_nof_zero_large_buckets, sorted_single_bucket_indices + h_nof_zero_large_buckets,
|
||||
point_indices, d_points, nof_buckets, large_buckets_to_compute, c + bm_bitsize, c, threads_per_bucket,
|
||||
@@ -573,17 +570,18 @@ namespace msm {
|
||||
for (int s = total_large_buckets_size >> 1; s > large_buckets_to_compute - 1; s >>= 1) {
|
||||
NUM_THREADS = min(MAX_TH, s);
|
||||
NUM_BLOCKS = (s + NUM_THREADS - 1) / NUM_THREADS;
|
||||
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
|
||||
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
|
||||
large_buckets, large_buckets, s * 2, 0, 0, 0, s);
|
||||
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
}
|
||||
|
||||
// distribute
|
||||
NUM_THREADS = min(MAX_TH, large_buckets_to_compute);
|
||||
NUM_BLOCKS = (large_buckets_to_compute + NUM_THREADS - 1) / NUM_THREADS;
|
||||
distribute_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
|
||||
distribute_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
|
||||
large_buckets, buckets, sorted_single_bucket_indices + h_nof_zero_large_buckets, large_buckets_to_compute);
|
||||
|
||||
cudaEventRecord(event_large_buckets_accumulated, stream_large_buckets);
|
||||
cudaStreamDestroy(stream_large_buckets);
|
||||
} else {
|
||||
h_nof_large_buckets = 0;
|
||||
}
|
||||
@@ -598,9 +596,11 @@ namespace msm {
|
||||
h_nof_buckets_to_compute - h_nof_large_buckets, c + bm_bitsize, c);
|
||||
}
|
||||
|
||||
// all the large buckets need to be accumulated before the final summation
|
||||
cudaStreamSynchronize(stream2);
|
||||
cudaStreamDestroy(stream2);
|
||||
if (large_buckets_to_compute > 0 && bucket_th > 0) {
|
||||
// all the large buckets need to be accumulated before the final summation
|
||||
cudaStreamWaitEvent(stream, event_large_buckets_accumulated);
|
||||
cudaStreamDestroy(stream_large_buckets);
|
||||
}
|
||||
|
||||
#ifdef SSM_SUM
|
||||
// sum each bucket
|
||||
@@ -618,10 +618,10 @@ namespace msm {
|
||||
#endif
|
||||
|
||||
P* d_final_result;
|
||||
if (!on_device) cudaMallocAsync(&d_final_result, sizeof(P), stream);
|
||||
if (!is_result_on_device) cudaMallocAsync(&d_final_result, sizeof(P), stream);
|
||||
|
||||
P* final_results;
|
||||
if (big_triangle) {
|
||||
if (is_big_triangle) {
|
||||
cudaMallocAsync(&final_results, sizeof(P) * nof_bms, stream);
|
||||
// launch the bucket module sum kernel - a thread for each bucket module
|
||||
NUM_THREADS = nof_bms;
|
||||
@@ -697,19 +697,17 @@ namespace msm {
|
||||
}
|
||||
|
||||
// launch the double and add kernel, a single thread
|
||||
final_accumulation_kernel<P, S><<<1, 1, 0, stream>>>(
|
||||
final_results, ones_results, on_device ? final_result : d_final_result, 1, nof_bms, c, true);
|
||||
final_accumulation_kernel<P, S>
|
||||
<<<1, 1, 0, stream>>>(final_results, is_result_on_device ? final_result : d_final_result, 1, nof_bms, c);
|
||||
cudaFreeAsync(final_results, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
if (!on_device) cudaMemcpyAsync(final_result, d_final_result, sizeof(P), cudaMemcpyDeviceToHost, stream);
|
||||
if (!is_result_on_device)
|
||||
cudaMemcpyAsync(final_result, d_final_result, sizeof(P), cudaMemcpyDeviceToHost, stream);
|
||||
|
||||
// free memory
|
||||
if (!on_device) {
|
||||
cudaFreeAsync(d_points, stream);
|
||||
cudaFreeAsync(d_scalars, stream);
|
||||
cudaFreeAsync(d_final_result, stream);
|
||||
}
|
||||
if (!are_scalars_on_device) cudaFreeAsync(d_scalars, stream);
|
||||
if (!are_points_on_device) cudaFreeAsync(d_points, stream);
|
||||
if (!is_result_on_device) cudaFreeAsync(d_final_result, stream);
|
||||
cudaFreeAsync(buckets, stream);
|
||||
#ifndef PHASE1_TEST
|
||||
cudaFreeAsync(bucket_indices, stream);
|
||||
@@ -725,9 +723,8 @@ namespace msm {
|
||||
cudaFreeAsync(nof_large_buckets, stream);
|
||||
cudaFreeAsync(max_res, stream);
|
||||
if (large_buckets_to_compute > 0 && bucket_th > 0) cudaFreeAsync(large_buckets, stream);
|
||||
cudaFreeAsync(ones_results, stream);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
if (!is_async) cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
// this function computes multiple msms using the bucket method
|
||||
@@ -879,7 +876,7 @@ namespace msm {
|
||||
NUM_THREADS = 1 << 8;
|
||||
NUM_BLOCKS = (batch_size + NUM_THREADS - 1) / NUM_THREADS;
|
||||
final_accumulation_kernel<P, S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
bm_sums, bm_sums, on_device ? final_results : d_final_results, batch_size, nof_bms, c, false);
|
||||
bm_sums, on_device ? final_results : d_final_results, batch_size, nof_bms, c);
|
||||
|
||||
// copy final result to host
|
||||
if (!on_device)
|
||||
@@ -907,26 +904,59 @@ namespace msm {
|
||||
|
||||
} // namespace
|
||||
|
||||
MSMConfig DefaultMSMConfig()
|
||||
{
|
||||
device_context::DeviceContext ctx = {
|
||||
0, // device_id
|
||||
(cudaStream_t)0, // stream
|
||||
0, // mempool
|
||||
};
|
||||
MSMConfig config = {
|
||||
false, // are_scalars_on_device
|
||||
false, // are_scalars_montgomery_form
|
||||
0, // points_size
|
||||
1, // precompute_factor
|
||||
false, // are_points_on_device
|
||||
false, // are_points_montgomery_form
|
||||
1, // batch_size
|
||||
false, // are_results_on_device
|
||||
0, // c
|
||||
0, // bitsize
|
||||
false, // is_big_triangle
|
||||
10, // large_bucket_factor
|
||||
false, // is_async
|
||||
ctx, // DeviceContext
|
||||
};
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename S, typename A, typename P>
|
||||
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig config, P* results)
|
||||
{
|
||||
int bitsize = (config.bitsize == 0) ? S::NBITS : config.bitsize;
|
||||
// TODO: DmytroTym/HadarIngonyama - unify the implementation of the bucket method and the batched bucket method in
|
||||
// one function
|
||||
// TODO: DmytroTym/HadarIngonyama - parameters to be included into the implementation: on deviceness of points,
|
||||
// scalars and results, precompute factor, points size and device id
|
||||
if (config.batch_size == 1)
|
||||
bucket_method_msm(
|
||||
config.bitsize, config.c, scalars, points, msm_size, results, config.are_scalars_on_device, config.big_triangle,
|
||||
config.large_bucket_factor, config.ctx.stream);
|
||||
bitsize, 16, scalars, points, msm_size, results, config.are_scalars_on_device,
|
||||
config.are_scalars_montgomery_form, config.are_points_on_device, config.are_points_montgomery_form,
|
||||
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.is_async,
|
||||
config.ctx.stream);
|
||||
else
|
||||
batched_bucket_method_msm(
|
||||
config.bitsize, config.c, scalars, points, config.batch_size, msm_size, results, config.are_scalars_on_device,
|
||||
config.ctx.stream);
|
||||
bitsize, (config.c == 0) ? get_optimal_c<S>(bitsize) : config.c, scalars, points, config.batch_size, msm_size,
|
||||
results, config.are_scalars_on_device, config.ctx.stream);
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern version of [msm](@ref msm) function with the following values of template parameters
|
||||
* Extern version of [DefaultMSMConfig](@ref DefaultMSMConfig) function.
|
||||
* @return Default value of [MSMConfig](@ref MSMConfig).
|
||||
*/
|
||||
extern "C" MSMConfig GetDefaultMSMConfig() { return DefaultMSMConfig(); }
|
||||
|
||||
/**
|
||||
* Extern version of [MSM](@ref MSM) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` is the [scalar field](@ref scalar_t) of the curve;
|
||||
* - `A` is the [affine representation](@ref affine_t) of curve points;
|
||||
@@ -947,7 +977,7 @@ namespace msm {
|
||||
#if defined(G2_DEFINED)
|
||||
|
||||
/**
|
||||
* Extern version of [msm](@ref msm) function with the following values of template parameters
|
||||
* Extern version of [MSM](@ref MSM) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` is the [scalar field](@ref scalar_t) of the curve;
|
||||
* - `A` is the [affine representation](@ref g2_affine_t) of G2 curve points;
|
||||
|
||||
@@ -4,7 +4,13 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "../../curves/curve_config.cuh"
|
||||
#include "../../primitives/affine.cuh"
|
||||
#include "../../primitives/field.cuh"
|
||||
#include "../../primitives/projective.cuh"
|
||||
#include "../../utils/cuda_utils.cuh"
|
||||
#include "../../utils/device_context.cuh"
|
||||
#include "../../utils/error_handler.cuh"
|
||||
|
||||
/**
|
||||
* @namespace msm
|
||||
@@ -41,31 +47,41 @@ namespace msm {
|
||||
* the MSM size). */
|
||||
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the
|
||||
* number of computations to make, on-line memory footprint, but increase the static
|
||||
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
|
||||
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
|
||||
bool are_points_on_device; /**< True if points are on device and false if they're on host. Default value: false. */
|
||||
bool are_points_montgomery_form; /**< True if coordinates of points are in Montgomery form and false otherwise.
|
||||
Default value: true. */
|
||||
int batch_size; /**< The number of MSMs to compute. Default value: 1. */
|
||||
bool are_result_on_device; /**< True if the results should be on device and false if they should be on host. Default
|
||||
value: false. */
|
||||
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
|
||||
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
|
||||
* footprint but also more parallelism and less computational complexity (up to a certain point).
|
||||
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
|
||||
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a
|
||||
* different (better) upper bound is known, it should be reflected in this variable. Default
|
||||
* value: 0 (set to the bitsize of scalar field). */
|
||||
bool big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but also
|
||||
* greatly decreases parallelism, so only suitable for large batches of MSMs. Default value:
|
||||
* false. */
|
||||
bool are_results_on_device; /**< True if the results should be on device and false if they should be on host. If set
|
||||
* to false, `is_async` won't take effect because a synchronization is needed to
|
||||
* transfer results to the host. Default value: false. */
|
||||
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
|
||||
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
|
||||
* footprint but also more parallelism and less computational complexity (up to a certain point).
|
||||
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
|
||||
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a
|
||||
* different (better) upper bound is known, it should be reflected in this variable. Default value: 0
|
||||
* (set to the bitsize of scalar field). */
|
||||
bool is_big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but
|
||||
* also greatly decreases parallelism, so only suitable for large batches of MSMs. Default
|
||||
* value: false. */
|
||||
int large_bucket_factor; /**< Variable that controls how sensitive the algorithm is to the buckets that occur very
|
||||
* frequently. Useful for efficient treatment of non-uniform distributions of scalars and
|
||||
* "top windows" with few bits. Can be set to 0 to disable separate treatment of large
|
||||
* buckets altogether. Default value: 10. */
|
||||
int is_async; /**< Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
|
||||
* and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or
|
||||
* `cudaDeviceSynchronize`. If set to false, the MSM function will block the current CPU thread. */
|
||||
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See
|
||||
[DeviceContext](@ref device_context::DeviceContext). */
|
||||
[DeviceContext](@ref `device_context::DeviceContext`). */
|
||||
};
|
||||
|
||||
/**
|
||||
* A function that returns the default value of [MSMConfig](@ref MSMConfig) for the [MSM](@ref MSM) function.
|
||||
* @return Default value of [MSMConfig](@ref MSMConfig).
|
||||
*/
|
||||
extern "C" MSMConfig DefaultMSMConfig();
|
||||
|
||||
/**
|
||||
* A function that computes MSM: \f$ MSM(s_i, P_i) = \sum_{i=1}^N s_i \cdot P_i \f$.
|
||||
* @param scalars Scalars \f$ s_i \f$. In case of batch MSM, the scalars from all MSMs are concatenated.
|
||||
@@ -81,6 +97,15 @@ namespace msm {
|
||||
* @tparam P Output type, which is typically a [projective
|
||||
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) point in our codebase.
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*
|
||||
* This function is asyncronous, and to sync it with host, you need to call `cudaDeviceSyncronize()`. To syncronize
|
||||
* with a different stream `stream1`, call `cudaStreamSynchronize(config.stream)` and
|
||||
* `cudaStreamSynchronize(stream1)`.
|
||||
*
|
||||
* **Note:** this function is still WIP and the following [MSMConfig](@ref MSMConfig) members do not yet have any
|
||||
* effect: `points_size` (it's always equal to the msm size currenly), `precompute_factor` (always equals 1) and
|
||||
* `ctx.device_id` (0 device is always used). Also, it's currently better to use `batch_size=1` in most cases (expept
|
||||
* with dealing with very many MSMs).
|
||||
*/
|
||||
template <typename S, typename A, typename P>
|
||||
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig config, P* results);
|
||||
|
||||
@@ -130,18 +130,14 @@ int main()
|
||||
test_scalar* scalars = new test_scalar[N];
|
||||
test_affine* points = new test_affine[N];
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
// scalars[i] = (i%msm_size < 10)? test_scalar::rand_host() : scalars[i-10];
|
||||
points[i] = (i % msm_size < 100) ? test_projective::to_affine(test_projective::rand_host()) : points[i - 100];
|
||||
scalars[i] = test_scalar::rand_host();
|
||||
// scalars[i] = i < N/2? test_scalar::rand_host() : test_scalar::one();
|
||||
// points[i] = test_projective::to_affine(test_projective::rand_host());
|
||||
}
|
||||
test_scalar::RandHostMany(scalars, N);
|
||||
test_projective::RandHostManyAffine(points, N);
|
||||
|
||||
std::cout << "finished generating" << std::endl;
|
||||
|
||||
// projective_t *short_res = (projective_t*)malloc(sizeof(projective_t));
|
||||
// test_projective *large_res = (test_projective*)malloc(sizeof(test_projective));
|
||||
test_projective large_res[batch_size * 2];
|
||||
test_projective large_res[batch_size];
|
||||
// test_projective batched_large_res[batch_size];
|
||||
// fake_point *large_res = (fake_point*)malloc(sizeof(fake_point));
|
||||
// fake_point batched_large_res[256];
|
||||
@@ -175,44 +171,47 @@ int main()
|
||||
0, // mempool
|
||||
};
|
||||
msm::MSMConfig config = {
|
||||
false, // scalars_on_device
|
||||
true, // scalars_montgomery_form
|
||||
msm_size, // points_size
|
||||
1, // precompute_factor
|
||||
false, // points_on_device
|
||||
true, // points_montgomery_form
|
||||
1, // batch_size
|
||||
false, // result_on_device
|
||||
16, // c
|
||||
test_scalar::NBITS, // bitsize
|
||||
false, // big_triangle
|
||||
10, // large_bucket_factor
|
||||
ctx, // DeviceContext
|
||||
false, // are_scalars_on_device
|
||||
false, // are_scalars_montgomery_form
|
||||
0, // points_size
|
||||
1, // precompute_factor
|
||||
false, // are_points_on_device
|
||||
false, // are_points_montgomery_form
|
||||
1, // batch_size
|
||||
true, // are_results_on_device
|
||||
0, // c
|
||||
0, // bitsize
|
||||
false, // is_big_triangle
|
||||
10, // large_bucket_factor
|
||||
true, // is_async
|
||||
ctx, // DeviceContext
|
||||
};
|
||||
|
||||
auto begin1 = std::chrono::high_resolution_clock::now();
|
||||
msm::MSM<test_scalar, test_affine, test_projective>(scalars, points, msm_size, config, large_res);
|
||||
msm::MSM<test_scalar, test_affine, test_projective>(scalars, points, msm_size, config, large_res_d);
|
||||
cudaEvent_t msm_end_event;
|
||||
cudaEventCreate(&msm_end_event);
|
||||
auto end1 = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed1 = std::chrono::duration_cast<std::chrono::nanoseconds>(end1 - begin1);
|
||||
printf("Big Triangle : %.3f seconds.\n", elapsed1.count() * 1e-9);
|
||||
config.big_triangle = true;
|
||||
printf("No Big Triangle : %.3f seconds.\n", elapsed1.count() * 1e-9);
|
||||
config.is_big_triangle = true;
|
||||
config.are_results_on_device = false;
|
||||
// std::cout<<test_projective::to_affine(large_res[0])<<std::endl;
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
msm::MSM<test_scalar, test_affine, test_projective>(scalars_d, points_d, msm_size, config, large_res_d);
|
||||
msm::MSM<test_scalar, test_affine, test_projective>(scalars_d, points_d, msm_size, config, large_res);
|
||||
// test_reduce_triangle(scalars);
|
||||
// test_reduce_rectangle(scalars);
|
||||
// test_reduce_single(scalars);
|
||||
// test_reduce_var(scalars);
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
|
||||
printf("On Device No Big Triangle: %.3f seconds.\n", elapsed.count() * 1e-9);
|
||||
printf("Big Triangle: %.3f seconds.\n", elapsed.count() * 1e-9);
|
||||
cudaStreamSynchronize(stream);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
std::cout << test_projective::to_affine(large_res[0]) << std::endl;
|
||||
|
||||
cudaMemcpy(&large_res[1], large_res_d, sizeof(test_projective), cudaMemcpyDeviceToHost);
|
||||
std::cout << test_projective::to_affine(large_res[1]) << std::endl;
|
||||
|
||||
// reference_msm<test_affine, test_scalar, test_projective>(scalars, points, msm_size);
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ namespace ntt {
|
||||
|
||||
namespace {
|
||||
|
||||
const uint32_t MAX_NUM_THREADS = 1024;
|
||||
const uint32_t MAX_NUM_THREADS = 512;
|
||||
const uint32_t MAX_THREADS_BATCH = 512; // TODO: allows 100% occupancy for scalar NTT for sm_86..sm_89
|
||||
const uint32_t MAX_SHARED_MEM_ELEMENT_SIZE = 32; // TODO: occupancy calculator, hardcoded for sm_86..sm_89
|
||||
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * 1024;
|
||||
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * MAX_NUM_THREADS;
|
||||
|
||||
/**
|
||||
* Computes the twiddle factors.
|
||||
@@ -21,13 +21,10 @@ namespace ntt {
|
||||
* @param omega multiplying factor.
|
||||
*/
|
||||
template <typename S>
|
||||
__global__ void twiddle_factors_kernel(S* d_twiddles, uint32_t n_twiddles, S omega)
|
||||
__global__ void twiddle_factors_kernel(S* d_twiddles, int n_twiddles, S omega)
|
||||
{
|
||||
for (uint32_t i = 0; i < n_twiddles; i++) {
|
||||
d_twiddles[i] = S::zero();
|
||||
}
|
||||
d_twiddles[0] = S::one();
|
||||
for (uint32_t i = 0; i < n_twiddles - 1; i++) {
|
||||
for (int i = 0; i < n_twiddles - 1; i++) {
|
||||
d_twiddles[i + 1] = omega * d_twiddles[i];
|
||||
}
|
||||
}
|
||||
@@ -61,7 +58,7 @@ namespace ntt {
|
||||
int number_of_threads = MAX_THREADS_BATCH;
|
||||
int number_of_blocks = (n * batch_size + number_of_threads - 1) / number_of_threads;
|
||||
reverse_order_kernel<<<number_of_blocks, number_of_threads, 0, stream>>>(arr, arr_reversed, n, logn, batch_size);
|
||||
cudaMemcpyAsync(arr, arr_reversed, n * batch_size * sizeof(E), cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(arr, arr_reversed, n * batch_size * sizeof(E), cudaMemcpyDefault, stream);
|
||||
cudaFreeAsync(arr_reversed, stream);
|
||||
}
|
||||
|
||||
@@ -313,15 +310,15 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (is_coset)
|
||||
utils_internal::batchVectorMult<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
|
||||
utils_internal::BatchMulKernel<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
|
||||
|
||||
num_threads = max(min(n / 2, MAX_NUM_THREADS), 1);
|
||||
num_blocks = (n * batch_size + num_threads - 1) / num_threads;
|
||||
utils_internal::template_normalize_kernel<E, S>
|
||||
utils_internal::NormalizeKernel<E, S>
|
||||
<<<num_blocks, num_threads, 0, stream>>>(d_inout, S::inv_log_size(logn), n * batch_size);
|
||||
} else {
|
||||
if (is_coset)
|
||||
utils_internal::batchVectorMult<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
|
||||
utils_internal::BatchMulKernel<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
|
||||
|
||||
for (int s = logn - 1; s >= logn_shmem; s--) // TODO: this loop also can be unrolled
|
||||
{
|
||||
@@ -349,30 +346,44 @@ namespace ntt {
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t NTT(E* input, int size, bool is_inverse, NTTConfig<S> config)
|
||||
cudaError_t NTT(NTTConfig<E, S>* config)
|
||||
{
|
||||
uint32_t logn = uint32_t(log(size) / log(2));
|
||||
uint32_t n_twiddles = size; // n_twiddles is set to 4096 as BLS12_381::scalar_t::omega() is of that order.
|
||||
size_t input_size = size * config.batch_size * sizeof(E);
|
||||
bool is_on_device = config.are_inputs_on_device;
|
||||
bool generate_twiddles = (config.twiddles == nullptr);
|
||||
cudaStream_t stream = config.ctx.stream;
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
cudaStream_t stream = config->ctx.stream;
|
||||
int size = config->size;
|
||||
int batch_size = config->batch_size;
|
||||
bool is_inverse = config->is_inverse;
|
||||
int n_twiddles = size;
|
||||
int logn = int(log(size) / log(2));
|
||||
int input_size_bytes = size * batch_size * sizeof(E);
|
||||
bool is_input_on_device = config->are_inputs_on_device;
|
||||
bool is_output_on_device = config->is_output_on_device;
|
||||
bool is_forward_twiddle_empty = config->twiddles == nullptr;
|
||||
bool is_inverse_twiddle_empty = config->inv_twiddles == nullptr;
|
||||
bool is_generating_twiddles = (is_forward_twiddle_empty && is_inverse_twiddle_empty) ||
|
||||
(is_forward_twiddle_empty && !is_inverse) || (is_inverse_twiddle_empty && is_inverse);
|
||||
|
||||
S* d_twiddles;
|
||||
if (generate_twiddles) {
|
||||
if (is_generating_twiddles) {
|
||||
cudaMallocAsync(&d_twiddles, n_twiddles * sizeof(S), stream);
|
||||
GenerateTwiddleFactors(d_twiddles, n_twiddles, is_inverse ? S::omega_inv(logn) : S::omega(logn), config.ctx);
|
||||
S omega = is_inverse ? S::omega_inv(logn) : S::omega(logn);
|
||||
GenerateTwiddleFactors(d_twiddles, n_twiddles, omega, config->ctx);
|
||||
} else {
|
||||
d_twiddles = is_inverse ? config->inv_twiddles : config->twiddles;
|
||||
}
|
||||
|
||||
E* d_input;
|
||||
if (!is_on_device) {
|
||||
cudaMallocAsync(&d_input, input_size, stream);
|
||||
cudaMemcpyAsync(d_input, input, input_size, cudaMemcpyHostToDevice, stream);
|
||||
E* d_inout;
|
||||
if (is_input_on_device) {
|
||||
d_inout = config->inout;
|
||||
} else {
|
||||
cudaMallocAsync(&d_inout, input_size_bytes, stream);
|
||||
cudaMemcpyAsync(d_inout, config->inout, input_size_bytes, cudaMemcpyHostToDevice, stream);
|
||||
}
|
||||
|
||||
bool reverse_input;
|
||||
bool reverse_output;
|
||||
switch (config.ordering) {
|
||||
switch (config->ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = is_inverse;
|
||||
reverse_output = !is_inverse;
|
||||
@@ -390,33 +401,50 @@ namespace ntt {
|
||||
reverse_output = is_inverse;
|
||||
break;
|
||||
}
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_inout, size, logn, config->batch_size, stream);
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
if (reverse_input) reverse_order_batch(is_on_device ? input : d_input, size, logn, config.batch_size, stream);
|
||||
ntt_inplace_batch_template(
|
||||
is_on_device ? input : d_input, generate_twiddles ? d_twiddles : config.twiddles, size, config.batch_size,
|
||||
is_inverse, config.is_coset, config.coset_gen, stream, false);
|
||||
if (reverse_output) reverse_order_batch(is_on_device ? input : d_input, size, logn, config.batch_size, stream);
|
||||
d_inout, d_twiddles, size, batch_size, is_inverse, config->is_coset, config->coset_gen, stream, false);
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
if (!is_on_device) {
|
||||
cudaMemcpyAsync(input, d_input, input_size, cudaMemcpyDeviceToHost, stream);
|
||||
cudaFreeAsync(d_input, stream);
|
||||
if (reverse_output) reverse_order_batch(d_inout, size, logn, batch_size, stream);
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
if (is_output_on_device) {
|
||||
// free(config->inout); // TODO: ? or callback?+
|
||||
config->inout = d_inout;
|
||||
} else {
|
||||
if (is_input_on_device) {
|
||||
E* h_output = (E*)malloc(input_size_bytes); // TODO: caller responsible for memory management
|
||||
cudaMemcpyAsync(h_output, d_inout, input_size_bytes, cudaMemcpyDeviceToHost, stream);
|
||||
config->inout = h_output;
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
} else {
|
||||
cudaMemcpyAsync(config->inout, d_inout, input_size_bytes, cudaMemcpyDeviceToHost, stream);
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
}
|
||||
cudaFreeAsync(d_inout, stream); // TODO: make it optional? so can be reused
|
||||
}
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
if (is_generating_twiddles && !config->is_preserving_twiddles) { cudaFreeAsync(d_twiddles, stream); }
|
||||
|
||||
if (config->is_preserving_twiddles) {
|
||||
if (is_inverse)
|
||||
config->inv_twiddles = d_twiddles;
|
||||
else {
|
||||
config->twiddles = d_twiddles;
|
||||
}
|
||||
}
|
||||
|
||||
if (generate_twiddles) { cudaFreeAsync(d_twiddles, stream); }
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern version of [GenerateTwiddleFactors](@ref GenerateTwiddleFactors) function with the template parameter
|
||||
* `S` being the [scalar field](@ref scalar_t) of the curve given by `-DCURVE` env variable during build.
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t GenerateTwiddleFactorsCuda(
|
||||
curve_config::scalar_t* d_twiddles, int n_twiddles, curve_config::scalar_t omega, device_context::DeviceContext ctx)
|
||||
{
|
||||
return GenerateTwiddleFactors<curve_config::scalar_t>(d_twiddles, n_twiddles, omega, ctx);
|
||||
CHECK_LAST_CUDA_ERROR();
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -425,10 +453,43 @@ namespace ntt {
|
||||
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t
|
||||
NTTCuda(curve_config::scalar_t* input, int size, bool is_inverse, NTTConfig<curve_config::scalar_t> config)
|
||||
extern "C" cudaError_t NTTCuda(NTTConfig<curve_config::scalar_t, curve_config::scalar_t>* config)
|
||||
{
|
||||
return NTT<curve_config::scalar_t, curve_config::scalar_t>(input, size, is_inverse, config);
|
||||
return NTT<curve_config::scalar_t, curve_config::scalar_t>(config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern version of [ntt](@ref ntt) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
cudaError_t NTTDefaultContext(NTTConfig<E, S>* config)
|
||||
{
|
||||
// TODO: if empty - create default
|
||||
cudaMemPool_t mempool;
|
||||
cudaDeviceGetDefaultMemPool(&mempool, config->ctx.device_id);
|
||||
|
||||
device_context::DeviceContext context = {
|
||||
config->ctx.device_id,
|
||||
0, // default stream
|
||||
mempool};
|
||||
|
||||
config->ctx = context;
|
||||
|
||||
return NTT<E, S>(config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern version of [ntt](@ref ntt) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t NTTDefaultContextCuda(NTTConfig<curve_config::scalar_t, curve_config::scalar_t>* config)
|
||||
{
|
||||
return NTTDefaultContext(config);
|
||||
}
|
||||
|
||||
#if defined(ECNTT_DEFINED)
|
||||
@@ -440,10 +501,9 @@ namespace ntt {
|
||||
* - `E` is the [scalar field](@ref scalar_t) of the curve;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t
|
||||
ECNTTCuda(curve_config::projective_t* input, int size, bool is_inverse, NTTConfig<curve_config::scalar_t> config)
|
||||
extern "C" cudaError_t ECNTTCuda(NTTConfig<curve_config::projective_t, curve_config::scalar_t>* config)
|
||||
{
|
||||
return NTT<curve_config::projective_t, curve_config::scalar_t>(input, size, is_inverse, config);
|
||||
return NTT<curve_config::projective_t, curve_config::scalar_t>(config);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "../../curves/curve_config.cuh"
|
||||
#include "../../utils/device_context.cuh"
|
||||
#include "../../utils/error_handler.cuh"
|
||||
#include "../../utils/sharedmem.cuh"
|
||||
#include "../../utils/utils_kernels.cuh"
|
||||
|
||||
/**
|
||||
* @namespace ntt
|
||||
@@ -51,10 +55,14 @@ namespace ntt {
|
||||
* @struct NTTConfig
|
||||
* Struct that encodes NTT parameters to be passed into the [ntt](@ref ntt) function.
|
||||
*/
|
||||
template <typename S>
|
||||
template <typename E, typename S>
|
||||
struct NTTConfig {
|
||||
E* inout; /**< Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot
|
||||
* config.batch_size \f$. Note that if inputs are in Montgomery form, the outputs will be as well and
|
||||
* vice-verse: non-Montgomery inputs produce non-Montgomety outputs.*/
|
||||
bool are_inputs_on_device; /**< True if inputs/outputs are on device and false if they're on host. Default value:
|
||||
false. */
|
||||
bool is_inverse; /**< True if true . Default value: false. */
|
||||
Ordering
|
||||
ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. */
|
||||
Decimation
|
||||
@@ -64,35 +72,40 @@ namespace ntt {
|
||||
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
|
||||
* `Decimation::kDIT` and if ordering is `Ordering::kNR` — to `Decimation::kDIF`. */
|
||||
Butterfly
|
||||
butterfly; /**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value:
|
||||
* `Butterfly::kCooleyTukey`.
|
||||
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
|
||||
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
|
||||
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
|
||||
bool is_coset; /**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is
|
||||
* computed on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen),
|
||||
* so: \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\}
|
||||
* \f$. Default value: false. */
|
||||
S* coset_gen; /**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
|
||||
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
|
||||
S* twiddles; /**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
|
||||
* This pointer is expected to live on device. The order is as follows:
|
||||
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
|
||||
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
|
||||
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
|
||||
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
|
||||
butterfly; /**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value:
|
||||
* `Butterfly::kCooleyTukey`.
|
||||
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
|
||||
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
|
||||
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
|
||||
bool is_coset; /**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is
|
||||
* computed on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen),
|
||||
* so: \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\}
|
||||
* \f$. Default value: false. */
|
||||
S* coset_gen; /**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
|
||||
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
|
||||
S* twiddles; /**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
|
||||
* This pointer is expected to live on device. The order is as follows:
|
||||
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
|
||||
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
|
||||
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
|
||||
S* inv_twiddles; /**< "Inverse twiddle factors", (or "domain", or "roots of unity") on which the iNTT is evaluated.
|
||||
* This pointer is expected to live on device. The order is as follows:
|
||||
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
|
||||
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
|
||||
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
|
||||
int size; /**< NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is
|
||||
the size of 1 NTT. */
|
||||
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
|
||||
bool is_preserving_twiddles; /**< If true, twiddle factors are preserved on device for subsequent use in config and
|
||||
not freed after calculation. Default value: false. */
|
||||
bool is_output_on_device; /**< If true, output is preserved on device for subsequent use in config and not freed
|
||||
after calculation. Default value: false. */
|
||||
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See
|
||||
[DeviceContext](@ref device_context::DeviceContext). */
|
||||
};
|
||||
|
||||
/**
|
||||
* A function that computes NTT or iNTT in-place.
|
||||
* @param input Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot
|
||||
* config.batch_size \f$. Note that if inputs are in Montgomery form, the outputs will be as well and vice-verse:
|
||||
* non-Montgomery inputs produce non-Montgomety outputs.
|
||||
* @param size NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is the
|
||||
* size of 1 NTT.
|
||||
* @param is_inverse If true, inverse NTT is computed, otherwise — regular forward NTT.
|
||||
* @param config [NTTConfig](@ref NTTConfig) used in this NTT.
|
||||
* @tparam E The type of inputs and outputs (i.e. coefficients \f$ \{p_i\} \f$ and values \f$ p(x) \f$). Must be a
|
||||
* group.
|
||||
@@ -100,7 +113,7 @@ namespace ntt {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
cudaError_t NTT(E* input, int size, bool is_inverse, NTTConfig<S> config);
|
||||
cudaError_t NTT(NTTConfig<E, S>* config);
|
||||
|
||||
/**
|
||||
* Generates twiddles \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$ from root of unity \f$ \omega \f$ and
|
||||
|
||||
@@ -25,11 +25,10 @@ using namespace bls12_377;
|
||||
|
||||
namespace curve_config {
|
||||
|
||||
typedef Field<fp_config> scalar_field_t;
|
||||
typedef scalar_field_t scalar_t;
|
||||
typedef Field<fp_config> scalar_t;
|
||||
typedef Field<fq_config> point_field_t;
|
||||
static constexpr point_field_t b = point_field_t{weierstrass_b};
|
||||
typedef Projective<point_field_t, scalar_field_t, b> projective_t;
|
||||
typedef Projective<point_field_t, scalar_t, b> projective_t;
|
||||
typedef Affine<point_field_t> affine_t;
|
||||
|
||||
#if defined(G2_DEFINED)
|
||||
|
||||
@@ -11,6 +11,16 @@ public:
|
||||
|
||||
static HOST_DEVICE_INLINE Affine neg(const Affine& point) { return {point.x, FF::neg(point.y)}; }
|
||||
|
||||
static HOST_DEVICE_INLINE Affine ToMontgomery(const Affine& point)
|
||||
{
|
||||
return {FF::ToMontgomery(point.x), FF::ToMontgomery(point.y)};
|
||||
}
|
||||
|
||||
static HOST_DEVICE_INLINE Affine FromMontgomery(const Affine& point)
|
||||
{
|
||||
return {FF::FromMontgomery(point.x), FF::FromMontgomery(point.y)};
|
||||
}
|
||||
|
||||
friend HOST_DEVICE_INLINE bool operator==(const Affine& xs, const Affine& ys)
|
||||
{
|
||||
return (xs.x == ys.x) && (xs.y == ys.y);
|
||||
|
||||
6
icicle/primitives/field.cu
Normal file
6
icicle/primitives/field.cu
Normal file
@@ -0,0 +1,6 @@
|
||||
#include "../curves/curve_config.cuh"
|
||||
#include "field.cuh"
|
||||
|
||||
#define scalar_t curve_config::scalar_t
|
||||
|
||||
extern "C" void GenerateScalars(scalar_t* scalars, int size) { scalar_t::RandHostMany(scalars, size); }
|
||||
@@ -69,10 +69,6 @@ public:
|
||||
|
||||
static constexpr HOST_DEVICE_INLINE Field modulus() { return Field{CONFIG::modulus}; }
|
||||
|
||||
static constexpr HOST_DEVICE_INLINE Field montgomery_r() { return Field{CONFIG::montgomery_r}; }
|
||||
|
||||
static constexpr HOST_DEVICE_INLINE Field montgomery_r_inv() { return Field{CONFIG::montgomery_r_inv}; }
|
||||
|
||||
// private:
|
||||
typedef storage<TLC> ff_storage;
|
||||
typedef storage<2 * TLC> ff_wide_storage;
|
||||
@@ -703,6 +699,12 @@ public:
|
||||
return value;
|
||||
}
|
||||
|
||||
static void RandHostMany(Field* out, int size)
|
||||
{
|
||||
for (int i = 0; i < size; i++)
|
||||
out[i] = rand_host();
|
||||
}
|
||||
|
||||
template <unsigned REDUCTION_SIZE = 1>
|
||||
static constexpr HOST_DEVICE_INLINE Field sub_modulus(const Field& xs)
|
||||
{
|
||||
@@ -718,7 +720,7 @@ public:
|
||||
hex_string << std::hex << std::setfill('0');
|
||||
|
||||
for (int i = 0; i < TLC; i++) {
|
||||
hex_string << std::setw(8) << xs.limbs_storage.limbs[i];
|
||||
hex_string << std::setw(8) << xs.limbs_storage.limbs[TLC - i - 1];
|
||||
}
|
||||
|
||||
os << "0x" << hex_string.str();
|
||||
@@ -869,6 +871,13 @@ public:
|
||||
return xs * xs;
|
||||
}
|
||||
|
||||
static constexpr HOST_DEVICE_INLINE Field ToMontgomery(const Field& xs) { return xs * Field{CONFIG::montgomery_r}; }
|
||||
|
||||
static constexpr HOST_DEVICE_INLINE Field FromMontgomery(const Field& xs)
|
||||
{
|
||||
return xs * Field{CONFIG::montgomery_r_inv};
|
||||
}
|
||||
|
||||
template <unsigned MODULUS_MULTIPLE = 1>
|
||||
static constexpr HOST_DEVICE_INLINE Field neg(const Field& xs)
|
||||
{
|
||||
|
||||
@@ -1,61 +1,54 @@
|
||||
#include "../curves/bls12_377/curve_config.cuh"
|
||||
#include "../curves/bls12_381/curve_config.cuh"
|
||||
#include "../curves/bn254/curve_config.cuh"
|
||||
#include "../curves/curve_config.cuh"
|
||||
#include "projective.cuh"
|
||||
#include <cuda.h>
|
||||
|
||||
extern "C" bool eq_bls12_381(BLS12_381::projective_t* point1, BLS12_381::projective_t* point2)
|
||||
#define projective_t curve_config::projective_t // TODO: global to avoid lengthy texts
|
||||
#define affine_t curve_config::affine_t
|
||||
#define point_field_t curve_config::point_field_t
|
||||
|
||||
extern "C" bool Eq(projective_t* point1, projective_t* point2)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BLS12_381::point_field_t::zero()) && (point1->y == BLS12_381::point_field_t::zero()) &&
|
||||
(point1->z == BLS12_381::point_field_t::zero())) &&
|
||||
!((point2->x == BLS12_381::point_field_t::zero()) && (point2->y == BLS12_381::point_field_t::zero()) &&
|
||||
(point2->z == BLS12_381::point_field_t::zero()));
|
||||
!((point1->x == point_field_t::zero()) && (point1->y == point_field_t::zero()) &&
|
||||
(point1->z == point_field_t::zero())) &&
|
||||
!((point2->x == point_field_t::zero()) && (point2->y == point_field_t::zero()) &&
|
||||
(point2->z == point_field_t::zero()));
|
||||
}
|
||||
|
||||
extern "C" bool eq_bls12_377(BLS12_377::projective_t* point1, BLS12_377::projective_t* point2)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BLS12_377::point_field_t::zero()) && (point1->y == BLS12_377::point_field_t::zero()) &&
|
||||
(point1->z == BLS12_377::point_field_t::zero())) &&
|
||||
!((point2->x == BLS12_377::point_field_t::zero()) && (point2->y == BLS12_377::point_field_t::zero()) &&
|
||||
(point2->z == BLS12_377::point_field_t::zero()));
|
||||
}
|
||||
extern "C" void ToAffine(projective_t* point, affine_t* point_out) { *point_out = projective_t::to_affine(*point); }
|
||||
|
||||
extern "C" bool eq_bn254(BN254::projective_t* point1, BN254::projective_t* point2)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BN254::point_field_t::zero()) && (point1->y == BN254::point_field_t::zero()) &&
|
||||
(point1->z == BN254::point_field_t::zero())) &&
|
||||
!((point2->x == BN254::point_field_t::zero()) && (point2->y == BN254::point_field_t::zero()) &&
|
||||
(point2->z == BN254::point_field_t::zero()));
|
||||
}
|
||||
extern "C" void GenerateProjectivePoints(projective_t* points, int size) { projective_t::RandHostMany(points, size); }
|
||||
|
||||
extern "C" void GenerateAffinePoints(affine_t* points, int size) { projective_t::RandHostManyAffine(points, size); }
|
||||
|
||||
#if defined(G2_DEFINED)
|
||||
extern "C" bool eq_g2_bls12_381(BLS12_381::g2_projective_t* point1, BLS12_381::g2_projective_t* point2)
|
||||
|
||||
#define g2_projective_t curve_config::g2_projective_t
|
||||
#define g2_affine_t curve_config::g2_affine_t
|
||||
#define g2_point_field_t curve_config::g2_point_field_t
|
||||
|
||||
extern "C" bool EqG2(g2_projective_t* point1, g2_projective_t* point2)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BLS12_381::g2_point_field_t::zero()) && (point1->y == BLS12_381::g2_point_field_t::zero()) &&
|
||||
(point1->z == BLS12_381::g2_point_field_t::zero())) &&
|
||||
!((point2->x == BLS12_381::g2_point_field_t::zero()) && (point2->y == BLS12_381::g2_point_field_t::zero()) &&
|
||||
(point2->z == BLS12_381::g2_point_field_t::zero()));
|
||||
!((point1->x == g2_point_field_t::zero()) && (point1->y == g2_point_field_t::zero()) &&
|
||||
(point1->z == g2_point_field_t::zero())) &&
|
||||
!((point2->x == g2_point_field_t::zero()) && (point2->y == g2_point_field_t::zero()) &&
|
||||
(point2->z == g2_point_field_t::zero()));
|
||||
}
|
||||
|
||||
extern "C" bool eq_g2_bls12_377(BLS12_377::g2_projective_t* point1, BLS12_377::g2_projective_t* point2)
|
||||
extern "C" void ToAffineG2(g2_projective_t* point, affine_t* point_out)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BLS12_377::g2_point_field_t::zero()) && (point1->y == BLS12_377::g2_point_field_t::zero()) &&
|
||||
(point1->z == BLS12_377::g2_point_field_t::zero())) &&
|
||||
!((point2->x == BLS12_377::g2_point_field_t::zero()) && (point2->y == BLS12_377::g2_point_field_t::zero()) &&
|
||||
(point2->z == BLS12_377::g2_point_field_t::zero()));
|
||||
*point_out = projective_t::to_affine(*point);
|
||||
}
|
||||
|
||||
extern "C" bool eq_g2_bn254(BN254::g2_projective_t* point1, BN254::g2_projective_t* point2)
|
||||
extern "C" void GenerateProjectivePointsG2(g2_projective_t* points, int size)
|
||||
{
|
||||
return (*point1 == *point2) &&
|
||||
!((point1->x == BN254::g2_point_field_t::zero()) && (point1->y == BN254::g2_point_field_t::zero()) &&
|
||||
(point1->z == BN254::g2_point_field_t::zero())) &&
|
||||
!((point2->x == BN254::g2_point_field_t::zero()) && (point2->y == BN254::g2_point_field_t::zero()) &&
|
||||
(point2->z == BN254::g2_point_field_t::zero()));
|
||||
g2_projective_t::RandHostMany(points, size);
|
||||
}
|
||||
#endif
|
||||
|
||||
extern "C" void GenerateAffinePointsG2(g2_affine_t* points, int size)
|
||||
{
|
||||
g2_projective_t::RandHostManyAffine(points, size);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -22,6 +22,16 @@ public:
|
||||
|
||||
static HOST_DEVICE_INLINE Projective from_affine(const Affine<FF>& point) { return {point.x, point.y, FF::one()}; }
|
||||
|
||||
static HOST_DEVICE_INLINE Projective ToMontgomery(const Projective& point)
|
||||
{
|
||||
return {FF::ToMontgomery(point.x), FF::ToMontgomery(point.y), FF::ToMontgomery(point.z)};
|
||||
}
|
||||
|
||||
static HOST_DEVICE_INLINE Projective FromMontgomery(const Projective& point)
|
||||
{
|
||||
return {FF::FromMontgomery(point.x), FF::FromMontgomery(point.y), FF::FromMontgomery(point.z)};
|
||||
}
|
||||
|
||||
static HOST_DEVICE_INLINE Projective generator() { return {FF::generator_x(), FF::generator_y(), FF::one()}; }
|
||||
|
||||
static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; }
|
||||
@@ -163,4 +173,16 @@ public:
|
||||
SCALAR_FF rand_scalar = SCALAR_FF::rand_host();
|
||||
return rand_scalar * generator();
|
||||
}
|
||||
|
||||
static void RandHostMany(Projective* out, int size)
|
||||
{
|
||||
for (int i = 0; i < size; i++)
|
||||
out[i] = (i % size < 100) ? rand_host() : out[i - 100];
|
||||
}
|
||||
|
||||
static void RandHostManyAffine(Affine<FF>* out, int size)
|
||||
{
|
||||
for (int i = 0; i < size; i++)
|
||||
out[i] = (i % size < 100) ? to_affine(rand_host()) : out[i - 100];
|
||||
}
|
||||
};
|
||||
|
||||
28
icicle/utils/error_handler.cu
Normal file
28
icicle/utils/error_handler.cu
Normal file
@@ -0,0 +1,28 @@
|
||||
#include <iostream>
|
||||
|
||||
template <typename T>
|
||||
void check(T err, const char* const func, const char* const file, const int line)
|
||||
{
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void checkLast(const char* const file, const int line)
|
||||
{
|
||||
cudaError_t err{cudaGetLastError()};
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void syncDevice(const char* const file, const int line)
|
||||
{
|
||||
cudaError_t err{cudaDeviceSynchronize()};
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << std::endl;
|
||||
}
|
||||
}
|
||||
@@ -1,32 +1,17 @@
|
||||
#pragma once
|
||||
#ifndef ERR_H
|
||||
#define ERR_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
|
||||
template <typename T>
|
||||
void check(T err, const char* const func, const char* const file, const int line)
|
||||
{
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
|
||||
}
|
||||
}
|
||||
void check(T err, const char* const func, const char* const file, const int line);
|
||||
|
||||
#define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__)
|
||||
void checkLast(const char* const file, const int line)
|
||||
{
|
||||
cudaError_t err{cudaGetLastError()};
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << std::endl;
|
||||
}
|
||||
}
|
||||
void checkLast(const char* const file, const int line);
|
||||
|
||||
#define CHECK_SYNC_DEVICE_ERROR() syncDevice(__FILE__, __LINE__)
|
||||
void syncDevice(const char* const file, const int line)
|
||||
{
|
||||
cudaError_t err{cudaDeviceSynchronize()};
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << std::endl;
|
||||
}
|
||||
}
|
||||
void syncDevice(const char* const file, const int line);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -2,23 +2,26 @@
|
||||
#ifndef MONT_H
|
||||
#define MONT_H
|
||||
|
||||
#include "utils_kernels.cuh"
|
||||
|
||||
namespace mont {
|
||||
|
||||
namespace {
|
||||
|
||||
#define MAX_THREADS_PER_BLOCK 256
|
||||
|
||||
// TODO (DmytroTym): do valid conversion for point types too
|
||||
template <typename E>
|
||||
int convert_montgomery(E* d_inout, int n, bool is_into, cudaStream_t stream)
|
||||
template <typename E, bool is_into>
|
||||
__global__ void MontgomeryKernel(E* inout, int n)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { inout[tid] = is_into ? E::ToMontgomery(inout[tid]) : E::FromMontgomery(inout[tid]); }
|
||||
}
|
||||
|
||||
template <typename E, bool is_into>
|
||||
int ConvertMontgomery(E* d_inout, int n, cudaStream_t stream)
|
||||
{
|
||||
// Set the grid and block dimensions
|
||||
int num_threads = MAX_THREADS_PER_BLOCK;
|
||||
int num_blocks = (n + num_threads - 1) / num_threads;
|
||||
E mont = is_into ? E::montgomery_r() : E::montgomery_r_inv();
|
||||
utils_internal::template_normalize_kernel<<<num_blocks, num_threads, 0, stream>>>(d_inout, mont, n);
|
||||
MontgomeryKernel<E, is_into><<<num_blocks, num_threads, 0, stream>>>(d_inout, n);
|
||||
|
||||
return 0; // TODO: void with propper error handling
|
||||
}
|
||||
@@ -26,15 +29,15 @@ namespace mont {
|
||||
} // namespace
|
||||
|
||||
template <typename E>
|
||||
int to_montgomery(E* d_inout, int n, cudaStream_t stream)
|
||||
int ToMontgomery(E* d_inout, int n, cudaStream_t stream)
|
||||
{
|
||||
return convert_montgomery(d_inout, n, true, stream);
|
||||
return ConvertMontgomery<E, true>(d_inout, n, stream);
|
||||
}
|
||||
|
||||
template <typename E>
|
||||
int from_montgomery(E* d_inout, int n, cudaStream_t stream)
|
||||
int FromMontgomery(E* d_inout, int n, cudaStream_t stream)
|
||||
{
|
||||
return convert_montgomery(d_inout, n, false, stream);
|
||||
return ConvertMontgomery<E, false>(d_inout, n, stream);
|
||||
}
|
||||
|
||||
} // namespace mont
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
#include "utils_kernels.cuh"
|
||||
|
||||
namespace utils_internal {
|
||||
// TODO: weird linking issue - only works in headers
|
||||
// template <typename E, typename S>
|
||||
// __global__ void NormalizeKernel(E* arr, S scalar, unsigned n)
|
||||
// {
|
||||
// int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (tid < n) { arr[tid] = scalar * arr[tid]; }
|
||||
// }
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void template_normalize_kernel(E* arr, S scalar, uint32_t n)
|
||||
__global__ void NormalizeKernel(E* arr, S scalar, int n)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { arr[tid] = scalar * arr[tid]; }
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void batchVectorMult(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size)
|
||||
__global__ void BatchMulKernel(E* element_vec, S* scalar_vec, int n_scalars, int batch_size)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < n_scalars * batch_size) {
|
||||
|
||||
@@ -4,17 +4,22 @@
|
||||
|
||||
namespace utils_internal {
|
||||
|
||||
/**
|
||||
* Multiply the elements of an input array by a scalar in-place. Used for normalization in iNTT.
|
||||
* @param arr input array.
|
||||
* @param n size of arr.
|
||||
* @param n_inv scalar of type S (scalar).
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
__global__ void template_normalize_kernel(E* arr, S scalar, uint32_t n);
|
||||
__global__ void NormalizeKernel(E* arr, S scalar, unsigned n)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { arr[tid] = scalar * arr[tid]; }
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void batchVectorMult(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size);
|
||||
__global__ void BatchMulKernel(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < n_scalars * batch_size) {
|
||||
int scalar_id = tid % n_scalars;
|
||||
element_vec[tid] = scalar_vec[scalar_id] * element_vec[tid];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace utils_internal
|
||||
|
||||
|
||||
@@ -1,34 +1,32 @@
|
||||
#include "lde.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../../curves/curve_config.cuh"
|
||||
#include "../../utils/device_context.cuh"
|
||||
#include "../../utils/mont.cuh"
|
||||
#include "../curves/curve_config.cuh"
|
||||
#include "device_context.cuh"
|
||||
#include "mont.cuh"
|
||||
|
||||
namespace lde {
|
||||
namespace vec_ops {
|
||||
|
||||
namespace {
|
||||
|
||||
#define MAX_THREADS_PER_BLOCK 256
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void mul_kernel(S* scalar_vec, E* element_vec, int n, E* result)
|
||||
__global__ void MulKernel(S* scalar_vec, E* element_vec, int n, E* result)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < n) { result[tid] = scalar_vec[tid] * element_vec[tid]; }
|
||||
}
|
||||
|
||||
template <typename E>
|
||||
__global__ void add_kernel(E* element_vec1, E* element_vec2, int n, E* result)
|
||||
__global__ void AddKernel(E* element_vec1, E* element_vec2, int n, E* result)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { result[tid] = element_vec1[tid] + element_vec2[tid]; }
|
||||
}
|
||||
|
||||
template <typename E>
|
||||
__global__ void sub_kernel(E* element_vec1, E* element_vec2, int n, E* result)
|
||||
__global__ void SubKernel(E* element_vec1, E* element_vec2, int n, E* result)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { result[tid] = element_vec1[tid] - element_vec2[tid]; }
|
||||
@@ -58,9 +56,9 @@ namespace lde {
|
||||
}
|
||||
|
||||
// Call the kernel to perform element-wise modular multiplication
|
||||
mul_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
MulKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
|
||||
if (is_montgomery) mont::from_montgomery(is_on_device ? result : d_result, n, ctx.stream);
|
||||
if (is_montgomery) mont::FromMontgomery(is_on_device ? result : d_result, n, ctx.stream);
|
||||
|
||||
if (!is_on_device) {
|
||||
cudaMemcpyAsync(result, d_result, n * sizeof(E), cudaMemcpyDeviceToHost, ctx.stream);
|
||||
@@ -93,7 +91,7 @@ namespace lde {
|
||||
}
|
||||
|
||||
// Call the kernel to perform element-wise addition
|
||||
add_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
AddKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
|
||||
|
||||
if (!is_on_device) {
|
||||
@@ -127,7 +125,7 @@ namespace lde {
|
||||
}
|
||||
|
||||
// Call the kernel to perform element-wise subtraction
|
||||
sub_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
SubKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
|
||||
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
|
||||
|
||||
if (!is_on_device) {
|
||||
@@ -190,4 +188,4 @@ namespace lde {
|
||||
return Sub<curve_config::scalar_t>(vec_a, vec_b, n, is_on_device, ctx, result);
|
||||
}
|
||||
|
||||
} // namespace lde
|
||||
} // namespace vec_ops
|
||||
@@ -2,15 +2,13 @@
|
||||
#ifndef LDE_H
|
||||
#define LDE_H
|
||||
|
||||
#include "../../utils/device_context.cuh"
|
||||
#include "device_context.cuh"
|
||||
|
||||
/**
|
||||
* @namespace lde
|
||||
* LDE (stands for low degree extension) contains [NTT](@ref ntt)-based methods for translating between coefficient and
|
||||
* evaluation domains of polynomials. It also contains methods for element-wise manipulation of vectors, which is useful
|
||||
* for working with polynomials in evaluation domain.
|
||||
* @namespace vec_ops
|
||||
* This namespace contains methods for performing element-wise arithmetic operations on vectors.
|
||||
*/
|
||||
namespace lde {
|
||||
namespace vec_ops {
|
||||
|
||||
/**
|
||||
* A function that multiplies two vectors element-wise.
|
||||
@@ -62,6 +60,6 @@ namespace lde {
|
||||
template <typename E>
|
||||
cudaError_t Sub(E* vec_a, E* vec_b, int n, bool is_on_device, device_context::DeviceContext ctx, E* result);
|
||||
|
||||
} // namespace lde
|
||||
} // namespace vec_ops
|
||||
|
||||
#endif
|
||||
@@ -1,329 +0,0 @@
|
||||
use std::ffi::c_uint;
|
||||
use ark_CURVE_NAME_L::{Fq as Fq_CURVE_NAME_U, Fr as Fr_CURVE_NAME_U, G1Affine as G1Affine_CURVE_NAME_U, G1Projective as G1Projective_CURVE_NAME_U};
|
||||
use ark_ec::AffineCurve;
|
||||
use ark_ff::{BigInteger_limbs_q, BigInteger_limbs_p, PrimeField};
|
||||
use std::mem::transmute;
|
||||
use ark_ff::Field;
|
||||
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}};
|
||||
use rustacuda_core::DeviceCopy;
|
||||
use rustacuda_derive::DeviceCopy;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field_CURVE_NAME_U<const NUM_LIMBS: usize> {
|
||||
pub s: [u32; NUM_LIMBS],
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_CURVE_NAME_U<NUM_LIMBS> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Default for Field_CURVE_NAME_U<NUM_LIMBS> {
|
||||
fn default() -> Self {
|
||||
Field_CURVE_NAME_U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Field_CURVE_NAME_U<NUM_LIMBS> {
|
||||
pub fn zero() -> Self {
|
||||
Field_CURVE_NAME_U {
|
||||
s: [0u32; NUM_LIMBS],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut s = [0u32; NUM_LIMBS];
|
||||
s[0] = 1;
|
||||
Field_CURVE_NAME_U { s }
|
||||
}
|
||||
|
||||
fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.s
|
||||
.iter()
|
||||
.map(|s| s.to_le_bytes().to_vec())
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS_CURVE_NAME_U: usize = limbs_q;
|
||||
pub const SCALAR_LIMBS_CURVE_NAME_U: usize = limbs_p;
|
||||
|
||||
pub type BaseField_CURVE_NAME_U = Field_CURVE_NAME_U<BASE_LIMBS_CURVE_NAME_U>;
|
||||
pub type ScalarField_CURVE_NAME_U = Field_CURVE_NAME_U<SCALAR_LIMBS_CURVE_NAME_U>;
|
||||
|
||||
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val.try_into().unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
impl BaseField_CURVE_NAME_U {
|
||||
pub fn limbs(&self) -> [u32; BASE_LIMBS_CURVE_NAME_U] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn from_limbs(value: &[u32]) -> Self {
|
||||
Self {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger_limbs_q {
|
||||
BigInteger_limbs_q::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger_limbs_q) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
}
|
||||
//
|
||||
|
||||
impl ScalarField_CURVE_NAME_U {
|
||||
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_CURVE_NAME_U] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger_limbs_p {
|
||||
BigInteger_limbs_p::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger_limbs_p) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
|
||||
pub fn to_ark_transmute(&self) -> BigInteger_limbs_p {
|
||||
unsafe { transmute(*self) }
|
||||
}
|
||||
|
||||
pub fn from_ark_transmute(v: BigInteger_limbs_p) -> ScalarField_CURVE_NAME_U {
|
||||
unsafe { transmute(v) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct Point_CURVE_NAME_U {
|
||||
pub x: BaseField_CURVE_NAME_U,
|
||||
pub y: BaseField_CURVE_NAME_U,
|
||||
pub z: BaseField_CURVE_NAME_U,
|
||||
}
|
||||
|
||||
impl Default for Point_CURVE_NAME_U {
|
||||
fn default() -> Self {
|
||||
Point_CURVE_NAME_U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_CURVE_NAME_U {
|
||||
pub fn zero() -> Self {
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::zero(),
|
||||
y: BaseField_CURVE_NAME_U::one(),
|
||||
z: BaseField_CURVE_NAME_U::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infinity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Projective_CURVE_NAME_U {
|
||||
//TODO: generic conversion
|
||||
self.to_ark_affine().into_projective()
|
||||
}
|
||||
|
||||
pub fn to_ark_affine(&self) -> G1Affine_CURVE_NAME_U {
|
||||
//TODO: generic conversion
|
||||
use std::ops::Mul;
|
||||
let proj_x_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.x.to_bytes_le());
|
||||
let proj_y_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.y.to_bytes_le());
|
||||
let proj_z_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.z.to_bytes_le());
|
||||
let inverse_z = proj_z_field.inverse().unwrap();
|
||||
let aff_x = proj_x_field.mul(inverse_z);
|
||||
let aff_y = proj_y_field.mul(inverse_z);
|
||||
G1Affine_CURVE_NAME_U::new(aff_x, aff_y, false)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: G1Projective_CURVE_NAME_U) -> Point_CURVE_NAME_U {
|
||||
let z_inv = ark.z.inverse().unwrap();
|
||||
let z_invsq = z_inv * z_inv;
|
||||
let z_invq3 = z_invsq * z_inv;
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark((ark.x * z_invsq).into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark((ark.y * z_invq3).into_repr()),
|
||||
z: BaseField_CURVE_NAME_U::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn eq_CURVE_NAME_L(point1: *const Point_CURVE_NAME_U, point2: *const Point_CURVE_NAME_U) -> c_uint;
|
||||
}
|
||||
|
||||
impl PartialEq for Point_CURVE_NAME_U {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
unsafe { eq_CURVE_NAME_L(self, other) != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct PointAffineNoInfinity_CURVE_NAME_U {
|
||||
pub x: BaseField_CURVE_NAME_U,
|
||||
pub y: BaseField_CURVE_NAME_U,
|
||||
}
|
||||
|
||||
impl Default for PointAffineNoInfinity_CURVE_NAME_U {
|
||||
fn default() -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::zero(),
|
||||
y: BaseField_CURVE_NAME_U::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PointAffineNoInfinity_CURVE_NAME_U {
|
||||
// TODO: generics
|
||||
///From u32 limbs x,y
|
||||
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(x),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(y),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limbs(&self) -> Vec<u32> {
|
||||
[self.x.limbs(), self.y.limbs()].concat()
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Point_CURVE_NAME_U {
|
||||
Point_CURVE_NAME_U {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: BaseField_CURVE_NAME_U::one(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Affine_CURVE_NAME_U {
|
||||
G1Affine_CURVE_NAME_U::new(Fq_CURVE_NAME_U::new(self.x.to_ark()), Fq_CURVE_NAME_U::new(self.y.to_ark()), false)
|
||||
}
|
||||
|
||||
pub fn to_ark_repr(&self) -> G1Affine_CURVE_NAME_U {
|
||||
G1Affine_CURVE_NAME_U::new(
|
||||
Fq_CURVE_NAME_U::from_repr(self.x.to_ark()).unwrap(),
|
||||
Fq_CURVE_NAME_U::from_repr(self.y.to_ark()).unwrap(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(p: &G1Affine_CURVE_NAME_U) -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark(p.x.into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark(p.y.into_repr()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_CURVE_NAME_U {
|
||||
// TODO: generics
|
||||
|
||||
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(x),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(y),
|
||||
},
|
||||
z: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(z),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_xy_limbs(value: &[u32]) -> Point_CURVE_NAME_U {
|
||||
let l = value.len();
|
||||
assert_eq!(l, 3 * BASE_LIMBS_CURVE_NAME_U, "length must be 3 * {}", BASE_LIMBS_CURVE_NAME_U);
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: value[..BASE_LIMBS_CURVE_NAME_U].try_into().unwrap(),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: value[BASE_LIMBS_CURVE_NAME_U..BASE_LIMBS_CURVE_NAME_U * 2].try_into().unwrap(),
|
||||
},
|
||||
z: BaseField_CURVE_NAME_U {
|
||||
s: value[BASE_LIMBS_CURVE_NAME_U * 2..].try_into().unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
|
||||
let ark_affine = self.to_ark_affine();
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark(ark_affine.x.into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark(ark_affine.y.into_repr()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_CURVE_NAME_U {
|
||||
pub fn from_limbs(value: &[u32]) -> ScalarField_CURVE_NAME_U {
|
||||
ScalarField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ark_CURVE_NAME_L::{Fr as Fr_CURVE_NAME_U};
|
||||
|
||||
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}, curves::CURVE_NAME_L::{Point_CURVE_NAME_U, ScalarField_CURVE_NAME_U}};
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
|
||||
let scalar = ScalarField_CURVE_NAME_U::from_limbs(&limbs);
|
||||
assert_eq!(
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute(),
|
||||
"{:08X?} {:08X?}",
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn test_point_equality() {
|
||||
let left = Point_CURVE_NAME_U::zero();
|
||||
let right = Point_CURVE_NAME_U::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = Point_CURVE_NAME_U::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
|
||||
assert_eq!(left, right);
|
||||
let right = Point_CURVE_NAME_U::from_limbs(
|
||||
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
&[0; 12],
|
||||
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
);
|
||||
assert!(left != right);
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
use std::ffi::c_uint;
|
||||
use ark_CURVE_NAME_L::{Fq as Fq_CURVE_NAME_U, Fr as Fr_CURVE_NAME_U, G1Affine as G1Affine_CURVE_NAME_U, G1Projective as G1Projective_CURVE_NAME_U};
|
||||
use ark_ec::AffineCurve;
|
||||
use ark_ff::{BigInteger_limbs_p, PrimeField};
|
||||
use std::mem::transmute;
|
||||
use ark_ff::Field;
|
||||
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}};
|
||||
use rustacuda_core::DeviceCopy;
|
||||
use rustacuda_derive::DeviceCopy;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field_CURVE_NAME_U<const NUM_LIMBS: usize> {
|
||||
pub s: [u32; NUM_LIMBS],
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_CURVE_NAME_U<NUM_LIMBS> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Default for Field_CURVE_NAME_U<NUM_LIMBS> {
|
||||
fn default() -> Self {
|
||||
Field_CURVE_NAME_U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Field_CURVE_NAME_U<NUM_LIMBS> {
|
||||
pub fn zero() -> Self {
|
||||
Field_CURVE_NAME_U {
|
||||
s: [0u32; NUM_LIMBS],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut s = [0u32; NUM_LIMBS];
|
||||
s[0] = 1;
|
||||
Field_CURVE_NAME_U { s }
|
||||
}
|
||||
|
||||
fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.s
|
||||
.iter()
|
||||
.map(|s| s.to_le_bytes().to_vec())
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS_CURVE_NAME_U: usize = limbs_p;
|
||||
pub const SCALAR_LIMBS_CURVE_NAME_U: usize = limbs_p;
|
||||
|
||||
pub type BaseField_CURVE_NAME_U = Field_CURVE_NAME_U<BASE_LIMBS_CURVE_NAME_U>;
|
||||
pub type ScalarField_CURVE_NAME_U = Field_CURVE_NAME_U<SCALAR_LIMBS_CURVE_NAME_U>;
|
||||
|
||||
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val.try_into().unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_CURVE_NAME_U {
|
||||
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_CURVE_NAME_U] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger_limbs_p {
|
||||
BigInteger_limbs_p::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger_limbs_p) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
|
||||
pub fn to_ark_transmute(&self) -> BigInteger_limbs_p {
|
||||
unsafe { transmute(*self) }
|
||||
}
|
||||
|
||||
pub fn from_ark_transmute(v: BigInteger_limbs_p) -> ScalarField_CURVE_NAME_U {
|
||||
unsafe { transmute(v) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct Point_CURVE_NAME_U {
|
||||
pub x: BaseField_CURVE_NAME_U,
|
||||
pub y: BaseField_CURVE_NAME_U,
|
||||
pub z: BaseField_CURVE_NAME_U,
|
||||
}
|
||||
|
||||
impl Default for Point_CURVE_NAME_U {
|
||||
fn default() -> Self {
|
||||
Point_CURVE_NAME_U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_CURVE_NAME_U {
|
||||
pub fn zero() -> Self {
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::zero(),
|
||||
y: BaseField_CURVE_NAME_U::one(),
|
||||
z: BaseField_CURVE_NAME_U::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infinity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Projective_CURVE_NAME_U {
|
||||
//TODO: generic conversion
|
||||
self.to_ark_affine().into_projective()
|
||||
}
|
||||
|
||||
pub fn to_ark_affine(&self) -> G1Affine_CURVE_NAME_U {
|
||||
//TODO: generic conversion
|
||||
use ark_ff::Field;
|
||||
use std::ops::Mul;
|
||||
let proj_x_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.x.to_bytes_le());
|
||||
let proj_y_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.y.to_bytes_le());
|
||||
let proj_z_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.z.to_bytes_le());
|
||||
let inverse_z = proj_z_field.inverse().unwrap();
|
||||
let aff_x = proj_x_field.mul(inverse_z);
|
||||
let aff_y = proj_y_field.mul(inverse_z);
|
||||
G1Affine_CURVE_NAME_U::new(aff_x, aff_y, false)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: G1Projective_CURVE_NAME_U) -> Point_CURVE_NAME_U {
|
||||
use ark_ff::Field;
|
||||
let z_inv = ark.z.inverse().unwrap();
|
||||
let z_invsq = z_inv * z_inv;
|
||||
let z_invq3 = z_invsq * z_inv;
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark((ark.x * z_invsq).into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark((ark.y * z_invq3).into_repr()),
|
||||
z: BaseField_CURVE_NAME_U::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn eq_CURVE_NAME_L(point1: *const Point_CURVE_NAME_U, point2: *const Point_CURVE_NAME_U) -> c_uint;
|
||||
}
|
||||
|
||||
impl PartialEq for Point_CURVE_NAME_U {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
unsafe { eq_CURVE_NAME_L(self, other) != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct PointAffineNoInfinity_CURVE_NAME_U {
|
||||
pub x: BaseField_CURVE_NAME_U,
|
||||
pub y: BaseField_CURVE_NAME_U,
|
||||
}
|
||||
|
||||
impl Default for PointAffineNoInfinity_CURVE_NAME_U {
|
||||
fn default() -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::zero(),
|
||||
y: BaseField_CURVE_NAME_U::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PointAffineNoInfinity_CURVE_NAME_U {
|
||||
// TODO: generics
|
||||
///From u32 limbs x,y
|
||||
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(x),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(y),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limbs(&self) -> Vec<u32> {
|
||||
[self.x.limbs(), self.y.limbs()].concat()
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Point_CURVE_NAME_U {
|
||||
Point_CURVE_NAME_U {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: BaseField_CURVE_NAME_U::one(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Affine_CURVE_NAME_U {
|
||||
G1Affine_CURVE_NAME_U::new(Fq_CURVE_NAME_U::new(self.x.to_ark()), Fq_CURVE_NAME_U::new(self.y.to_ark()), false)
|
||||
}
|
||||
|
||||
pub fn to_ark_repr(&self) -> G1Affine_CURVE_NAME_U {
|
||||
G1Affine_CURVE_NAME_U::new(
|
||||
Fq_CURVE_NAME_U::from_repr(self.x.to_ark()).unwrap(),
|
||||
Fq_CURVE_NAME_U::from_repr(self.y.to_ark()).unwrap(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(p: &G1Affine_CURVE_NAME_U) -> Self {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark(p.x.into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark(p.y.into_repr()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_CURVE_NAME_U {
|
||||
// TODO: generics
|
||||
|
||||
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(x),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(y),
|
||||
},
|
||||
z: BaseField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(z),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_xy_limbs(value: &[u32]) -> Point_CURVE_NAME_U {
|
||||
let l = value.len();
|
||||
assert_eq!(l, 3 * BASE_LIMBS_CURVE_NAME_U, "length must be 3 * {}", BASE_LIMBS_CURVE_NAME_U);
|
||||
Point_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U {
|
||||
s: value[..BASE_LIMBS_CURVE_NAME_U].try_into().unwrap(),
|
||||
},
|
||||
y: BaseField_CURVE_NAME_U {
|
||||
s: value[BASE_LIMBS_CURVE_NAME_U..BASE_LIMBS_CURVE_NAME_U * 2].try_into().unwrap(),
|
||||
},
|
||||
z: BaseField_CURVE_NAME_U {
|
||||
s: value[BASE_LIMBS_CURVE_NAME_U * 2..].try_into().unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
|
||||
let ark_affine = self.to_ark_affine();
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: BaseField_CURVE_NAME_U::from_ark(ark_affine.x.into_repr()),
|
||||
y: BaseField_CURVE_NAME_U::from_ark(ark_affine.y.into_repr()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
|
||||
PointAffineNoInfinity_CURVE_NAME_U {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_CURVE_NAME_U {
|
||||
pub fn from_limbs(value: &[u32]) -> ScalarField_CURVE_NAME_U {
|
||||
ScalarField_CURVE_NAME_U {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ark_CURVE_NAME_L::{Fr as Fr_CURVE_NAME_U};
|
||||
|
||||
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}, curves::CURVE_NAME_L::{Point_CURVE_NAME_U, ScalarField_CURVE_NAME_U}};
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
|
||||
let scalar = ScalarField_CURVE_NAME_U::from_limbs(&limbs);
|
||||
assert_eq!(
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute(),
|
||||
"{:08X?} {:08X?}",
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn test_point_equality() {
|
||||
let left = Point_CURVE_NAME_U::zero();
|
||||
let right = Point_CURVE_NAME_U::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = Point_CURVE_NAME_U::from_limbs(&[0; 8], &[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8]);
|
||||
assert_eq!(left, right);
|
||||
let right = Point_CURVE_NAME_U::from_limbs(
|
||||
&[2, 0, 0, 0, 0, 0, 0, 0],
|
||||
&[0; 8],
|
||||
&[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
);
|
||||
assert!(left != right);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,385 +0,0 @@
|
||||
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
|
||||
use ark_bls12_377::{Fq as Fq_BLS12_377, G1Affine as G1Affine_BLS12_377, G1Projective as G1Projective_BLS12_377};
|
||||
use ark_ec::AffineCurve;
|
||||
use ark_ff::Field;
|
||||
use ark_ff::{BigInteger256, BigInteger384, PrimeField};
|
||||
use rustacuda_core::DeviceCopy;
|
||||
use rustacuda_derive::DeviceCopy;
|
||||
use std::ffi::c_uint;
|
||||
use std::mem::transmute;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field_BLS12_377<const NUM_LIMBS: usize> {
|
||||
pub s: [u32; NUM_LIMBS],
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BLS12_377<NUM_LIMBS> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Default for Field_BLS12_377<NUM_LIMBS> {
|
||||
fn default() -> Self {
|
||||
Field_BLS12_377::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Field_BLS12_377<NUM_LIMBS> {
|
||||
pub fn zero() -> Self {
|
||||
Field_BLS12_377 { s: [0u32; NUM_LIMBS] }
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut s = [0u32; NUM_LIMBS];
|
||||
s[0] = 1;
|
||||
Field_BLS12_377 { s }
|
||||
}
|
||||
|
||||
fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.s
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.to_le_bytes()
|
||||
.to_vec()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS_BLS12_377: usize = 12;
|
||||
pub const SCALAR_LIMBS_BLS12_377: usize = 8;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type BaseField_BLS12_377 = Field_BLS12_377<BASE_LIMBS_BLS12_377>;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type ScalarField_BLS12_377 = Field_BLS12_377<SCALAR_LIMBS_BLS12_377>;
|
||||
|
||||
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
impl BaseField_BLS12_377 {
|
||||
pub fn limbs(&self) -> [u32; BASE_LIMBS_BLS12_377] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn from_limbs(value: &[u32]) -> Self {
|
||||
Self {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger384 {
|
||||
BigInteger384::new(
|
||||
u32_vec_to_u64_vec(&self.limbs())
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger384) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BLS12_377 {
|
||||
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BLS12_377] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger256 {
|
||||
BigInteger256::new(
|
||||
u32_vec_to_u64_vec(&self.limbs())
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger256) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
|
||||
pub fn to_ark_transmute(&self) -> BigInteger256 {
|
||||
unsafe { transmute(*self) }
|
||||
}
|
||||
|
||||
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BLS12_377 {
|
||||
unsafe { transmute(v) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct Point_BLS12_377 {
|
||||
pub x: BaseField_BLS12_377,
|
||||
pub y: BaseField_BLS12_377,
|
||||
pub z: BaseField_BLS12_377,
|
||||
}
|
||||
|
||||
impl Default for Point_BLS12_377 {
|
||||
fn default() -> Self {
|
||||
Point_BLS12_377::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BLS12_377 {
|
||||
pub fn zero() -> Self {
|
||||
Point_BLS12_377 {
|
||||
x: BaseField_BLS12_377::zero(),
|
||||
y: BaseField_BLS12_377::one(),
|
||||
z: BaseField_BLS12_377::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infinity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Projective_BLS12_377 {
|
||||
//TODO: generic conversion
|
||||
self.to_ark_affine()
|
||||
.into_projective()
|
||||
}
|
||||
|
||||
pub fn to_ark_affine(&self) -> G1Affine_BLS12_377 {
|
||||
//TODO: generic conversion
|
||||
use std::ops::Mul;
|
||||
let proj_x_field = Fq_BLS12_377::from_le_bytes_mod_order(
|
||||
&self
|
||||
.x
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_y_field = Fq_BLS12_377::from_le_bytes_mod_order(
|
||||
&self
|
||||
.y
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_z_field = Fq_BLS12_377::from_le_bytes_mod_order(
|
||||
&self
|
||||
.z
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let inverse_z = proj_z_field
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let aff_x = proj_x_field.mul(inverse_z);
|
||||
let aff_y = proj_y_field.mul(inverse_z);
|
||||
G1Affine_BLS12_377::new(aff_x, aff_y, false)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: G1Projective_BLS12_377) -> Point_BLS12_377 {
|
||||
let z_inv = ark
|
||||
.z
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let z_invsq = z_inv * z_inv;
|
||||
let z_invq3 = z_invsq * z_inv;
|
||||
Point_BLS12_377 {
|
||||
x: BaseField_BLS12_377::from_ark((ark.x * z_invsq).into_repr()),
|
||||
y: BaseField_BLS12_377::from_ark((ark.y * z_invq3).into_repr()),
|
||||
z: BaseField_BLS12_377::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn eq_bls12_377(point1: *const Point_BLS12_377, point2: *const Point_BLS12_377) -> c_uint;
|
||||
}
|
||||
|
||||
impl PartialEq for Point_BLS12_377 {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
unsafe { eq_bls12_377(self, other) != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct PointAffineNoInfinity_BLS12_377 {
|
||||
pub x: BaseField_BLS12_377,
|
||||
pub y: BaseField_BLS12_377,
|
||||
}
|
||||
|
||||
impl Default for PointAffineNoInfinity_BLS12_377 {
|
||||
fn default() -> Self {
|
||||
PointAffineNoInfinity_BLS12_377 {
|
||||
x: BaseField_BLS12_377::zero(),
|
||||
y: BaseField_BLS12_377::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PointAffineNoInfinity_BLS12_377 {
|
||||
// TODO: generics
|
||||
///From u32 limbs x,y
|
||||
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
PointAffineNoInfinity_BLS12_377 {
|
||||
x: BaseField_BLS12_377 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BLS12_377 { s: get_fixed_limbs(y) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limbs(&self) -> Vec<u32> {
|
||||
[
|
||||
self.x
|
||||
.limbs(),
|
||||
self.y
|
||||
.limbs(),
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Point_BLS12_377 {
|
||||
Point_BLS12_377 {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: BaseField_BLS12_377::one(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Affine_BLS12_377 {
|
||||
G1Affine_BLS12_377::new(
|
||||
Fq_BLS12_377::new(
|
||||
self.x
|
||||
.to_ark(),
|
||||
),
|
||||
Fq_BLS12_377::new(
|
||||
self.y
|
||||
.to_ark(),
|
||||
),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn to_ark_repr(&self) -> G1Affine_BLS12_377 {
|
||||
G1Affine_BLS12_377::new(
|
||||
Fq_BLS12_377::from_repr(
|
||||
self.x
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
Fq_BLS12_377::from_repr(
|
||||
self.y
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(p: &G1Affine_BLS12_377) -> Self {
|
||||
PointAffineNoInfinity_BLS12_377 {
|
||||
x: BaseField_BLS12_377::from_ark(p.x.into_repr()),
|
||||
y: BaseField_BLS12_377::from_ark(p.y.into_repr()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BLS12_377 {
|
||||
// TODO: generics
|
||||
|
||||
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Point_BLS12_377 {
|
||||
x: BaseField_BLS12_377 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BLS12_377 { s: get_fixed_limbs(y) },
|
||||
z: BaseField_BLS12_377 { s: get_fixed_limbs(z) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_xy_limbs(value: &[u32]) -> Point_BLS12_377 {
|
||||
let l = value.len();
|
||||
assert_eq!(
|
||||
l,
|
||||
3 * BASE_LIMBS_BLS12_377,
|
||||
"length must be 3 * {}",
|
||||
BASE_LIMBS_BLS12_377
|
||||
);
|
||||
Point_BLS12_377 {
|
||||
x: BaseField_BLS12_377 {
|
||||
s: value[..BASE_LIMBS_BLS12_377]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
y: BaseField_BLS12_377 {
|
||||
s: value[BASE_LIMBS_BLS12_377..BASE_LIMBS_BLS12_377 * 2]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
z: BaseField_BLS12_377 {
|
||||
s: value[BASE_LIMBS_BLS12_377 * 2..]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> PointAffineNoInfinity_BLS12_377 {
|
||||
let ark_affine = self.to_ark_affine();
|
||||
PointAffineNoInfinity_BLS12_377 {
|
||||
x: BaseField_BLS12_377::from_ark(
|
||||
ark_affine
|
||||
.x
|
||||
.into_repr(),
|
||||
),
|
||||
y: BaseField_BLS12_377::from_ark(
|
||||
ark_affine
|
||||
.y
|
||||
.into_repr(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BLS12_377 {
|
||||
PointAffineNoInfinity_BLS12_377 { x: self.x, y: self.y }
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BLS12_377 {
|
||||
pub fn from_limbs(value: &[u32]) -> ScalarField_BLS12_377 {
|
||||
ScalarField_BLS12_377 {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::curves::bls12_377::{Point_BLS12_377, ScalarField_BLS12_377};
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
|
||||
let scalar = ScalarField_BLS12_377::from_limbs(&limbs);
|
||||
assert_eq!(
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute(),
|
||||
"{:08X?} {:08X?}",
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn test_point_equality() {
|
||||
let left = Point_BLS12_377::zero();
|
||||
let right = Point_BLS12_377::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BLS12_377::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BLS12_377::from_limbs(
|
||||
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
&[0; 12],
|
||||
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
);
|
||||
assert!(left != right);
|
||||
}
|
||||
}
|
||||
@@ -1,407 +0,0 @@
|
||||
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
|
||||
use ark_bls12_381::{Fq as Fq_BLS12_381, G1Affine as G1Affine_BLS12_381, G1Projective as G1Projective_BLS12_381};
|
||||
use ark_ec::AffineCurve;
|
||||
use ark_ff::Field;
|
||||
use ark_ff::{BigInteger256, BigInteger384, PrimeField};
|
||||
use rustacuda_core::DeviceCopy;
|
||||
use rustacuda_derive::DeviceCopy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ffi::c_uint;
|
||||
use std::mem::transmute;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field_BLS12_381<const NUM_LIMBS: usize> {
|
||||
pub s: [u32; NUM_LIMBS],
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BLS12_381<NUM_LIMBS> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Default for Field_BLS12_381<NUM_LIMBS> {
|
||||
fn default() -> Self {
|
||||
Field_BLS12_381::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Field_BLS12_381<NUM_LIMBS> {
|
||||
pub fn zero() -> Self {
|
||||
Field_BLS12_381 { s: [0u32; NUM_LIMBS] }
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut s = [0u32; NUM_LIMBS];
|
||||
s[0] = 1;
|
||||
Field_BLS12_381 { s }
|
||||
}
|
||||
|
||||
fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.s
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.to_le_bytes()
|
||||
.to_vec()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS_BLS12_381: usize = 12;
|
||||
pub const SCALAR_LIMBS_BLS12_381: usize = 8;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type BaseField_BLS12_381 = Field_BLS12_381<BASE_LIMBS_BLS12_381>;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type ScalarField_BLS12_381 = Field_BLS12_381<SCALAR_LIMBS_BLS12_381>;
|
||||
|
||||
impl Serialize for ScalarField_BLS12_381 {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
self.s
|
||||
.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ScalarField_BLS12_381 {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = <[u32; SCALAR_LIMBS_BLS12_381] as Deserialize<'de>>::deserialize(deserializer)?;
|
||||
Ok(ScalarField_BLS12_381 { s })
|
||||
}
|
||||
}
|
||||
|
||||
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
impl BaseField_BLS12_381 {
|
||||
pub fn limbs(&self) -> [u32; BASE_LIMBS_BLS12_381] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn from_limbs(value: &[u32]) -> Self {
|
||||
Self {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger384 {
|
||||
BigInteger384::new(
|
||||
u32_vec_to_u64_vec(&self.limbs())
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger384) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BLS12_381 {
|
||||
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BLS12_381] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger256 {
|
||||
BigInteger256::new(
|
||||
u32_vec_to_u64_vec(&self.limbs())
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger256) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
|
||||
pub fn to_ark_transmute(&self) -> BigInteger256 {
|
||||
unsafe { transmute(*self) }
|
||||
}
|
||||
|
||||
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BLS12_381 {
|
||||
unsafe { transmute(v) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct Point_BLS12_381 {
|
||||
pub x: BaseField_BLS12_381,
|
||||
pub y: BaseField_BLS12_381,
|
||||
pub z: BaseField_BLS12_381,
|
||||
}
|
||||
|
||||
impl Default for Point_BLS12_381 {
|
||||
fn default() -> Self {
|
||||
Point_BLS12_381::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BLS12_381 {
|
||||
pub fn zero() -> Self {
|
||||
Point_BLS12_381 {
|
||||
x: BaseField_BLS12_381::zero(),
|
||||
y: BaseField_BLS12_381::one(),
|
||||
z: BaseField_BLS12_381::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infinity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Projective_BLS12_381 {
|
||||
//TODO: generic conversion
|
||||
self.to_ark_affine()
|
||||
.into_projective()
|
||||
}
|
||||
|
||||
pub fn to_ark_affine(&self) -> G1Affine_BLS12_381 {
|
||||
//TODO: generic conversion
|
||||
use std::ops::Mul;
|
||||
let proj_x_field = Fq_BLS12_381::from_le_bytes_mod_order(
|
||||
&self
|
||||
.x
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_y_field = Fq_BLS12_381::from_le_bytes_mod_order(
|
||||
&self
|
||||
.y
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_z_field = Fq_BLS12_381::from_le_bytes_mod_order(
|
||||
&self
|
||||
.z
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let inverse_z = proj_z_field
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let aff_x = proj_x_field.mul(inverse_z);
|
||||
let aff_y = proj_y_field.mul(inverse_z);
|
||||
G1Affine_BLS12_381::new(aff_x, aff_y, false)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: G1Projective_BLS12_381) -> Point_BLS12_381 {
|
||||
let z_inv = ark
|
||||
.z
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let z_invsq = z_inv * z_inv;
|
||||
let z_invq3 = z_invsq * z_inv;
|
||||
Point_BLS12_381 {
|
||||
x: BaseField_BLS12_381::from_ark((ark.x * z_invsq).into_repr()),
|
||||
y: BaseField_BLS12_381::from_ark((ark.y * z_invq3).into_repr()),
|
||||
z: BaseField_BLS12_381::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn eq_bls12_381(point1: *const Point_BLS12_381, point2: *const Point_BLS12_381) -> c_uint;
|
||||
}
|
||||
|
||||
impl PartialEq for Point_BLS12_381 {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
unsafe { eq_bls12_381(self, other) != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct PointAffineNoInfinity_BLS12_381 {
|
||||
pub x: BaseField_BLS12_381,
|
||||
pub y: BaseField_BLS12_381,
|
||||
}
|
||||
|
||||
impl Default for PointAffineNoInfinity_BLS12_381 {
|
||||
fn default() -> Self {
|
||||
PointAffineNoInfinity_BLS12_381 {
|
||||
x: BaseField_BLS12_381::zero(),
|
||||
y: BaseField_BLS12_381::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PointAffineNoInfinity_BLS12_381 {
|
||||
// TODO: generics
|
||||
///From u32 limbs x,y
|
||||
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
PointAffineNoInfinity_BLS12_381 {
|
||||
x: BaseField_BLS12_381 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BLS12_381 { s: get_fixed_limbs(y) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limbs(&self) -> Vec<u32> {
|
||||
[
|
||||
self.x
|
||||
.limbs(),
|
||||
self.y
|
||||
.limbs(),
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Point_BLS12_381 {
|
||||
Point_BLS12_381 {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: BaseField_BLS12_381::one(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Affine_BLS12_381 {
|
||||
G1Affine_BLS12_381::new(
|
||||
Fq_BLS12_381::new(
|
||||
self.x
|
||||
.to_ark(),
|
||||
),
|
||||
Fq_BLS12_381::new(
|
||||
self.y
|
||||
.to_ark(),
|
||||
),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn to_ark_repr(&self) -> G1Affine_BLS12_381 {
|
||||
G1Affine_BLS12_381::new(
|
||||
Fq_BLS12_381::from_repr(
|
||||
self.x
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
Fq_BLS12_381::from_repr(
|
||||
self.y
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(p: &G1Affine_BLS12_381) -> Self {
|
||||
PointAffineNoInfinity_BLS12_381 {
|
||||
x: BaseField_BLS12_381::from_ark(p.x.into_repr()),
|
||||
y: BaseField_BLS12_381::from_ark(p.y.into_repr()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BLS12_381 {
|
||||
// TODO: generics
|
||||
|
||||
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Point_BLS12_381 {
|
||||
x: BaseField_BLS12_381 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BLS12_381 { s: get_fixed_limbs(y) },
|
||||
z: BaseField_BLS12_381 { s: get_fixed_limbs(z) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_xy_limbs(value: &[u32]) -> Point_BLS12_381 {
|
||||
let l = value.len();
|
||||
assert_eq!(
|
||||
l,
|
||||
3 * BASE_LIMBS_BLS12_381,
|
||||
"length must be 3 * {}",
|
||||
BASE_LIMBS_BLS12_381
|
||||
);
|
||||
Point_BLS12_381 {
|
||||
x: BaseField_BLS12_381 {
|
||||
s: value[..BASE_LIMBS_BLS12_381]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
y: BaseField_BLS12_381 {
|
||||
s: value[BASE_LIMBS_BLS12_381..BASE_LIMBS_BLS12_381 * 2]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
z: BaseField_BLS12_381 {
|
||||
s: value[BASE_LIMBS_BLS12_381 * 2..]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> PointAffineNoInfinity_BLS12_381 {
|
||||
let ark_affine = self.to_ark_affine();
|
||||
PointAffineNoInfinity_BLS12_381 {
|
||||
x: BaseField_BLS12_381::from_ark(
|
||||
ark_affine
|
||||
.x
|
||||
.into_repr(),
|
||||
),
|
||||
y: BaseField_BLS12_381::from_ark(
|
||||
ark_affine
|
||||
.y
|
||||
.into_repr(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BLS12_381 {
|
||||
PointAffineNoInfinity_BLS12_381 { x: self.x, y: self.y }
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BLS12_381 {
|
||||
pub fn from_limbs(value: &[u32]) -> ScalarField_BLS12_381 {
|
||||
ScalarField_BLS12_381 {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::curves::bls12_381::{Point_BLS12_381, ScalarField_BLS12_381};
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
|
||||
let scalar = ScalarField_BLS12_381::from_limbs(&limbs);
|
||||
assert_eq!(
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute(),
|
||||
"{:08X?} {:08X?}",
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn test_point_equality() {
|
||||
let left = Point_BLS12_381::zero();
|
||||
let right = Point_BLS12_381::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BLS12_381::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BLS12_381::from_limbs(
|
||||
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
&[0; 12],
|
||||
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
);
|
||||
assert!(left != right);
|
||||
}
|
||||
}
|
||||
@@ -1,353 +0,0 @@
|
||||
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
|
||||
use ark_bn254::{Fq as Fq_BN254, G1Affine as G1Affine_BN254, G1Projective as G1Projective_BN254};
|
||||
use ark_ec::AffineCurve;
|
||||
use ark_ff::Field;
|
||||
use ark_ff::{BigInteger256, PrimeField};
|
||||
use rustacuda_core::DeviceCopy;
|
||||
use rustacuda_derive::DeviceCopy;
|
||||
use std::ffi::c_uint;
|
||||
use std::mem::transmute;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field_BN254<const NUM_LIMBS: usize> {
|
||||
pub s: [u32; NUM_LIMBS],
|
||||
}
|
||||
|
||||
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BN254<NUM_LIMBS> {}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Default for Field_BN254<NUM_LIMBS> {
|
||||
fn default() -> Self {
|
||||
Field_BN254::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize> Field_BN254<NUM_LIMBS> {
|
||||
pub fn zero() -> Self {
|
||||
Field_BN254 { s: [0u32; NUM_LIMBS] }
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut s = [0u32; NUM_LIMBS];
|
||||
s[0] = 1;
|
||||
Field_BN254 { s }
|
||||
}
|
||||
|
||||
fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.s
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.to_le_bytes()
|
||||
.to_vec()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS_BN254: usize = 8;
|
||||
pub const SCALAR_LIMBS_BN254: usize = 8;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type BaseField_BN254 = Field_BN254<BASE_LIMBS_BN254>;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type ScalarField_BN254 = Field_BN254<SCALAR_LIMBS_BN254>;
|
||||
|
||||
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BN254 {
|
||||
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BN254] {
|
||||
self.s
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> BigInteger256 {
|
||||
BigInteger256::new(
|
||||
u32_vec_to_u64_vec(&self.limbs())
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: BigInteger256) -> Self {
|
||||
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
|
||||
}
|
||||
|
||||
pub fn to_ark_transmute(&self) -> BigInteger256 {
|
||||
unsafe { transmute(*self) }
|
||||
}
|
||||
|
||||
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BN254 {
|
||||
unsafe { transmute(v) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct Point_BN254 {
|
||||
pub x: BaseField_BN254,
|
||||
pub y: BaseField_BN254,
|
||||
pub z: BaseField_BN254,
|
||||
}
|
||||
|
||||
impl Default for Point_BN254 {
|
||||
fn default() -> Self {
|
||||
Point_BN254::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BN254 {
|
||||
pub fn zero() -> Self {
|
||||
Point_BN254 {
|
||||
x: BaseField_BN254::zero(),
|
||||
y: BaseField_BN254::one(),
|
||||
z: BaseField_BN254::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn infinity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Projective_BN254 {
|
||||
//TODO: generic conversion
|
||||
self.to_ark_affine()
|
||||
.into_projective()
|
||||
}
|
||||
|
||||
pub fn to_ark_affine(&self) -> G1Affine_BN254 {
|
||||
//TODO: generic conversion
|
||||
use std::ops::Mul;
|
||||
let proj_x_field = Fq_BN254::from_le_bytes_mod_order(
|
||||
&self
|
||||
.x
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_y_field = Fq_BN254::from_le_bytes_mod_order(
|
||||
&self
|
||||
.y
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let proj_z_field = Fq_BN254::from_le_bytes_mod_order(
|
||||
&self
|
||||
.z
|
||||
.to_bytes_le(),
|
||||
);
|
||||
let inverse_z = proj_z_field
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let aff_x = proj_x_field.mul(inverse_z);
|
||||
let aff_y = proj_y_field.mul(inverse_z);
|
||||
G1Affine_BN254::new(aff_x, aff_y, false)
|
||||
}
|
||||
|
||||
pub fn from_ark(ark: G1Projective_BN254) -> Point_BN254 {
|
||||
let z_inv = ark
|
||||
.z
|
||||
.inverse()
|
||||
.unwrap();
|
||||
let z_invsq = z_inv * z_inv;
|
||||
let z_invq3 = z_invsq * z_inv;
|
||||
Point_BN254 {
|
||||
x: BaseField_BN254::from_ark((ark.x * z_invsq).into_repr()),
|
||||
y: BaseField_BN254::from_ark((ark.y * z_invq3).into_repr()),
|
||||
z: BaseField_BN254::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn eq_bn254(point1: *const Point_BN254, point2: *const Point_BN254) -> c_uint;
|
||||
}
|
||||
|
||||
impl PartialEq for Point_BN254 {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
unsafe { eq_bn254(self, other) != 0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
|
||||
#[repr(C)]
|
||||
pub struct PointAffineNoInfinity_BN254 {
|
||||
pub x: BaseField_BN254,
|
||||
pub y: BaseField_BN254,
|
||||
}
|
||||
|
||||
impl Default for PointAffineNoInfinity_BN254 {
|
||||
fn default() -> Self {
|
||||
PointAffineNoInfinity_BN254 {
|
||||
x: BaseField_BN254::zero(),
|
||||
y: BaseField_BN254::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PointAffineNoInfinity_BN254 {
|
||||
// TODO: generics
|
||||
///From u32 limbs x,y
|
||||
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
PointAffineNoInfinity_BN254 {
|
||||
x: BaseField_BN254 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BN254 { s: get_fixed_limbs(y) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limbs(&self) -> Vec<u32> {
|
||||
[
|
||||
self.x
|
||||
.limbs(),
|
||||
self.y
|
||||
.limbs(),
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Point_BN254 {
|
||||
Point_BN254 {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: BaseField_BN254::one(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_ark(&self) -> G1Affine_BN254 {
|
||||
G1Affine_BN254::new(
|
||||
Fq_BN254::new(
|
||||
self.x
|
||||
.to_ark(),
|
||||
),
|
||||
Fq_BN254::new(
|
||||
self.y
|
||||
.to_ark(),
|
||||
),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn to_ark_repr(&self) -> G1Affine_BN254 {
|
||||
G1Affine_BN254::new(
|
||||
Fq_BN254::from_repr(
|
||||
self.x
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
Fq_BN254::from_repr(
|
||||
self.y
|
||||
.to_ark(),
|
||||
)
|
||||
.unwrap(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_ark(p: &G1Affine_BN254) -> Self {
|
||||
PointAffineNoInfinity_BN254 {
|
||||
x: BaseField_BN254::from_ark(p.x.into_repr()),
|
||||
y: BaseField_BN254::from_ark(p.y.into_repr()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Point_BN254 {
|
||||
// TODO: generics
|
||||
|
||||
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Point_BN254 {
|
||||
x: BaseField_BN254 { s: get_fixed_limbs(x) },
|
||||
y: BaseField_BN254 { s: get_fixed_limbs(y) },
|
||||
z: BaseField_BN254 { s: get_fixed_limbs(z) },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_xy_limbs(value: &[u32]) -> Point_BN254 {
|
||||
let l = value.len();
|
||||
assert_eq!(l, 3 * BASE_LIMBS_BN254, "length must be 3 * {}", BASE_LIMBS_BN254);
|
||||
Point_BN254 {
|
||||
x: BaseField_BN254 {
|
||||
s: value[..BASE_LIMBS_BN254]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
y: BaseField_BN254 {
|
||||
s: value[BASE_LIMBS_BN254..BASE_LIMBS_BN254 * 2]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
z: BaseField_BN254 {
|
||||
s: value[BASE_LIMBS_BN254 * 2..]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> PointAffineNoInfinity_BN254 {
|
||||
let ark_affine = self.to_ark_affine();
|
||||
PointAffineNoInfinity_BN254 {
|
||||
x: BaseField_BN254::from_ark(
|
||||
ark_affine
|
||||
.x
|
||||
.into_repr(),
|
||||
),
|
||||
y: BaseField_BN254::from_ark(
|
||||
ark_affine
|
||||
.y
|
||||
.into_repr(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BN254 {
|
||||
PointAffineNoInfinity_BN254 { x: self.x, y: self.y }
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarField_BN254 {
|
||||
pub fn from_limbs(value: &[u32]) -> ScalarField_BN254 {
|
||||
ScalarField_BN254 {
|
||||
s: get_fixed_limbs(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::curves::bn254::{Point_BN254, ScalarField_BN254};
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
|
||||
let scalar = ScalarField_BN254::from_limbs(&limbs);
|
||||
assert_eq!(
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute(),
|
||||
"{:08X?} {:08X?}",
|
||||
scalar.to_ark(),
|
||||
scalar.to_ark_transmute()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn test_point_equality() {
|
||||
let left = Point_BN254::zero();
|
||||
let right = Point_BN254::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BN254::from_limbs(&[0; 8], &[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8]);
|
||||
assert_eq!(left, right);
|
||||
let right = Point_BN254::from_limbs(&[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8], &[1, 0, 0, 0, 0, 0, 0, 0]);
|
||||
assert!(left != right);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod bls12_377;
|
||||
pub mod bls12_381;
|
||||
pub mod bn254;
|
||||
@@ -1,5 +0,0 @@
|
||||
pub mod curves;
|
||||
pub mod test_bls12_377;
|
||||
pub mod test_bls12_381;
|
||||
pub mod test_bn254;
|
||||
pub mod utils;
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1751
src/test_bn254.rs
1751
src/test_bn254.rs
File diff suppressed because it is too large
Load Diff
99
src/utils.rs
99
src/utils.rs
@@ -1,99 +0,0 @@
|
||||
use rand::rngs::StdRng;
|
||||
use rand::RngCore;
|
||||
use rand::SeedableRng;
|
||||
|
||||
pub fn from_limbs<T>(limbs: Vec<u32>, chunk_size: usize, f: fn(&[u32]) -> T) -> Vec<T> {
|
||||
let points = limbs
|
||||
.chunks(chunk_size)
|
||||
.map(|lmbs| f(lmbs))
|
||||
.collect::<Vec<T>>();
|
||||
points
|
||||
}
|
||||
|
||||
pub fn u32_vec_to_u64_vec(arr_u32: &[u32]) -> Vec<u64> {
|
||||
let len = (arr_u32.len() / 2) as usize;
|
||||
let mut arr_u64 = vec![0u64; len];
|
||||
|
||||
for i in 0..len {
|
||||
arr_u64[i] = u64::from(arr_u32[i * 2]) | (u64::from(arr_u32[i * 2 + 1]) << 32);
|
||||
}
|
||||
|
||||
arr_u64
|
||||
}
|
||||
|
||||
pub fn u64_vec_to_u32_vec(arr_u64: &[u64]) -> Vec<u32> {
|
||||
let len = arr_u64.len() * 2;
|
||||
let mut arr_u32 = vec![0u32; len];
|
||||
|
||||
for i in 0..arr_u64.len() {
|
||||
arr_u32[i * 2] = arr_u64[i] as u32;
|
||||
arr_u32[i * 2 + 1] = (arr_u64[i] >> 32) as u32;
|
||||
}
|
||||
|
||||
arr_u32
|
||||
}
|
||||
|
||||
pub fn get_rng(seed: Option<u64>) -> Box<dyn RngCore> {
|
||||
//TOOD: this func is universal
|
||||
let rng: Box<dyn RngCore> = match seed {
|
||||
Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
|
||||
None => Box::new(rand::thread_rng()),
|
||||
};
|
||||
rng
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ark_ff::BigInteger256;
|
||||
|
||||
use crate::curves::bls12_381::ScalarField_BLS12_381 as ScalarField;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_u32_vec_to_u64_vec() {
|
||||
let arr_u32 = [1, 0x0fffffff, 3, 0x2fffffff, 5, 0x4fffffff, 7, 0x6fffffff];
|
||||
|
||||
let s = ScalarField::from_ark_transmute(BigInteger256::new(
|
||||
u32_vec_to_u64_vec(&arr_u32)
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
))
|
||||
.limbs();
|
||||
|
||||
assert_eq!(arr_u32.to_vec(), s);
|
||||
|
||||
let arr_u64_expected = [
|
||||
0x0FFFFFFF00000001,
|
||||
0x2FFFFFFF00000003,
|
||||
0x4FFFFFFF00000005,
|
||||
0x6FFFFFFF00000007,
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
u32_vec_to_u64_vec(&arr_u32),
|
||||
arr_u64_expected,
|
||||
"{:016X?}",
|
||||
u32_vec_to_u64_vec(&arr_u32)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_vec_to_u32_vec() {
|
||||
let arr_u64 = [
|
||||
0x2FFFFFFF00000001,
|
||||
0x4FFFFFFF00000003,
|
||||
0x6FFFFFFF00000005,
|
||||
0x8FFFFFFF00000007,
|
||||
];
|
||||
|
||||
let arr_u32_expected = [1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7, 0x8fffffff];
|
||||
|
||||
assert_eq!(
|
||||
u64_vec_to_u32_vec(&arr_u64),
|
||||
arr_u32_expected,
|
||||
"{:016X?}",
|
||||
u64_vec_to_u32_vec(&arr_u64)
|
||||
);
|
||||
}
|
||||
}
|
||||
3
wrappers/rust/Cargo.toml
Normal file
3
wrappers/rust/Cargo.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["icicle-cuda-runtime", "icicle-core", "icicle-curves/icicle-bn254"]
|
||||
36
wrappers/rust/icicle-core/Cargo.toml
Normal file
36
wrappers/rust/icicle-core/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "icicle-core"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
description = "A library for GPU ZK acceleration by Ingonyama"
|
||||
homepage = "https://www.ingonyama.com"
|
||||
repository = "https://github.com/ingonyama-zk/icicle"
|
||||
|
||||
|
||||
[dependencies]
|
||||
|
||||
icicle-cuda-runtime = { path = "../icicle-cuda-runtime" }
|
||||
|
||||
ark-ff = { version = "0.4.0", optional = true }
|
||||
ark-ec = { version = "0.4.0", optional = true, features = [ "parallel" ] }
|
||||
|
||||
# [build-dependencies]
|
||||
# cc = { version = "1.0", features = ["parallel"] }
|
||||
# cmake = "*"
|
||||
# bindgen = "*"
|
||||
# libc = "*" #TODO: move libc dependencies to build
|
||||
|
||||
# [dev-dependencies]
|
||||
# rustacuda = "0.1"
|
||||
# rustacuda_core = "0.1"
|
||||
# rustacuda_derive = "0.1"
|
||||
|
||||
# "criterion" = "*"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
arkworks = ["ark-ff", "ark-ec"]
|
||||
# TODO: impl G2 and EC NTT
|
||||
g2 = []
|
||||
ec_ntt = []
|
||||
193
wrappers/rust/icicle-core/src/curve.rs
Normal file
193
wrappers/rust/icicle-core/src/curve.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use crate::field::{Field, FieldConfig};
|
||||
#[cfg(feature = "arkworks")]
|
||||
use crate::traits::ArkConvertible;
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_ec::models::CurveConfig as ArkCurveConfig;
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_ec::short_weierstrass::{Affine as ArkAffine, Projective as ArkProjective, SWCurveConfig};
|
||||
use std::ffi::{c_uint, c_void};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub trait CurveConfig: PartialEq + Copy + Clone {
|
||||
fn eq_proj(point1: *const c_void, point2: *const c_void) -> c_uint;
|
||||
fn to_affine(point: *const c_void, point_aff: *mut c_void);
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
type ArkSWConfig: SWCurveConfig;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[repr(C)]
|
||||
pub struct Projective<T, C: CurveConfig> {
|
||||
pub x: T,
|
||||
pub y: T,
|
||||
pub z: T,
|
||||
p: PhantomData<C>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
#[repr(C)]
|
||||
pub struct Affine<T, C: CurveConfig> {
|
||||
pub x: T,
|
||||
pub y: T,
|
||||
p: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F, C> Affine<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
F: FieldConfig,
|
||||
C: CurveConfig,
|
||||
{
|
||||
// While this is not a true zero point and not even a valid point, it's still useful
|
||||
// both as a handy default as well as a representation of zero points in other codebases
|
||||
pub fn zero() -> Self {
|
||||
Affine {
|
||||
x: Field::<NUM_LIMBS, F>::zero(),
|
||||
y: Field::<NUM_LIMBS, F>::zero(),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_limbs(x: &[u32], y: &[u32]) -> Self {
|
||||
Affine {
|
||||
x: Field::<NUM_LIMBS, F>::set_limbs(x),
|
||||
y: Field::<NUM_LIMBS, F>::set_limbs(y),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> Projective<Field<NUM_LIMBS, F>, C> {
|
||||
Projective {
|
||||
x: self.x,
|
||||
y: self.y,
|
||||
z: Field::<NUM_LIMBS, F>::one(),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F, C> From<Affine<Field<NUM_LIMBS, F>, C>> for Projective<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
F: FieldConfig,
|
||||
C: CurveConfig,
|
||||
{
|
||||
fn from(item: Affine<Field<NUM_LIMBS, F>, C>) -> Self {
|
||||
Self {
|
||||
x: item.x,
|
||||
y: item.y,
|
||||
z: Field::<NUM_LIMBS, F>::one(),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F, C> Projective<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
F: FieldConfig,
|
||||
C: CurveConfig,
|
||||
{
|
||||
pub fn zero() -> Self {
|
||||
Projective {
|
||||
x: Field::<NUM_LIMBS, F>::zero(),
|
||||
y: Field::<NUM_LIMBS, F>::one(),
|
||||
z: Field::<NUM_LIMBS, F>::zero(),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
|
||||
Projective {
|
||||
x: Field::<NUM_LIMBS, F>::set_limbs(x),
|
||||
y: Field::<NUM_LIMBS, F>::set_limbs(y),
|
||||
z: Field::<NUM_LIMBS, F>::set_limbs(z),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F, C> PartialEq for Projective<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
F: FieldConfig,
|
||||
C: CurveConfig,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
C::eq_proj(self as *const _ as *const c_void, other as *const _ as *const c_void) != 0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F, C> From<Projective<Field<NUM_LIMBS, F>, C>> for Affine<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
F: FieldConfig,
|
||||
C: CurveConfig,
|
||||
{
|
||||
fn from(item: Projective<Field<NUM_LIMBS, F>, C>) -> Self {
|
||||
let mut aff = Self::zero();
|
||||
C::to_affine(&item as *const _ as *const c_void, &mut aff as *mut _ as *mut c_void);
|
||||
aff
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
impl<const NUM_LIMBS: usize, F, C> ArkConvertible for Affine<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
C: CurveConfig,
|
||||
F: FieldConfig<ArkField = <<C as CurveConfig>::ArkSWConfig as ArkCurveConfig>::BaseField>,
|
||||
{
|
||||
type ArkEquivalent = ArkAffine<C::ArkSWConfig>;
|
||||
|
||||
fn to_ark(&self) -> Self::ArkEquivalent {
|
||||
let proj_x = self
|
||||
.x
|
||||
.to_ark();
|
||||
let proj_y = self
|
||||
.y
|
||||
.to_ark();
|
||||
Self::ArkEquivalent::new_unchecked(proj_x, proj_y)
|
||||
}
|
||||
|
||||
fn from_ark(ark: Self::ArkEquivalent) -> Self {
|
||||
Self {
|
||||
x: Field::<NUM_LIMBS, F>::from_ark(ark.x),
|
||||
y: Field::<NUM_LIMBS, F>::from_ark(ark.y),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
impl<const NUM_LIMBS: usize, F, C> ArkConvertible for Projective<Field<NUM_LIMBS, F>, C>
|
||||
where
|
||||
C: CurveConfig,
|
||||
F: FieldConfig<ArkField = <<C as CurveConfig>::ArkSWConfig as ArkCurveConfig>::BaseField>,
|
||||
{
|
||||
type ArkEquivalent = ArkProjective<C::ArkSWConfig>;
|
||||
|
||||
fn to_ark(&self) -> Self::ArkEquivalent {
|
||||
let proj_x = self
|
||||
.x
|
||||
.to_ark();
|
||||
let proj_y = self
|
||||
.y
|
||||
.to_ark();
|
||||
let proj_z = self
|
||||
.z
|
||||
.to_ark();
|
||||
|
||||
// conversion between projective used in icicle and Jacobian used in arkworks
|
||||
let proj_x = proj_x * proj_z;
|
||||
let proj_y = proj_y * proj_z * proj_z;
|
||||
Self::ArkEquivalent::new_unchecked(proj_x, proj_y, proj_z)
|
||||
}
|
||||
|
||||
fn from_ark(ark: Self::ArkEquivalent) -> Self {
|
||||
// conversion between Jacobian used in arkworks and projective used in icicle
|
||||
let proj_x = ark.x * ark.z;
|
||||
let proj_z = ark.z * ark.z * ark.z;
|
||||
Self {
|
||||
x: Field::<NUM_LIMBS, F>::from_ark(proj_x),
|
||||
y: Field::<NUM_LIMBS, F>::from_ark(ark.y),
|
||||
z: Field::<NUM_LIMBS, F>::from_ark(proj_z),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
98
wrappers/rust/icicle-core/src/field.rs
Normal file
98
wrappers/rust/icicle-core/src/field.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
#[cfg(feature = "arkworks")]
|
||||
use crate::traits::ArkConvertible;
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_ff::{BigInteger, PrimeField};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
pub trait FieldConfig: PartialEq + Copy + Clone {
|
||||
type ArkField: PrimeField;
|
||||
}
|
||||
#[cfg(not(feature = "arkworks"))]
|
||||
pub trait FieldConfig: PartialEq + Copy + Clone {}
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
#[repr(C)]
|
||||
pub struct Field<const NUM_LIMBS: usize, F: FieldConfig> {
|
||||
limbs: [u32; NUM_LIMBS],
|
||||
p: PhantomData<F>,
|
||||
}
|
||||
|
||||
pub(crate) fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
|
||||
match val.len() {
|
||||
n if n < NUM_LIMBS => {
|
||||
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
|
||||
padded[..val.len()].copy_from_slice(&val);
|
||||
padded
|
||||
}
|
||||
n if n == NUM_LIMBS => val
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
_ => panic!("slice has too many elements"),
|
||||
}
|
||||
}
|
||||
|
||||
impl<const NUM_LIMBS: usize, F: FieldConfig> Field<NUM_LIMBS, F> {
|
||||
pub fn get_limbs(&self) -> [u32; NUM_LIMBS] {
|
||||
self.limbs
|
||||
}
|
||||
|
||||
pub fn set_limbs(value: &[u32]) -> Self {
|
||||
Self {
|
||||
limbs: get_fixed_limbs(value),
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bytes_le(&self) -> Vec<u8> {
|
||||
self.limbs
|
||||
.iter()
|
||||
.map(|limb| {
|
||||
limb.to_le_bytes()
|
||||
.to_vec()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn from_bytes_le(bytes: &[u8]) -> Self {
|
||||
let limbs = bytes
|
||||
.chunks(4)
|
||||
.map(|chunk| {
|
||||
u32::from_le_bytes(
|
||||
chunk
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Self::set_limbs(&limbs)
|
||||
}
|
||||
|
||||
pub fn zero() -> Self {
|
||||
Field {
|
||||
limbs: [0u32; NUM_LIMBS],
|
||||
p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one() -> Self {
|
||||
let mut limbs = [0u32; NUM_LIMBS];
|
||||
limbs[0] = 1;
|
||||
Field { limbs, p: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
impl<const NUM_LIMBS: usize, F: FieldConfig> ArkConvertible for Field<NUM_LIMBS, F> {
|
||||
type ArkEquivalent = F::ArkField;
|
||||
|
||||
fn to_ark(&self) -> Self::ArkEquivalent {
|
||||
F::ArkField::from_le_bytes_mod_order(&self.to_bytes_le())
|
||||
}
|
||||
|
||||
fn from_ark(ark: Self::ArkEquivalent) -> Self {
|
||||
let ark_bigint: <Self::ArkEquivalent as PrimeField>::BigInt = ark.into();
|
||||
Self::from_bytes_le(&ark_bigint.to_bytes_le())
|
||||
}
|
||||
}
|
||||
5
wrappers/rust/icicle-core/src/lib.rs
Normal file
5
wrappers/rust/icicle-core/src/lib.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod curve;
|
||||
pub mod field;
|
||||
pub mod msm;
|
||||
pub mod ntt;
|
||||
pub mod traits;
|
||||
98
wrappers/rust/icicle-core/src/msm/mod.rs
Normal file
98
wrappers/rust/icicle-core/src/msm/mod.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
|
||||
/*
|
||||
/**
|
||||
* @struct MSMConfig
|
||||
* Struct that encodes MSM parameters to be passed into the [msm](@ref msm) function.
|
||||
*/
|
||||
struct MSMConfig {
|
||||
bool are_scalars_on_device; /**< True if scalars are on device and false if they're on host. Default value: false. */
|
||||
bool are_scalars_montgomery_form; /**< True if scalars are in Montgomery form and false otherwise. Default value: true. */
|
||||
int points_size; /**< Number of points in the MSM. If a batch of MSMs needs to be computed, this should be a number
|
||||
* of different points. So, if each MSM re-uses the same set of points, this variable is set equal
|
||||
* to the MSM size. And if every MSM uses a distinct set of points, it should be set to the product of
|
||||
* MSM size and [batch_size](@ref batch_size). Default value: 0 (meaning it's equal to the MSM size). */
|
||||
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the number of computations
|
||||
* to make, on-line memory footprint, but increase the static memory footprint. Default value: 1 (i.e. don't pre-compute). */
|
||||
bool are_points_on_device; /**< True if points are on device and false if they're on host. Default value: false. */
|
||||
bool are_points_montgomery_form; /**< True if coordinates of points are in Montgomery form and false otherwise. Default value: true. */
|
||||
int batch_size; /**< The number of MSMs to compute. Default value: 1. */
|
||||
bool are_results_on_device; /**< True if the results should be on device and false if they should be on host. If set to false,
|
||||
* `is_async` won't take effect because a synchronization is needed to transfer results to the host. Default value: false. */
|
||||
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
|
||||
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
|
||||
* footprint but also more parallelism and less computational complexity (up to a certain point).
|
||||
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
|
||||
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a different
|
||||
* (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field). */
|
||||
bool is_big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but also greatly
|
||||
* decreases parallelism, so only suitable for large batches of MSMs. Default value: false. */
|
||||
int large_bucket_factor; /**< Variable that controls how sensitive the algorithm is to the buckets that occur very frequently.
|
||||
* Useful for efficient treatment of non-uniform distributions of scalars and "top windows" with few bits.
|
||||
* Can be set to 0 to disable separate treatment of large buckets altogether. Default value: 10. */
|
||||
int is_async; /**< Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
|
||||
* and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`.
|
||||
* If set to false, the MSM function will block the current CPU thread. */
|
||||
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See [DeviceContext](@ref `device_context::DeviceContext`). */
|
||||
};
|
||||
*/
|
||||
/// Struct that encodes MSM parameters to be passed into the `msm` function.
|
||||
|
||||
#[repr(C)]
|
||||
pub struct MSMConfig<'a> {
|
||||
/// True if scalars are on device and false if they're on host. Default value: false.
|
||||
pub are_scalars_on_device: bool,
|
||||
|
||||
/// True if scalars are in Montgomery form and false otherwise. Default value: true.
|
||||
pub are_scalars_montgomery_form: bool,
|
||||
|
||||
/// Number of points in the MSM. If a batch of MSMs needs to be computed, this should be a number
|
||||
/// of different points. So, if each MSM re-uses the same set of points, this variable is set equal
|
||||
/// to the MSM size. And if every MSM uses a distinct set of points, it should be set to the product of
|
||||
/// MSM size and batch_size. Default value: 0 (meaning it's equal to the MSM size).
|
||||
pub points_size: usize, // Note: `unsigned` in C++ corresponds to `u32` in Rust
|
||||
|
||||
/// The number of extra points to pre-compute for each point. Larger values decrease the number of computations
|
||||
/// to make, on-line memory footprint, but increase the static memory footprint. Default value: 1 (i.e. don't pre-compute).
|
||||
pub precompute_factor: usize,
|
||||
|
||||
/// True if points are on device and false if they're on host. Default value: false.
|
||||
pub are_points_on_device: bool,
|
||||
|
||||
/// True if coordinates of points are in Montgomery form and false otherwise. Default value: true.
|
||||
pub are_points_montgomery_form: bool,
|
||||
|
||||
/// The number of MSMs to compute. Default value: 1.
|
||||
pub batch_size: usize,
|
||||
|
||||
/// True if the results should be on device and false if they should be on host. If set to false,
|
||||
/// `is_async` won't take effect because a synchronization is needed to transfer results to the host. Default value: false.
|
||||
pub are_results_on_device: bool,
|
||||
|
||||
/// `c` value, or "window bitsize" which is the main parameter of the "bucket method"
|
||||
/// that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
|
||||
/// footprint but also more parallelism and less computational complexity (up to a certain point).
|
||||
/// Default value: 0 (the optimal value of `c` is chosen automatically).
|
||||
pub c: usize,
|
||||
|
||||
/// Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a different
|
||||
/// (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field).
|
||||
pub bitsize: usize,
|
||||
|
||||
/// Whether to do "bucket accumulation" serially. Decreases computational complexity, but also greatly
|
||||
/// decreases parallelism, so only suitable for large batches of MSMs. Default value: false.
|
||||
pub is_big_triangle: bool,
|
||||
|
||||
/// Variable that controls how sensitive the algorithm is to the buckets that occur very frequently.
|
||||
/// Useful for efficient treatment of non-uniform distributions of scalars and "top windows" with few bits.
|
||||
/// Can be set to 0 to disable separate treatment of large buckets altogether. Default value: 10.
|
||||
pub large_bucket_factor: usize,
|
||||
|
||||
/// Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
|
||||
/// and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`.
|
||||
/// If set to `false`, the MSM function will block the current CPU thread.
|
||||
pub is_async: bool,
|
||||
|
||||
/// Details related to the device such as its id and stream id.
|
||||
pub ctx: DeviceContext<'a>,
|
||||
}
|
||||
104
wrappers/rust/icicle-core/src/ntt/mod.rs
Normal file
104
wrappers/rust/icicle-core/src/ntt/mod.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use std::os::raw::c_int;
|
||||
|
||||
/**
|
||||
* @enum Ordering
|
||||
* How to order inputs and outputs of the NTT:
|
||||
* - kNN: inputs and outputs are natural-order (example of natural ordering: \f$ \{a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7\} \f$).
|
||||
* - kNR: inputs are natural-order and outputs are bit-reversed-order (example of bit-reversed ordering: \f$ \{a_0, a_4, a_2, a_6, a_1, a_5, a_3, a_7\} \f$).
|
||||
* - kRN: inputs are bit-reversed-order and outputs are natural-order.
|
||||
* - kRR: inputs and outputs are bit-reversed-order.
|
||||
*/
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Ordering {
|
||||
kNN,
|
||||
kNR,
|
||||
kRN,
|
||||
kRR,
|
||||
}
|
||||
|
||||
/**
|
||||
* @enum Decimation
|
||||
* Decimation of the NTT algorithm:
|
||||
* - kDIT: decimation in time.
|
||||
* - kDIF: decimation in frequency.
|
||||
*/
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Decimation {
|
||||
kDIT,
|
||||
kDIF,
|
||||
}
|
||||
|
||||
/**
|
||||
* @enum Butterfly
|
||||
* [Butterfly](https://en.wikipedia.org/wiki/Butterfly_diagram) used in the NTT algorithm (i.e. what happens to each pair of inputs on every iteration):
|
||||
* - kCooleyTukey: Cooley-Tukey butterfly.
|
||||
* - kGentlemanSande: Gentleman-Sande butterfly.
|
||||
*/
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Butterfly {
|
||||
kCooleyTukey,
|
||||
kGentlemanSande,
|
||||
}
|
||||
|
||||
/**
|
||||
* @struct NTTConfig
|
||||
* Struct that encodes NTT parameters to be passed into the [ntt](@ref ntt) function.
|
||||
*/
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
pub struct NTTConfigCuda<'a, E, S> {
|
||||
pub inout: *mut E,
|
||||
/**< Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot config.batch_size \f$.
|
||||
* Note that if inputs are in Montgomery form, the outputs will be as well and vice-verse: non-Montgomery inputs produce non-Montgomety outputs.*/
|
||||
pub is_input_on_device: bool,
|
||||
/**< True if inputs/outputs are on device and false if they're on host. Default value: false. */
|
||||
pub is_inverse: bool,
|
||||
/**< True if true . Default value: false. */
|
||||
pub ordering: Ordering,
|
||||
/**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. */
|
||||
pub decimation: Decimation,
|
||||
/**< Decimation of the algorithm, see [Decimation](@ref Decimation). Default value: `Decimation::kDIT`.
|
||||
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
|
||||
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
|
||||
* `Decimation::kDIT` and if ordering is `Ordering::kNR` — to `Decimation::kDIF`. */
|
||||
pub butterfly: Butterfly,
|
||||
/**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value: `Butterfly::kCooleyTukey`.
|
||||
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
|
||||
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
|
||||
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
|
||||
pub is_coset: bool,
|
||||
/**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is computed
|
||||
* on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen), so:
|
||||
* \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\} \f$. Default value: false. */
|
||||
pub coset_gen: *const S,
|
||||
/**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
|
||||
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
|
||||
pub twiddles: *const S,
|
||||
/**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
|
||||
* This pointer is expected to live on device. The order is as follows:
|
||||
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle factors
|
||||
* are generated online using the default generator (TODO: link to twiddle gen here) and function
|
||||
* [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
|
||||
pub inv_twiddles: *const S,
|
||||
/**< "Inverse twiddle factors", (or "domain", or "roots of unity") on which the iNTT is evaluated.
|
||||
* This pointer is expected to live on device. The order is as follows:
|
||||
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle factors
|
||||
* are generated online using the default generator (TODO: link to twiddle gen here) and function
|
||||
* [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
|
||||
pub size: c_int,
|
||||
/**< NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is the size of 1 NTT. */
|
||||
pub batch_size: c_int,
|
||||
/**< The number of NTTs to compute. Default value: 1. */
|
||||
pub is_preserving_twiddles: bool,
|
||||
/**< If true, twiddle factors are preserved on device for subsequent use in config and not freed after calculation. Default value: false. */
|
||||
pub is_output_on_device: bool,
|
||||
/**< If true, output is preserved on device for subsequent use in config and not freed after calculation. Default value: false. */
|
||||
pub ctx: DeviceContext<'a>, /*< Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext). */
|
||||
}
|
||||
7
wrappers/rust/icicle-core/src/traits.rs
Normal file
7
wrappers/rust/icicle-core/src/traits.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#[cfg(feature = "arkworks")]
|
||||
pub trait ArkConvertible {
|
||||
type ArkEquivalent;
|
||||
|
||||
fn to_ark(&self) -> Self::ArkEquivalent;
|
||||
fn from_ark(ark: Self::ArkEquivalent) -> Self;
|
||||
}
|
||||
14
wrappers/rust/icicle-cuda-runtime/Cargo.toml
Normal file
14
wrappers/rust/icicle-cuda-runtime/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "icicle-cuda-runtime"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
description = "Ingonyama's Rust wrapper of CUDA runtime"
|
||||
homepage = "https://www.ingonyama.com"
|
||||
repository = "https://github.com/ingonyama-zk/icicle"
|
||||
|
||||
[dependencies]
|
||||
bitflags = "2.4"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "*"
|
||||
82
wrappers/rust/icicle-cuda-runtime/build.rs
Normal file
82
wrappers/rust/icicle-cuda-runtime/build.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
// Based on https://github.com/matter-labs/z-prize-msm-gpu/blob/main/bellman-cuda-rust/cudart-sys/build.rs
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn cuda_include_path() -> &'static str {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
concat!(env!("CUDA_PATH"), "/include")
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
"/usr/local/cuda/include"
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_lib_path() -> &'static str {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
concat!(env!("CUDA_PATH"), "/lib/x64")
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
"/usr/local/cuda/lib64"
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let cuda_runtime_api_path = PathBuf::from(cuda_include_path())
|
||||
.join("cuda_runtime_api.h")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
println!("cargo:rustc-link-search=native={}", cuda_lib_path());
|
||||
println!("cargo:rustc-link-lib=cudart");
|
||||
println!("cargo:rerun-if-changed={}", cuda_runtime_api_path);
|
||||
|
||||
let bindings = bindgen::Builder::default()
|
||||
.header(cuda_runtime_api_path)
|
||||
.size_t_is_usize(true)
|
||||
.generate_comments(false)
|
||||
.layout_tests(false)
|
||||
.allowlist_type("cudaError")
|
||||
.rustified_enum("cudaError")
|
||||
.must_use_type("cudaError")
|
||||
// device management
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
|
||||
.allowlist_function("cudaSetDevice")
|
||||
// error handling
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__ERROR.html
|
||||
.allowlist_function("cudaGetLastError")
|
||||
// stream management
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
|
||||
.allowlist_function("cudaStreamCreate")
|
||||
.allowlist_var("cudaStreamDefault")
|
||||
.allowlist_var("cudaStreamNonBlocking")
|
||||
.allowlist_function("cudaStreamCreateWithFlags")
|
||||
.allowlist_function("cudaStreamDestroy")
|
||||
.allowlist_function("cudaStreamQuery")
|
||||
.allowlist_function("cudaStreamSynchronize")
|
||||
.allowlist_var("cudaEventWaitDefault")
|
||||
.allowlist_var("cudaEventWaitExternal")
|
||||
.allowlist_function("cudaStreamWaitEvent")
|
||||
// memory management
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html
|
||||
.allowlist_function("cudaFree")
|
||||
.allowlist_function("cudaMalloc")
|
||||
.allowlist_function("cudaMemcpy")
|
||||
.allowlist_function("cudaMemcpyAsync")
|
||||
.allowlist_function("cudaMemset")
|
||||
.allowlist_function("cudaMemsetAsync")
|
||||
.rustified_enum("cudaMemcpyKind")
|
||||
// Stream Ordered Memory Allocator
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html
|
||||
.allowlist_function("cudaFreeAsync")
|
||||
.allowlist_function("cudaMallocAsync")
|
||||
//
|
||||
.generate()
|
||||
.expect("Unable to generate bindings");
|
||||
|
||||
fs::write(PathBuf::from("src").join("bindings.rs"), bindings.to_string()).expect("Couldn't write bindings!");
|
||||
}
|
||||
27
wrappers/rust/icicle-cuda-runtime/src/device_context.rs
Normal file
27
wrappers/rust/icicle-cuda-runtime/src/device_context.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use crate::memory::CudaMemPool;
|
||||
use crate::stream::CudaStream;
|
||||
|
||||
/// Properties of the device used in icicle functions.
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
pub struct DeviceContext<'a> {
|
||||
/// Index of the currently used GPU. Default value: 0.
|
||||
pub device_id: usize,
|
||||
|
||||
/// Stream to use. Default value: 0.
|
||||
pub stream: &'a CudaStream, // Assuming the type is provided by a CUDA binding crate
|
||||
|
||||
/// Mempool to use. Default value: 0.
|
||||
pub mempool: CudaMemPool, // Assuming the type is provided by a CUDA binding crate
|
||||
}
|
||||
|
||||
pub fn get_default_device_context() -> DeviceContext<'static> {
|
||||
static default_stream: CudaStream = CudaStream {
|
||||
handle: std::ptr::null_mut(),
|
||||
};
|
||||
DeviceContext {
|
||||
device_id: 0,
|
||||
stream: &default_stream,
|
||||
mempool: 0,
|
||||
}
|
||||
}
|
||||
35
wrappers/rust/icicle-cuda-runtime/src/error.rs
Normal file
35
wrappers/rust/icicle-cuda-runtime/src/error.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::bindings::{cudaError, cudaGetLastError};
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
pub type CudaError = cudaError;
|
||||
|
||||
pub type CudaResult<T> = Result<T, CudaError>;
|
||||
|
||||
pub trait CudaResultWrap {
|
||||
fn wrap(self) -> CudaResult<()>;
|
||||
fn wrap_value<T>(self, value: T) -> CudaResult<T>;
|
||||
fn wrap_maybe_uninit<T>(self, value: MaybeUninit<T>) -> CudaResult<T>;
|
||||
}
|
||||
|
||||
impl CudaResultWrap for CudaError {
|
||||
fn wrap(self) -> CudaResult<()> {
|
||||
self.wrap_value(())
|
||||
}
|
||||
|
||||
fn wrap_value<T>(self, value: T) -> CudaResult<T> {
|
||||
if self == CudaError::cudaSuccess {
|
||||
Ok(value)
|
||||
} else {
|
||||
Err(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_maybe_uninit<T>(self, value: MaybeUninit<T>) -> CudaResult<T> {
|
||||
self.wrap_value(value)
|
||||
.map(|x| unsafe { x.assume_init() })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_last_error() -> CudaError {
|
||||
unsafe { cudaGetLastError() }
|
||||
}
|
||||
8
wrappers/rust/icicle-cuda-runtime/src/lib.rs
Normal file
8
wrappers/rust/icicle-cuda-runtime/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
|
||||
mod bindings;
|
||||
pub mod device_context;
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
pub mod stream;
|
||||
156
wrappers/rust/icicle-cuda-runtime/src/memory.rs
Normal file
156
wrappers/rust/icicle-cuda-runtime/src/memory.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use crate::bindings::{cudaMalloc, cudaMallocAsync, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
|
||||
use crate::error::{CudaError, CudaResult, CudaResultWrap};
|
||||
use crate::stream::CudaStream;
|
||||
use std::mem::{size_of, MaybeUninit};
|
||||
use std::os::raw::c_void;
|
||||
use std::{ptr, slice};
|
||||
|
||||
/// Fixed-size device-side slice.
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
pub struct DeviceSlice<'a, T>(&'a mut [T]);
|
||||
|
||||
impl<'a, T> DeviceSlice<'a, T> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.0
|
||||
.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0
|
||||
.is_empty()
|
||||
}
|
||||
|
||||
pub fn as_slice(&mut self) -> &mut [T] {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const T {
|
||||
self.0
|
||||
.as_ptr()
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&mut self) -> *mut T {
|
||||
self.0
|
||||
.as_mut_ptr()
|
||||
}
|
||||
|
||||
pub fn cuda_malloc(count: usize) -> CudaResult<Self> {
|
||||
let size = count
|
||||
.checked_mul(size_of::<T>())
|
||||
.unwrap_or(0);
|
||||
if size == 0 {
|
||||
return Err(CudaError::cudaErrorMemoryAllocation);
|
||||
}
|
||||
|
||||
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
|
||||
unsafe {
|
||||
cudaMalloc(device_ptr.as_mut_ptr(), size).wrap()?;
|
||||
Ok(DeviceSlice {
|
||||
0: slice::from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_malloc_async(count: usize, stream: &mut CudaStream) -> CudaResult<Self> {
|
||||
let size = count
|
||||
.checked_mul(size_of::<T>())
|
||||
.unwrap_or(0);
|
||||
if size == 0 {
|
||||
return Err(CudaError::cudaErrorMemoryAllocation);
|
||||
}
|
||||
|
||||
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
|
||||
unsafe {
|
||||
cudaMallocAsync(device_ptr.as_mut_ptr(), size, stream as *mut _ as *mut _).wrap()?;
|
||||
Ok(DeviceSlice {
|
||||
0: slice::from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn copy_from_host(&mut self, val: &[T]) -> CudaResult<()> {
|
||||
assert!(
|
||||
self.len() == val.len(),
|
||||
"destination and source slices have different lengths"
|
||||
);
|
||||
let size = size_of::<T>() * self.len();
|
||||
if size != 0 {
|
||||
unsafe {
|
||||
cudaMemcpy(
|
||||
self.as_mut_ptr() as *mut c_void,
|
||||
val.as_ptr() as *const c_void,
|
||||
size,
|
||||
cudaMemcpyKind::cudaMemcpyHostToDevice,
|
||||
)
|
||||
.wrap()?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn copy_to_host(&self, val: &mut [T]) -> CudaResult<()> {
|
||||
assert!(
|
||||
self.len() == val.len(),
|
||||
"destination and source slices have different lengths"
|
||||
);
|
||||
let size = size_of::<T>() * self.len();
|
||||
if size != 0 {
|
||||
unsafe {
|
||||
cudaMemcpy(
|
||||
val.as_mut_ptr() as *mut c_void,
|
||||
self.as_ptr() as *const c_void,
|
||||
size,
|
||||
cudaMemcpyKind::cudaMemcpyDeviceToHost,
|
||||
)
|
||||
.wrap()?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn copy_from_host_async(&mut self, val: &[T], stream: &mut CudaStream) -> CudaResult<()> {
|
||||
assert!(
|
||||
self.len() == val.len(),
|
||||
"destination and source slices have different lengths"
|
||||
);
|
||||
let size = size_of::<T>() * self.len();
|
||||
if size != 0 {
|
||||
unsafe {
|
||||
cudaMemcpyAsync(
|
||||
self.as_mut_ptr() as *mut c_void,
|
||||
val.as_ptr() as *const c_void,
|
||||
size,
|
||||
cudaMemcpyKind::cudaMemcpyHostToDevice,
|
||||
stream as *mut _ as *mut _,
|
||||
)
|
||||
.wrap()?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn copy_to_host_async(&self, val: &mut [T], stream: &mut CudaStream) -> CudaResult<()> {
|
||||
assert!(
|
||||
self.len() == val.len(),
|
||||
"destination and source slices have different lengths"
|
||||
);
|
||||
let size = size_of::<T>() * self.len();
|
||||
if size != 0 {
|
||||
unsafe {
|
||||
cudaMemcpyAsync(
|
||||
val.as_mut_ptr() as *mut c_void,
|
||||
self.as_ptr() as *const c_void,
|
||||
size,
|
||||
cudaMemcpyKind::cudaMemcpyDeviceToHost,
|
||||
stream as *mut _ as *mut _,
|
||||
)
|
||||
.wrap()?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
pub type CudaMemPool = usize; // This is a placeholder, TODO: actually make this into a proper CUDA wrapper
|
||||
74
wrappers/rust/icicle-cuda-runtime/src/stream.rs
Normal file
74
wrappers/rust/icicle-cuda-runtime/src/stream.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use crate::bindings::{
|
||||
cudaStreamCreate, cudaStreamDefault, cudaStreamDestroy, cudaStreamNonBlocking, cudaStreamSynchronize, cudaStream_t,
|
||||
};
|
||||
use crate::error::{CudaResult, CudaResultWrap};
|
||||
use bitflags::bitflags;
|
||||
use std::mem::{forget, MaybeUninit};
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStream {
|
||||
pub(crate) handle: cudaStream_t,
|
||||
}
|
||||
|
||||
unsafe impl Sync for CudaStream {}
|
||||
|
||||
bitflags! {
|
||||
pub struct CudaStreamCreateFlags: u32 {
|
||||
const DEFAULT = cudaStreamDefault;
|
||||
const NON_BLOCKING = cudaStreamNonBlocking;
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaStream {
|
||||
pub(crate) fn from_handle(handle: cudaStream_t) -> Self {
|
||||
Self { handle }
|
||||
}
|
||||
|
||||
pub fn create() -> CudaResult<Self> {
|
||||
let mut handle = MaybeUninit::<cudaStream_t>::uninit();
|
||||
unsafe {
|
||||
cudaStreamCreate(handle.as_mut_ptr())
|
||||
.wrap_maybe_uninit(handle)
|
||||
.map(CudaStream::from_handle)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn destroy(self) -> CudaResult<()> {
|
||||
let handle = self.handle;
|
||||
forget(self);
|
||||
if handle.is_null() {
|
||||
Ok(())
|
||||
} else {
|
||||
unsafe { cudaStreamDestroy(handle).wrap() }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn synchronize(&self) -> CudaResult<()> {
|
||||
unsafe { cudaStreamSynchronize(self.handle).wrap() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CudaStream {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
handle: std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaStream {
|
||||
fn drop(&mut self) {
|
||||
let handle = self.handle;
|
||||
if handle.is_null() {
|
||||
return;
|
||||
}
|
||||
let _ = unsafe { cudaStreamDestroy(handle) };
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&CudaStream> for cudaStream_t {
|
||||
fn from(stream: &CudaStream) -> Self {
|
||||
stream.handle
|
||||
}
|
||||
}
|
||||
29
wrappers/rust/icicle-curves/icicle-bn254/Cargo.toml
Normal file
29
wrappers/rust/icicle-curves/icicle-bn254/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "icicle-bn254"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
description = "Rust wrapper for the CUDA implementation of BN254 pairing friendly elliptic curve by Ingonyama"
|
||||
homepage = "https://www.ingonyama.com"
|
||||
repository = "https://github.com/ingonyama-zk/icicle"
|
||||
|
||||
[dependencies]
|
||||
icicle-core = { path = "../../icicle-core" }
|
||||
icicle-cuda-runtime = { path = "../../icicle-cuda-runtime" }
|
||||
ark-bn254 = { version = "0.4.0", optional = true }
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
ark-bn254 = "0.4.0"
|
||||
ark-std = "0.4.0"
|
||||
ark-ff = "0.4.0"
|
||||
ark-ec = "0.4.0"
|
||||
ark-poly = "0.4.0"
|
||||
icicle-core = { path = "../../icicle-core", features = ["arkworks"] }
|
||||
icicle-bn254 = { path = ".", features = ["arkworks"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
arkworks = ["ark-bn254", "icicle-core/arkworks"]
|
||||
30
wrappers/rust/icicle-curves/icicle-bn254/build.rs
Normal file
30
wrappers/rust/icicle-curves/icicle-bn254/build.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use cmake::Config;
|
||||
use std::env::var;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-env-changed=CXXFLAGS");
|
||||
println!("cargo:rerun-if-changed=../../../../icicle");
|
||||
|
||||
let cargo_dir = var("CARGO_MANIFEST_DIR").unwrap();
|
||||
let profile = var("PROFILE").unwrap();
|
||||
|
||||
let target_output_dir = format!("{}/../../target/{}", cargo_dir, profile);
|
||||
|
||||
Config::new("./icicle")
|
||||
.define("BUILD_TESTS", "OFF") //TODO: feature
|
||||
// .define("CURVE", "bls12_381")
|
||||
.define("CURVE", "bn254")
|
||||
// .define("ECNTT_DEFINED", "") //TODO: feature
|
||||
.define("LIBRARY_OUTPUT_DIRECTORY", &target_output_dir)
|
||||
.define("CMAKE_BUILD_TYPE", "Release")
|
||||
.build_target("icicle")
|
||||
.build();
|
||||
|
||||
println!("cargo:rustc-link-search={}", &target_output_dir);
|
||||
|
||||
// println!("cargo:rustc-link-lib=icicle");
|
||||
println!("cargo:rustc-link-lib=ingo_bn254");
|
||||
println!("cargo:rustc-link-lib=stdc++");
|
||||
// println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
|
||||
println!("cargo:rustc-link-lib=cudart");
|
||||
}
|
||||
151
wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs
Normal file
151
wrappers/rust/icicle-curves/icicle-bn254/src/curve.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
#[cfg(feature = "arkworks")]
|
||||
use ark_bn254::{g1::Config as ArkG1Config, Fq, Fr};
|
||||
use icicle_core::curve::{Affine, CurveConfig, Projective};
|
||||
use icicle_core::field::{Field, FieldConfig};
|
||||
use std::ffi::{c_uint, c_void};
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub struct ScalarCfg {}
|
||||
|
||||
impl FieldConfig for ScalarCfg {
|
||||
#[cfg(feature = "arkworks")]
|
||||
type ArkField = Fr;
|
||||
}
|
||||
|
||||
const SCALAR_LIMBS: usize = 8;
|
||||
|
||||
pub type ScalarField = Field<SCALAR_LIMBS, ScalarCfg>;
|
||||
|
||||
extern "C" {
|
||||
fn GenerateScalars(scalars: *mut ScalarField, size: usize);
|
||||
}
|
||||
|
||||
pub(crate) fn generate_random_scalars(size: usize) -> Vec<ScalarField> {
|
||||
let mut res = vec![ScalarField::zero(); size];
|
||||
unsafe { GenerateScalars(&mut res[..] as *mut _ as *mut ScalarField, size) };
|
||||
res
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub struct BaseCfg {}
|
||||
|
||||
impl FieldConfig for BaseCfg {
|
||||
#[cfg(feature = "arkworks")]
|
||||
type ArkField = Fq;
|
||||
}
|
||||
|
||||
pub const BASE_LIMBS: usize = 8;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
pub struct CurveCfg {}
|
||||
|
||||
pub type BaseField = Field<BASE_LIMBS, BaseCfg>;
|
||||
pub type G1Affine = Affine<BaseField, CurveCfg>;
|
||||
pub type G1Projective = Projective<BaseField, CurveCfg>;
|
||||
|
||||
extern "C" {
|
||||
fn Eq(point1: *const c_void, point2: *const c_void) -> c_uint;
|
||||
fn ToAffine(point: *const c_void, point_out: *mut c_void);
|
||||
fn GenerateProjectivePoints(points: *mut G1Projective, size: usize);
|
||||
fn GenerateAffinePoints(points: *mut G1Affine, size: usize);
|
||||
}
|
||||
|
||||
impl CurveConfig for CurveCfg {
|
||||
fn eq_proj(point1: *const c_void, point2: *const c_void) -> c_uint {
|
||||
unsafe { Eq(point1, point2) }
|
||||
}
|
||||
|
||||
fn to_affine(point: *const c_void, point_out: *mut c_void) {
|
||||
unsafe { ToAffine(point, point_out) };
|
||||
}
|
||||
|
||||
#[cfg(feature = "arkworks")]
|
||||
type ArkSWConfig = ArkG1Config;
|
||||
}
|
||||
|
||||
pub(crate) fn generate_random_projective_points(size: usize) -> Vec<G1Projective> {
|
||||
let mut res = vec![G1Projective::zero(); size];
|
||||
unsafe { GenerateProjectivePoints(&mut res[..] as *mut _ as *mut G1Projective, size) };
|
||||
res
|
||||
}
|
||||
|
||||
pub(crate) fn generate_random_affine_points(size: usize) -> Vec<G1Affine> {
|
||||
let mut res = vec![G1Affine::zero(); size];
|
||||
unsafe { GenerateAffinePoints(&mut res[..] as *mut _ as *mut G1Affine, size) };
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
generate_random_affine_points, generate_random_projective_points, generate_random_scalars, BaseField, G1Affine,
|
||||
G1Projective, ScalarField, BASE_LIMBS,
|
||||
};
|
||||
use icicle_core::traits::ArkConvertible;
|
||||
|
||||
use ark_bn254::G1Affine as ArkG1Affine;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_equality() {
|
||||
let left = ScalarField::zero();
|
||||
let right = ScalarField::one();
|
||||
assert_ne!(left, right);
|
||||
let left = ScalarField::set_limbs(&[1]);
|
||||
assert_eq!(left, right);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ark_scalar_convert() {
|
||||
let size = 1 << 10;
|
||||
let scalars = generate_random_scalars(size);
|
||||
for scalar in scalars {
|
||||
assert_eq!(scalar.to_ark(), scalar.to_ark())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_affine_projective_convert() {
|
||||
let size = 1 << 10;
|
||||
let affine_points = generate_random_affine_points(size);
|
||||
let projective_points = generate_random_projective_points(size);
|
||||
for affine_point in affine_points {
|
||||
let projective_eqivalent: G1Projective = affine_point.into();
|
||||
assert_eq!(affine_point, projective_eqivalent.into());
|
||||
}
|
||||
for projective_point in projective_points {
|
||||
let affine_eqivalent: G1Affine = projective_point.into();
|
||||
assert_eq!(projective_point, affine_eqivalent.into());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_point_equality() {
|
||||
let left = G1Projective::zero();
|
||||
let right = G1Projective::zero();
|
||||
assert_eq!(left, right);
|
||||
let right = G1Projective::set_limbs(&[0; BASE_LIMBS], &[2; BASE_LIMBS], &[0; BASE_LIMBS]);
|
||||
assert_eq!(left, right);
|
||||
let right = G1Projective::set_limbs(
|
||||
&[0; BASE_LIMBS],
|
||||
&[4; BASE_LIMBS],
|
||||
&BaseField::set_limbs(&[2]).get_limbs(),
|
||||
);
|
||||
assert_ne!(left, right);
|
||||
let left = G1Projective::set_limbs(&[0; BASE_LIMBS], &[2; BASE_LIMBS], &BaseField::one().get_limbs());
|
||||
assert_eq!(left, right);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ark_point_convert() {
|
||||
let size = 1 << 10;
|
||||
let affine_points = generate_random_affine_points(size);
|
||||
for affine_point in affine_points {
|
||||
let ark_projective = Into::<G1Projective>::into(affine_point).to_ark();
|
||||
let ark_affine: ArkG1Affine = ark_projective.into();
|
||||
assert!(ark_affine.is_on_curve());
|
||||
assert!(ark_affine.is_in_correct_subgroup_assuming_on_curve());
|
||||
let affine_after_conversion: G1Affine = G1Projective::from_ark(ark_projective).into();
|
||||
assert_eq!(affine_point, affine_after_conversion);
|
||||
}
|
||||
}
|
||||
}
|
||||
3
wrappers/rust/icicle-curves/icicle-bn254/src/lib.rs
Normal file
3
wrappers/rust/icicle-curves/icicle-bn254/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod curve;
|
||||
pub mod msm;
|
||||
pub mod ntt;
|
||||
101
wrappers/rust/icicle-curves/icicle-bn254/src/msm/mod.rs
Normal file
101
wrappers/rust/icicle-curves/icicle-bn254/src/msm/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use crate::curve::{G1Affine, G1Projective, ScalarField};
|
||||
use icicle_core::msm::MSMConfig;
|
||||
use icicle_cuda_runtime::error::{CudaError, CudaResult, CudaResultWrap};
|
||||
|
||||
extern "C" {
|
||||
#[link_name = "bn254MSMCuda"]
|
||||
fn msm_cuda<'a>(
|
||||
scalars: *const ScalarField,
|
||||
points: *const G1Affine,
|
||||
count: usize,
|
||||
config: MSMConfig<'a>,
|
||||
out: *mut G1Projective,
|
||||
) -> CudaError;
|
||||
|
||||
#[link_name = "bn254GetDefaultMSMConfig"]
|
||||
fn GetDefaultMSMConfig() -> MSMConfig<'static>;
|
||||
}
|
||||
|
||||
pub fn get_default_msm_config() -> MSMConfig<'static> {
|
||||
unsafe { GetDefaultMSMConfig() }
|
||||
}
|
||||
|
||||
pub fn msm<'a>(
|
||||
scalars: &[ScalarField],
|
||||
points: &[G1Affine],
|
||||
cfg: MSMConfig<'a>,
|
||||
results: &mut [G1Projective],
|
||||
) -> CudaResult<()> {
|
||||
if points.len() != scalars.len() {
|
||||
return Err(CudaError::cudaErrorInvalidValue);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
msm_cuda(
|
||||
scalars as *const _ as *const ScalarField,
|
||||
points as *const _ as *const G1Affine,
|
||||
points.len(),
|
||||
cfg,
|
||||
results as *mut _ as *mut G1Projective,
|
||||
)
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use ark_bn254::G1Projective as ArkG1Projective;
|
||||
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
|
||||
|
||||
use crate::{
|
||||
curve::{generate_random_affine_points, generate_random_scalars, G1Projective},
|
||||
msm::{get_default_msm_config, msm},
|
||||
};
|
||||
use icicle_core::traits::ArkConvertible;
|
||||
use icicle_cuda_runtime::memory::DeviceSlice;
|
||||
use icicle_cuda_runtime::stream::CudaStream;
|
||||
|
||||
#[test]
|
||||
fn test_msm() {
|
||||
let log_test_sizes = [20];
|
||||
|
||||
for log_test_size in log_test_sizes {
|
||||
let count = 1 << log_test_size;
|
||||
let points = generate_random_affine_points(count);
|
||||
let scalars = generate_random_scalars(count);
|
||||
|
||||
let mut msm_results = DeviceSlice::cuda_malloc(1).unwrap();
|
||||
let stream = CudaStream::create().unwrap();
|
||||
let mut cfg = get_default_msm_config();
|
||||
cfg.ctx
|
||||
.stream = &stream;
|
||||
cfg.is_async = true;
|
||||
cfg.are_results_on_device = true;
|
||||
msm(&scalars, &points, cfg, &mut msm_results.as_slice()).unwrap();
|
||||
|
||||
// this happens on CPU in parallel to the GPU MSM computations
|
||||
let point_r_ark: Vec<_> = points
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
let scalars_r_ark: Vec<_> = scalars
|
||||
.iter()
|
||||
.map(|x| x.to_ark())
|
||||
.collect();
|
||||
let msm_result_ark: ArkG1Projective = VariableBaseMSM::msm(&point_r_ark, &scalars_r_ark).unwrap();
|
||||
|
||||
let mut msm_host_result = vec![G1Projective::zero(); 1];
|
||||
msm_results
|
||||
.copy_to_host(&mut msm_host_result[..])
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
stream
|
||||
.destroy()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
|
||||
}
|
||||
}
|
||||
}
|
||||
61
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/config.rs
Normal file
61
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/config.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use crate::curve::*;
|
||||
use icicle_core::ntt::{Butterfly, Decimation, NTTConfigCuda, Ordering};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
|
||||
pub(super) type ECNTTConfig<'a> = NTTConfigCuda<'a, G1Projective, ScalarField>;
|
||||
pub(super) type NTTConfig<'a> = NTTConfigCuda<'a, ScalarField, ScalarField>;
|
||||
|
||||
pub(super) fn get_ntt_config<E, S>(size: usize, ctx: DeviceContext) -> NTTConfigCuda<E, S> {
|
||||
//TODO: implement on CUDA side
|
||||
|
||||
NTTConfigCuda::<E, S> {
|
||||
inout: 0 as _, // inout as *mut _ as *mut ScalarField,
|
||||
is_input_on_device: false,
|
||||
is_inverse: false,
|
||||
ordering: Ordering::kNN,
|
||||
decimation: Decimation::kDIF,
|
||||
butterfly: Butterfly::kCooleyTukey,
|
||||
is_coset: false,
|
||||
coset_gen: 0 as _, //TODO: ?
|
||||
twiddles: 0 as _, //TODO: ?,
|
||||
inv_twiddles: 0 as _, //TODO: ?,
|
||||
size: size as i32,
|
||||
batch_size: 0 as i32,
|
||||
is_preserving_twiddles: true,
|
||||
is_output_on_device: false,
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn get_ntt_default_config<E, S>(size: usize) -> NTTConfigCuda<'static, E, S> {
|
||||
//TODO: implement on CUDA side
|
||||
let ctx = get_default_device_context();
|
||||
|
||||
// let root_of_unity = S::default(); //TODO: implement on CUDA side
|
||||
|
||||
let config = get_ntt_config(size, ctx);
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
pub(super) fn get_ntt_config_with_input(ntt_intt_result: &mut [ScalarField], size: usize, batches: usize) -> NTTConfig {
|
||||
NTTConfig {
|
||||
inout: ntt_intt_result as *mut _ as *mut ScalarField,
|
||||
is_input_on_device: false,
|
||||
is_inverse: false,
|
||||
ordering: Ordering::kNN,
|
||||
decimation: Decimation::kDIF,
|
||||
butterfly: Butterfly::kCooleyTukey,
|
||||
is_coset: false,
|
||||
coset_gen: &[ScalarField::zero()] as _, //TODO: ?
|
||||
twiddles: 0 as *const ScalarField, //TODO: ?,
|
||||
inv_twiddles: 0 as *const ScalarField, //TODO: ?,
|
||||
size: size as _,
|
||||
batch_size: batches as i32,
|
||||
is_preserving_twiddles: true,
|
||||
is_output_on_device: true,
|
||||
ctx: get_default_device_context(),
|
||||
}
|
||||
}
|
||||
192
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/domain.rs
Normal file
192
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/domain.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use icicle_core::ntt::{Butterfly, Decimation, NTTConfigCuda, Ordering};
|
||||
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
|
||||
use icicle_cuda_runtime::memory::DeviceSlice;
|
||||
use std::default;
|
||||
|
||||
pub(super) type ECNTTDomain<'a> = Domain<'a, G1Projective, ScalarField>;
|
||||
pub(super) type NTTDomain<'a> = Domain<'a, ScalarField, ScalarField>;
|
||||
|
||||
use crate::curve::*;
|
||||
|
||||
use super::{config::*, ntt_internal};
|
||||
|
||||
/// Represents the NTT domain
|
||||
pub struct Domain<'a, E, S> {
|
||||
config: NTTConfigCuda<'a, E, S>,
|
||||
}
|
||||
|
||||
impl<'a, E, S> Domain<'a, E, S> {
|
||||
pub fn new(size: usize, ctx: DeviceContext<'a>) -> Self {
|
||||
Domain {
|
||||
config: get_ntt_config(size, ctx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_output_on_device(&self) -> Result<*mut E, &'static str> {
|
||||
if self
|
||||
.config
|
||||
.is_output_on_device
|
||||
{
|
||||
Ok(self
|
||||
.config
|
||||
.inout)
|
||||
} else {
|
||||
Err("Output should be on device.")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_input_on_device(&self) -> Result<*mut E, &'static str> {
|
||||
if self
|
||||
.config
|
||||
.is_input_on_device
|
||||
{
|
||||
Ok(self
|
||||
.config
|
||||
.inout)
|
||||
} else {
|
||||
Err("Input should be on device.")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_input(&self) -> Result<*mut E, &'static str> {
|
||||
if !self
|
||||
.config
|
||||
.is_input_on_device
|
||||
{
|
||||
Ok(self
|
||||
.config
|
||||
.inout)
|
||||
} else {
|
||||
Err("Output is on device.")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_output(&self) -> Result<*mut E, &'static str> {
|
||||
if !self
|
||||
.config
|
||||
.is_output_on_device
|
||||
{
|
||||
Ok(self
|
||||
.config
|
||||
.inout)
|
||||
} else {
|
||||
Err("Output is on device.")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_for_default_context(size: usize) -> Self {
|
||||
let ctx = get_default_device_context();
|
||||
// let default_root_of_unity = S::default(); //TODO: implement
|
||||
let domain = Domain::new(size, ctx);
|
||||
domain
|
||||
}
|
||||
}
|
||||
|
||||
// Add implementations for other methods and structs as needed.
|
||||
|
||||
impl<'a, E: 'static, S: 'static> Domain<'a, E, S> {
|
||||
// ... previous methods ...
|
||||
|
||||
// NTT methods
|
||||
pub fn ntt(&mut self, inout: &mut [E]) {
|
||||
let batch_size = 1;
|
||||
|
||||
let size = inout.len();
|
||||
|
||||
if size
|
||||
!= self
|
||||
.config
|
||||
.size as _
|
||||
{
|
||||
//TODO: test for this error
|
||||
panic!(
|
||||
"input lenght: {} does not match domain size: {}",
|
||||
size,
|
||||
self.config
|
||||
.size
|
||||
)
|
||||
}
|
||||
|
||||
self.config
|
||||
.inout = inout.as_mut_ptr(); // as *mut _ as *mut E;
|
||||
self.config
|
||||
.is_inverse = false;
|
||||
self.config
|
||||
.is_input_on_device = false;
|
||||
self.config
|
||||
.is_output_on_device = false;
|
||||
// self.config
|
||||
// .ordering = Ordering::default(); //TODO: each call?
|
||||
self.config
|
||||
.batch_size = batch_size as i32;
|
||||
|
||||
ntt_internal(&mut self.config);
|
||||
}
|
||||
|
||||
pub fn ntt_on_device(&mut self, inout: &mut DeviceSlice<E>) {
|
||||
// Implementation for NTT on device
|
||||
}
|
||||
|
||||
pub fn ntt_batch(&mut self, inout: &mut [E]) {
|
||||
// Implementation for batched NTT
|
||||
}
|
||||
|
||||
pub fn ntt_batch_on_device(&mut self, inout: &mut DeviceSlice<E>) {
|
||||
// Implementation for batched NTT on device
|
||||
}
|
||||
|
||||
pub fn ntt_coset(&mut self, inout: &mut [E], coset: &mut [E]) {
|
||||
// Implementation for NTT with coset
|
||||
}
|
||||
|
||||
pub fn ntt_coset_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
|
||||
// Implementation for NTT with coset on device
|
||||
}
|
||||
|
||||
pub fn ntt_coset_batch(&mut self, inout: &mut [E], coset: &mut [E]) {
|
||||
// Implementation for batched NTT with coset
|
||||
}
|
||||
|
||||
pub fn ntt_coset_batch_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
|
||||
// Implementation for batched NTT with coset on device
|
||||
}
|
||||
|
||||
// iNTT methods
|
||||
pub fn intt(&mut self, inout: &mut [E]) {
|
||||
// Implementation for iNTT
|
||||
}
|
||||
|
||||
pub fn intt_on_device(&mut self, inout: &mut DeviceSlice<E>) {
|
||||
// Implementation for iNTT on device
|
||||
}
|
||||
|
||||
pub fn intt_batch(&mut self, inout: &mut [E]) {
|
||||
// Implementation for batched iNTT
|
||||
}
|
||||
|
||||
pub fn intt_batch_on_device(&mut self, inout: &mut DeviceSlice<E>) {
|
||||
// Implementation for batched iNTT on device
|
||||
}
|
||||
|
||||
pub fn intt_coset(&mut self, inout: &mut [E], coset: &mut [E]) {
|
||||
// Implementation for iNTT with coset
|
||||
}
|
||||
|
||||
pub fn intt_coset_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
|
||||
// Implementation for iNTT with coset on device
|
||||
}
|
||||
|
||||
pub fn intt_coset_batch(&mut self, inout: &mut [E], coset: &mut [E]) {
|
||||
// Implementation for batched iNTT with coset
|
||||
}
|
||||
|
||||
pub fn intt_coset_batch_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
|
||||
// Implementation for batched iNTT with coset on device
|
||||
}
|
||||
|
||||
// Ordering setter
|
||||
pub fn set_ordering(&mut self, ordering: Ordering) {
|
||||
self.config
|
||||
.ordering = ordering;
|
||||
}
|
||||
}
|
||||
322
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs
Normal file
322
wrappers/rust/icicle-curves/icicle-bn254/src/ntt/mod.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
mod config;
|
||||
pub mod domain;
|
||||
|
||||
use std::any::TypeId;
|
||||
|
||||
use crate::curve::*;
|
||||
|
||||
use self::config::*;
|
||||
|
||||
use icicle_core::ntt::{Butterfly, Decimation, Ordering};
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
|
||||
extern "C" {
|
||||
#[link_name = "NTTDefaultContextCuda"]
|
||||
fn ntt_cuda(config: *mut NTTConfig) -> CudaError;
|
||||
}
|
||||
|
||||
pub(crate) fn ntt_wip(
|
||||
inout: &mut [ScalarField],
|
||||
is_inverse: bool,
|
||||
is_input_on_device: bool,
|
||||
ordering: Ordering,
|
||||
is_output_on_device: bool,
|
||||
batch_size: usize,
|
||||
) {
|
||||
let mut batch_size = batch_size;
|
||||
if batch_size == 0 {
|
||||
batch_size = 1;
|
||||
}
|
||||
|
||||
let size = inout.len() / batch_size;
|
||||
|
||||
let mut config = get_ntt_default_config::<ScalarField, ScalarField>(size);
|
||||
|
||||
config.inout = inout as *mut _ as *mut ScalarField;
|
||||
config.is_inverse = is_inverse;
|
||||
config.is_input_on_device = is_input_on_device;
|
||||
config.is_output_on_device = is_output_on_device;
|
||||
config.ordering = ordering;
|
||||
config.batch_size = batch_size as i32;
|
||||
|
||||
ntt_internal(&mut config);
|
||||
}
|
||||
|
||||
pub(self) fn ntt_internal<TConfig>(config: *mut TConfig) -> CudaError {
|
||||
let result_code = unsafe { ntt_cuda(config as _) };
|
||||
// let typeid = TypeId::of::<TConfig>();
|
||||
// if typeid == TypeId::of::<NTTConfig>() {
|
||||
// result_code = unsafe { ntt_cuda(config as _) };
|
||||
// } else {
|
||||
// result_code = CudaError::cudaSuccess; //TODO: unsafe { ecntt_cuda(config as _) };
|
||||
// }
|
||||
|
||||
// if result_code != CudaError::cudaSuccess {
|
||||
// println!("_result_code = {:?}", result_code);
|
||||
// }
|
||||
|
||||
return CudaError::cudaSuccess;
|
||||
}
|
||||
|
||||
pub(self) fn ecntt_internal(config: *mut ECNTTConfig) -> u32 {
|
||||
let result_code = 0; //TODO: unsafe { ecntt_cuda(config) };
|
||||
if result_code != 0 {
|
||||
println!("_result_code = {}", result_code);
|
||||
}
|
||||
|
||||
return result_code;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use ark_bn254::{Fr, G1Affine as arkG1Affine, G1Projective as arkG1Projective};
|
||||
// use ark_bls12_381::{Fr, G1Projective};
|
||||
use ark_ff::PrimeField;
|
||||
use ark_poly::EvaluationDomain;
|
||||
use ark_poly::GeneralEvaluationDomain;
|
||||
use ark_std::UniformRand;
|
||||
use std::slice;
|
||||
|
||||
use crate::ntt::domain::NTTDomain;
|
||||
use crate::{curve::*, ntt::*};
|
||||
use icicle_core::traits::ArkConvertible;
|
||||
|
||||
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
|
||||
fn is_power_of_two(n: u32) -> bool {
|
||||
n != 0 && n & (n - 1) == 0
|
||||
}
|
||||
assert!(is_power_of_two(order));
|
||||
let mask = order - 1;
|
||||
let binary = format!("{:0width$b}", n, width = (32 - mask.leading_zeros()) as usize);
|
||||
let reversed = binary
|
||||
.chars()
|
||||
.rev()
|
||||
.collect::<String>();
|
||||
u32::from_str_radix(&reversed, 2).unwrap()
|
||||
}
|
||||
|
||||
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
|
||||
l.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| l[reverse_bit_order(i as u32, l.len() as u32) as usize])
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt() {
|
||||
//NTT
|
||||
let test_size = 1 << 11;
|
||||
let batches = 1;
|
||||
|
||||
let full_test_size = test_size * batches;
|
||||
let scalars_batch: Vec<ScalarField> = generate_random_scalars(full_test_size);
|
||||
|
||||
// let scalars_batch: Vec<ScalarField> = (0..full_test_size)
|
||||
// .into_iter()
|
||||
// .map(|x| {
|
||||
// // if x % 1 == 0 {
|
||||
// if x % 2 == 0 {
|
||||
// ScalarField::one()
|
||||
// } else {
|
||||
// ScalarField::zero()
|
||||
// }
|
||||
// })
|
||||
// .collect();
|
||||
|
||||
let mut ntt_result = scalars_batch.clone();
|
||||
|
||||
let ark_domain = GeneralEvaluationDomain::<Fr>::new(test_size).unwrap();
|
||||
let mut domain = NTTDomain::new_for_default_context(test_size);
|
||||
|
||||
let ark_scalars_batch = scalars_batch
|
||||
.clone()
|
||||
.iter()
|
||||
.map(|v| v.to_ark())
|
||||
.collect::<Vec<Fr>>();
|
||||
let mut ark_ntt_result = ark_scalars_batch.clone();
|
||||
|
||||
ark_domain.fft_in_place(&mut ark_ntt_result);
|
||||
|
||||
assert_ne!(ark_ntt_result, ark_scalars_batch);
|
||||
|
||||
// do ntt
|
||||
// ntt_wip(&mut ntt_result, false, false, Ordering::kNN, false, batches);
|
||||
domain.ntt(&mut ntt_result); //single ntt
|
||||
let ntt_result_as_ark = ntt_result
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<Fr>>();
|
||||
|
||||
assert_ne!(ntt_result, scalars_batch);
|
||||
assert_eq!(ark_ntt_result, ntt_result_as_ark);
|
||||
|
||||
let mut ark_intt_result = ark_ntt_result;
|
||||
|
||||
ark_domain.ifft_in_place(&mut ark_intt_result);
|
||||
assert_eq!(ark_intt_result, ark_scalars_batch);
|
||||
|
||||
// check that ntt output is different from input
|
||||
assert_ne!(ntt_result, scalars_batch);
|
||||
|
||||
// do intt
|
||||
let mut intt_result = ntt_result;
|
||||
|
||||
ntt_wip(&mut intt_result, true, false, Ordering::kNN, false, batches);
|
||||
|
||||
assert!(ark_intt_result == ark_scalars_batch);
|
||||
assert!(intt_result == scalars_batch);
|
||||
|
||||
let mut ntt_intt_result = intt_result;
|
||||
ntt_wip(&mut ntt_intt_result, false, false, Ordering::kNR, false, batches);
|
||||
assert!(ntt_intt_result != scalars_batch);
|
||||
ntt_wip(&mut ntt_intt_result, true, false, Ordering::kRN, false, batches);
|
||||
assert!(ntt_intt_result == scalars_batch);
|
||||
|
||||
let mut ntt_intt_result = list_to_reverse_bit_order(&ntt_intt_result);
|
||||
ntt_wip(&mut ntt_intt_result, false, false, Ordering::kRR, false, batches);
|
||||
assert!(ntt_intt_result != scalars_batch);
|
||||
ntt_wip(&mut ntt_intt_result, true, false, Ordering::kRN, false, batches);
|
||||
assert!(ntt_intt_result == scalars_batch);
|
||||
|
||||
////
|
||||
let size = ntt_intt_result.len() / batches;
|
||||
|
||||
let mut config = get_ntt_config_with_input(&mut ntt_intt_result, size, batches);
|
||||
|
||||
ntt_internal(&mut config);
|
||||
|
||||
//host
|
||||
let mut ntt_result = scalars_batch.clone();
|
||||
ntt_wip(&mut ntt_result, false, false, Ordering::kNR, false, batches);
|
||||
|
||||
// let mut buff1 = DeviceBuffer::from_slice(&scalars_batch[..]).unwrap();
|
||||
// let dev_ptr1 = buff1
|
||||
// .as_device_ptr()
|
||||
// .as_raw_mut();
|
||||
|
||||
// let buff_len = buff1.len();
|
||||
|
||||
// std::mem::forget(buff1);
|
||||
|
||||
// let buff_from_dev_ptr = unsafe { DeviceBuffer::from_raw_parts(DevicePointer::wrap(dev_ptr1), buff_len) };
|
||||
// let mut from_device = vec![ScalarField::zero(); scalars_batch.len()];
|
||||
// buff_from_dev_ptr
|
||||
// .copy_to(&mut from_device)
|
||||
// .unwrap();
|
||||
|
||||
// assert_eq!(from_device, scalars_batch);
|
||||
|
||||
// host - device - device - host
|
||||
let mut ntt_intt_result = scalars_batch.clone();
|
||||
|
||||
let mut config = get_ntt_config_with_input(&mut ntt_intt_result, size, batches);
|
||||
|
||||
config.is_input_on_device = false;
|
||||
config.is_output_on_device = true;
|
||||
// config.is_preserving_twiddles = true; // TODO: same as in get_ntt_config
|
||||
config.ordering = Ordering::kNR;
|
||||
|
||||
ntt_internal(&mut config); //twiddles are preserved after first call
|
||||
|
||||
// config.is_preserving_twiddles = true; //TODO: same as in get_ntt_config
|
||||
config.is_inverse = true;
|
||||
config.is_input_on_device = false;
|
||||
config.is_output_on_device = true;
|
||||
config.ordering = Ordering::kNR;
|
||||
|
||||
ntt_internal(&mut config); //inv_twiddles are preserved after first call
|
||||
|
||||
let ntt_intt_result = &mut scalars_batch.clone()[..];
|
||||
let raw_scalars_batch_copy = ntt_intt_result as *mut _ as *mut ScalarField;
|
||||
|
||||
let config_inout2: &mut [ScalarField] =
|
||||
unsafe { std::slice::from_raw_parts_mut(raw_scalars_batch_copy, config.size as usize) };
|
||||
assert_eq!(config_inout2, scalars_batch);
|
||||
|
||||
config.is_preserving_twiddles = true; //TODO: same as in get_ntt_config
|
||||
|
||||
config.inout = raw_scalars_batch_copy;
|
||||
|
||||
config.is_inverse = false;
|
||||
config.is_input_on_device = false;
|
||||
config.is_output_on_device = true;
|
||||
config.ordering = Ordering::kNR;
|
||||
|
||||
ntt_internal(&mut config);
|
||||
|
||||
config.is_inverse = true;
|
||||
config.is_input_on_device = true;
|
||||
config.is_output_on_device = false;
|
||||
config.ordering = Ordering::kRN;
|
||||
|
||||
ntt_internal(&mut config);
|
||||
|
||||
let result_from_device: &mut [ScalarField] =
|
||||
unsafe { std::slice::from_raw_parts_mut(config.inout, scalars_batch.len()) };
|
||||
|
||||
assert_eq!(result_from_device, &scalars_batch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_ntt() {
|
||||
//NTT
|
||||
let test_size = 1 << 11;
|
||||
let batches = 2;
|
||||
|
||||
let full_test_size = test_size * batches;
|
||||
let scalars_batch: Vec<ScalarField> = generate_random_scalars(full_test_size);
|
||||
|
||||
let mut scalar_vec_of_vec: Vec<Vec<ScalarField>> = Vec::new();
|
||||
|
||||
for i in 0..batches {
|
||||
scalar_vec_of_vec.push(scalars_batch[i * test_size..(i + 1) * test_size].to_vec());
|
||||
}
|
||||
|
||||
let mut ntt_result = scalars_batch.clone();
|
||||
|
||||
// do batch ntt
|
||||
ntt_wip(&mut ntt_result, false, false, Ordering::kNN, false, batches);
|
||||
|
||||
let mut ntt_result_vec_of_vec = Vec::new();
|
||||
|
||||
// do ntt for every chunk
|
||||
for i in 0..batches {
|
||||
ntt_result_vec_of_vec.push(scalar_vec_of_vec[i].clone());
|
||||
|
||||
ntt_wip(&mut ntt_result_vec_of_vec[i], false, false, Ordering::kNN, false, 1);
|
||||
}
|
||||
|
||||
// check that the ntt of each vec of scalars is equal to the ntt of the specific batch
|
||||
for i in 0..batches {
|
||||
assert_eq!(ntt_result_vec_of_vec[i], ntt_result[i * test_size..(i + 1) * test_size]);
|
||||
}
|
||||
|
||||
// check that ntt output is different from input
|
||||
assert_ne!(ntt_result, scalars_batch);
|
||||
|
||||
let mut intt_result = ntt_result.clone();
|
||||
|
||||
// do batch intt
|
||||
// intt_batch(&mut intt_result, test_size, 0);
|
||||
ntt_wip(&mut intt_result, true, false, Ordering::kNN, false, batches);
|
||||
|
||||
let mut intt_result_vec_of_vec = Vec::new();
|
||||
|
||||
// do intt for every chunk
|
||||
for i in 0..batches {
|
||||
intt_result_vec_of_vec.push(ntt_result_vec_of_vec[i].clone());
|
||||
// intt(&mut intt_result_vec_of_vec[i], 0);
|
||||
ntt_wip(&mut intt_result_vec_of_vec[i], true, false, Ordering::kNN, false, 1);
|
||||
}
|
||||
|
||||
// check that the intt of each vec of scalars is equal to the intt of the specific batch
|
||||
for i in 0..batches {
|
||||
assert_eq!(
|
||||
intt_result_vec_of_vec[i],
|
||||
intt_result[i * test_size..(i + 1) * test_size]
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(intt_result, scalars_batch);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user