Update Rust apis (#262)

* fix memory error in single_stage_multi_reduction_kernel (#235)

* refactor

* refactor

* revert

* refactor: clang format

* Update icicle/appUtils/msm/msm.cu

* Added separate device context struct, returned lde

* wip - msm and eq

* added lde to cmake

* Montgomery param added in lde.cu mul function

* fixed on_device for ntt and lde

* CamelCase

* fixed msm_test, int unification, google guilde

* wip - ntt crash debugging

* async MSM with a rust wrapper

* wip ntt tests with corretness

* hotfix for correctness > 2^9

* wip on device inout mixing with correctness

* cleanup

* preserving twiddles after first call

* fixed twiddles preserving

* formatting

* removed some printing

* disable ecntt temporarily

* format

* rust fmt

* exclude target from format

* passing ntt after merge

* hotfix for linking issue

* format

* format

* draft of pr comments + correctness restored

* wip refactor + format

* domain wip

* rust format

* Merged feature branch in and Rust MSM correctness

* rust build for correct curve

* Slowdown fixed by passing release flag to cmake

* WIP field and curve

* still wip field and curve

* field and curve in rust 1.0

* Refactored rust into several crates

* Arkworks is now an option, bn254 crate created

* Rust msm and ntt wip

* A version of rust msm done, cuda runtime wrapped

* refactor rust by creating a curve folder

* vec_ops instead of lde for now

* format

---------

Co-authored-by: ImmanuelSegol <3ditds@gmail.com>
Co-authored-by: Vitalii <vitalii@ingonyama.com>
This commit is contained in:
DmytroTym
2023-12-03 13:32:50 +02:00
committed by GitHub
parent 133a1b28bc
commit dfa5b10adb
60 changed files with 2376 additions and 9236 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@
**/.DS_Store
**/Cargo.lock
**/icicle/build/
**/wrappers/rust/icicle-cuda-runtime/src/bindings.rs

View File

@@ -1,49 +0,0 @@
[package]
name = "icicle"
version = "0.1.0"
edition = "2021"
authors = [ "Ingonyama" ]
description = "An implementation of the Ingonyama CUDA Library"
homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/icicle"
[[bench]]
name = "ntt"
path = "benches/ntt.rs"
harness = false
[[bench]]
name = "msm"
path = "benches/msm.rs"
harness = false
[dependencies]
hex = "*"
ark-std = "0.3.0"
ark-ff = "0.3.0"
ark-poly = "0.3.0"
ark-ec = { version = "0.3.0", features = [ "parallel" ] }
ark-bls12-381 = "0.3.0"
ark-bls12-377 = "0.3.0"
ark-bn254 = "0.3.0"
serde = { version = "1.0", features = ["derive"] }
serde_derive = "1.0"
serde_cbor = "0.11.2"
rustacuda = "0.1"
rustacuda_core = "0.1"
rustacuda_derive = "0.1"
rand = "*" #TODO: move rand and ark dependencies to dev once random scalar/point generation is done "natively"
[build-dependencies]
cc = { version = "1.0", features = ["parallel"] }
[dev-dependencies]
"criterion" = "0.4.0"
[features]
default = ["bls12_381"]
bls12_381 = ["ark-bls12-381/curve"]
g2 = []

View File

@@ -1,54 +0,0 @@
extern crate criterion;
use criterion::{criterion_group, criterion_main, Criterion};
use icicle::test_bls12_381::{commit_batch_bls12_381, generate_random_points_bls12_381, set_up_scalars_bls12_381};
use icicle::utils::*;
#[cfg(feature = "g2")]
use icicle::{commit_batch_g2, field::ExtensionField};
use rustacuda::prelude::*;
const LOG_MSM_SIZES: [usize; 1] = [12];
const BATCH_SIZES: [usize; 2] = [128, 256];
fn bench_msm(c: &mut Criterion) {
let mut group = c.benchmark_group("MSM");
for log_msm_size in LOG_MSM_SIZES {
for batch_size in BATCH_SIZES {
let msm_size = 1 << log_msm_size;
let (scalars, _, _) = set_up_scalars_bls12_381(msm_size, 0, false);
let batch_scalars = vec![scalars; batch_size].concat();
let mut d_scalars = DeviceBuffer::from_slice(&batch_scalars[..]).unwrap();
let points = generate_random_points_bls12_381(msm_size, get_rng(None));
let batch_points = vec![points; batch_size].concat();
let mut d_points = DeviceBuffer::from_slice(&batch_points[..]).unwrap();
#[cfg(feature = "g2")]
let g2_points = generate_random_points::<ExtensionField>(msm_size, get_rng(None));
#[cfg(feature = "g2")]
let g2_batch_points = vec![g2_points; batch_size].concat();
#[cfg(feature = "g2")]
let mut d_g2_points = DeviceBuffer::from_slice(&g2_batch_points[..]).unwrap();
group
.sample_size(30)
.bench_function(
&format!("MSM of size 2^{} in batch {}", log_msm_size, batch_size),
|b| b.iter(|| commit_batch_bls12_381(&mut d_points, &mut d_scalars, batch_size)),
);
#[cfg(feature = "g2")]
group
.sample_size(10)
.bench_function(
&format!("G2 MSM of size 2^{} in batch {}", log_msm_size, batch_size),
|b| b.iter(|| commit_batch_g2(&mut d_g2_points, &mut d_scalars, batch_size)),
);
}
}
}
criterion_group!(msm_benches, bench_msm);
criterion_main!(msm_benches);

View File

@@ -1,85 +0,0 @@
extern crate criterion;
use criterion::{criterion_group, criterion_main, Criterion};
use icicle::test_bls12_381::*;
const LOG_NTT_SIZES: [usize; 3] = [20, 9, 10];
const BATCH_SIZES: [usize; 3] = [1, 512, 1024];
fn bench_ntt(c: &mut Criterion) {
let mut group = c.benchmark_group("NTT");
for log_ntt_size in LOG_NTT_SIZES {
for batch_size in BATCH_SIZES {
let ntt_size = 1 << log_ntt_size;
if ntt_size * batch_size > 1 << 25 {
continue;
}
let scalar_samples = 20;
let (_, mut d_evals, mut d_domain) = set_up_scalars_bls12_381(ntt_size * batch_size, log_ntt_size, true);
group
.sample_size(scalar_samples)
.bench_function(
&format!("Scalar NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| evaluate_scalars_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size)),
);
group
.sample_size(scalar_samples)
.bench_function(
&format!("Scalar iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| interpolate_scalars_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size)),
);
group
.sample_size(scalar_samples)
.bench_function(
&format!("Scalar inplace NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| ntt_inplace_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size, false, 0)),
);
group
.sample_size(scalar_samples)
.bench_function(
&format!("Scalar inplace iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| ntt_inplace_batch_bls12_381(&mut d_evals, &mut d_domain, batch_size, true, 0)),
);
drop(d_evals);
drop(d_domain);
if ntt_size * batch_size > 1 << 18 {
continue;
}
let point_samples = 10;
let (_, mut d_points_evals, mut d_domain) =
set_up_points_bls12_381(ntt_size * batch_size, log_ntt_size, true);
group
.sample_size(point_samples)
.bench_function(
&format!("EC NTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| interpolate_points_batch_bls12_381(&mut d_points_evals, &mut d_domain, batch_size)),
);
group
.sample_size(point_samples)
.bench_function(
&format!("EC iNTT of size 2^{} in batch {}", log_ntt_size, batch_size),
|b| b.iter(|| evaluate_points_batch_bls12_381(&mut d_points_evals, &mut d_domain, batch_size)),
);
drop(d_points_evals);
drop(d_domain);
}
}
}
criterion_group!(ntt_benches, bench_ntt);
criterion_main!(ntt_benches);

View File

@@ -1,31 +0,0 @@
use std::env;
fn main() {
//TODO: check cargo features selected
//TODO: can conflict/duplicate with make ?
println!("cargo:rerun-if-env-changed=CXXFLAGS");
println!("cargo:rerun-if-changed=./icicle");
let arch_type = env::var("ARCH_TYPE").unwrap_or(String::from("native"));
let stream_type = env::var("DEFAULT_STREAM").unwrap_or(String::from("legacy"));
let mut arch = String::from("-arch=");
arch.push_str(&arch_type);
let mut stream = String::from("-default-stream=");
stream.push_str(&stream_type);
let mut nvcc = cc::Build::new();
println!("Compiling icicle library using arch: {}", &arch);
if cfg!(feature = "g2") {
nvcc.define("G2_DEFINED", None);
}
nvcc.cuda(true);
nvcc.debug(false);
nvcc.flag(&arch);
nvcc.flag(&stream);
nvcc.files(["./icicle/curves/index.cu"]);
nvcc.compile("ingo_icicle"); //TODO: extension??
}

View File

@@ -1,156 +0,0 @@
use std::time::Instant;
use icicle::{curves::bls12_381::ScalarField_BLS12_381, test_bls12_381::*};
use rustacuda::prelude::DeviceBuffer;
const LOG_NTT_SIZES: [usize; 3] = [20, 10, 9];
const BATCH_SIZES: [usize; 3] = [1, 1 << 9, 1 << 10];
const MAX_POINTS_LOG2: usize = 18;
const MAX_SCALARS_LOG2: usize = 26;
fn bench_lde() {
for log_ntt_size in LOG_NTT_SIZES {
for batch_size in BATCH_SIZES {
let ntt_size = 1 << log_ntt_size;
fn ntt_scalars_batch_bls12_381(
d_inout: &mut DeviceBuffer<ScalarField_BLS12_381>,
d_twiddles: &mut DeviceBuffer<ScalarField_BLS12_381>,
batch_size: usize,
) -> i32 {
ntt_inplace_batch_bls12_381(d_inout, d_twiddles, batch_size, false, 0);
0
}
fn intt_scalars_batch_bls12_381(
d_inout: &mut DeviceBuffer<ScalarField_BLS12_381>,
d_twiddles: &mut DeviceBuffer<ScalarField_BLS12_381>,
batch_size: usize,
) -> i32 {
ntt_inplace_batch_bls12_381(d_inout, d_twiddles, batch_size, true, 0);
0
}
// copy
bench_ntt_template(
MAX_SCALARS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_scalars_bls12_381,
evaluate_scalars_batch_bls12_381,
"NTT",
false,
100,
);
bench_ntt_template(
MAX_SCALARS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_scalars_bls12_381,
interpolate_scalars_batch_bls12_381,
"iNTT",
true,
100,
);
bench_ntt_template(
MAX_POINTS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_points_bls12_381,
evaluate_points_batch_bls12_381,
"EC NTT",
false,
20,
);
bench_ntt_template(
MAX_POINTS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_points_bls12_381,
interpolate_points_batch_bls12_381,
"EC iNTT",
true,
20,
);
// inplace
bench_ntt_template(
MAX_SCALARS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_scalars_bls12_381,
ntt_scalars_batch_bls12_381,
"NTT inplace",
false,
100,
);
bench_ntt_template(
MAX_SCALARS_LOG2,
ntt_size,
batch_size,
log_ntt_size,
set_up_scalars_bls12_381,
intt_scalars_batch_bls12_381,
"iNTT inplace",
true,
100,
);
}
}
}
fn bench_ntt_template<E, S, R>(
log_max_size: usize,
ntt_size: usize,
batch_size: usize,
log_ntt_size: usize,
set_data: fn(test_size: usize, log_domain_size: usize, inverse: bool) -> (Vec<E>, DeviceBuffer<E>, DeviceBuffer<S>),
bench_fn: fn(d_evaluations: &mut DeviceBuffer<E>, d_domain: &mut DeviceBuffer<S>, batch_size: usize) -> R,
id: &str,
inverse: bool,
samples: usize,
) -> Option<(Vec<E>, R)> {
let count = ntt_size * batch_size;
let bench_id = format!("{} of size 2^{} in batch {}", id, log_ntt_size, batch_size);
if count > 1 << log_max_size {
println!("Bench size exceeded: {}", bench_id);
return None;
}
println!("{}", bench_id);
let (input, mut d_evals, mut d_domain) = set_data(ntt_size * batch_size, log_ntt_size, inverse);
let first = bench_fn(&mut d_evals, &mut d_domain, batch_size);
let start = Instant::now();
for _ in 0..samples {
bench_fn(&mut d_evals, &mut d_domain, batch_size);
}
let elapsed = start.elapsed();
println!(
"{} {:0?} us x {} = {:?}",
bench_id,
elapsed.as_micros() as f32 / (samples as f32),
samples,
elapsed
);
Some((input, first))
}
fn main() {
bench_lde();
}

View File

@@ -19,8 +19,8 @@ set(CMAKE_CUDA_FLAGS_RELEASE "")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
# when adding a new curve, append its name to the end of this list
set(SUPPORTED_CURVES bn254;bls12_381;bls12_377)
# when adding a new curve/field, append its name to the end of this list
set(SUPPORTED_CURVES bn254;bls12_381;bls12_377;bw6_671)
set(IS_CURVE_SUPPORTED FALSE)
set(I 0)
@@ -42,25 +42,48 @@ if (NOT BUILD_TESTS)
add_library(
icicle
utils/error_handler.cu
utils/utils_kernels.cu
utils/vec_ops.cu
primitives/field.cu
primitives/projective.cu
appUtils/msm/msm.cu
appUtils/ntt/ntt.cu
appUtils/lde/lde.cu
)
#set_target_properties(icicle PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
#set_target_properties(icicle PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# # cuda_add_library(cudanet STATIC ${cu})
# target_link_libraries(icicle ${CUDA_LIBRARIES})
# target_link_libraries(icicle cudart cuda cublas curand)
# message(STATUS "include & link cuda")
# link_directories(/usr/local/cuda/lib64)
target_compile_options(icicle PRIVATE -c)
# find_package(CUDA QUIET REQUIRED)
add_custom_command(
TARGET icicle
POST_BUILD
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym msm_cuda=${CURVE}_msm_cuda ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
COMMAND ${CMAKE_AR} ARGS -rcs ${LIBRARY_OUTPUT_DIRECTORY}/libingo_${CURVE}.a ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym MSMCuda=${CURVE}MSMCuda ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
COMMAND ${CMAKE_OBJCOPY} ARGS --redefine-sym GetDefaultMSMConfig=${CURVE}GetDefaultMSMConfig ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
# COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
# COMMAND ${CMAKE_OBJCOPY} ARGS --prefix-symbols=${CURVE}_ ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o
COMMAND ${CMAKE_AR} ARGS -rcs ${LIBRARY_OUTPUT_DIRECTORY}/libingo_${CURVE}.a
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/error_handler.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/vec_ops.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/field.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/projective.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/utils_kernels.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o
${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o
)
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o)
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o)
file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/msm/msm.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/ntt/ntt.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/appUtils/lde/lde.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/primitives/projective.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/utils_kernels.cu.o)
# file(REMOVE ${PROJECT_BINARY_DIR}/CMakeFiles/icicle.dir/utils/error_handler.cu.o)
else()

View File

@@ -16,6 +16,7 @@
#include "../../primitives/projective.cuh"
#include "../../utils/cuda_utils.cuh"
#include "../../utils/error_handler.cuh"
#include "../../utils/mont.cuh"
namespace msm {
@@ -27,6 +28,12 @@ namespace msm {
// #define BIG_TRIANGLE
// #define SSM_SUM //WIP
template <typename S>
int get_optimal_c(int bitsize)
{
return ceil(log2(bitsize)) - 4;
}
template <typename P>
__global__ void single_stage_multi_reduction_kernel(
P* v,
@@ -132,32 +139,14 @@ namespace msm {
buckets_indices[current_index] =
(msm_index << (c + bm_bitsize)) | (bm << c) |
bucket_index; // the bucket module number and the msm number are appended at the msbs
if (scalar == S::zero() || scalar == S::one() || bucket_index == 0)
buckets_indices[current_index] = 0; // will be skipped
point_indices[current_index] = tid; // the point index is saved for later
if (scalar == S::zero() || bucket_index == 0) buckets_indices[current_index] = 0; // will be skipped
point_indices[current_index] = tid; // the point index is saved for later
#endif
}
}
}
template <typename P, typename A, typename S>
__global__ void
add_ones_kernel(A* points, S* scalars, P* results, const unsigned msm_size, const unsigned run_length)
{
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
const unsigned nof_threads = (msm_size + run_length - 1) / run_length; // 129256
if (tid >= nof_threads) {
results[tid] = P::zero();
return;
}
const unsigned start_index = tid * run_length;
P sum = P::zero();
for (int i = start_index; i < min(start_index + run_length, msm_size); i++) {
if (scalars[i] == S::one()) sum = sum + points[i];
}
results[tid] = sum;
}
template <typename S>
__global__ void
find_cutoff_kernel(unsigned* v, unsigned size, unsigned cutoff, unsigned run_length, unsigned* result)
{
@@ -174,6 +163,7 @@ namespace msm {
if (tid == 0 && v[size - 1] > cutoff) { result[0] = size; }
}
template <typename S>
__global__ void
find_max_size(unsigned* bucket_sizes, unsigned* single_bucket_indices, unsigned c, unsigned* largest_bucket_size)
{
@@ -330,8 +320,8 @@ namespace msm {
// this kernel computes the final result using the double and add algorithm
// it is done by a single thread
template <typename P, typename S>
__global__ void final_accumulation_kernel(
P* final_sums, P* ones_result, P* final_results, unsigned nof_msms, unsigned nof_bms, unsigned c, bool add_ones)
__global__ void
final_accumulation_kernel(P* final_sums, P* final_results, unsigned nof_msms, unsigned nof_bms, unsigned c)
{
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid > nof_msms) return;
@@ -343,38 +333,53 @@ namespace msm {
final_result = final_result + final_result;
}
}
if (add_ones)
final_results[tid] = final_result + final_sums[tid * nof_bms] + ones_result[0];
else
final_results[tid] = final_result + final_sums[tid * nof_bms];
final_results[tid] = final_result + final_sums[tid * nof_bms];
}
// this function computes msm using the bucket method
template <typename S, typename P, typename A>
void bucket_method_msm(
unsigned bitsize,
unsigned c,
int bitsize,
int c,
S* scalars,
A* points,
unsigned size,
int size,
P* final_result,
bool on_device,
bool big_triangle,
unsigned large_bucket_factor,
bool are_scalars_on_device,
bool are_scalars_montgomery_form,
bool are_points_on_device,
bool are_points_montgomery_form,
bool is_result_on_device,
bool is_big_triangle,
int large_bucket_factor,
bool is_async,
cudaStream_t stream)
{
S* d_scalars;
A* d_points;
if (!on_device) {
// copy scalars and points to gpu
if (!are_scalars_on_device) {
// copy scalars to gpu
cudaMallocAsync(&d_scalars, sizeof(S) * size, stream);
cudaMallocAsync(&d_points, sizeof(A) * size, stream);
cudaMemcpyAsync(d_scalars, scalars, sizeof(S) * size, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(d_points, points, sizeof(A) * size, cudaMemcpyHostToDevice, stream);
} else {
d_scalars = scalars;
}
cudaStream_t stream_points;
if (!are_points_on_device || are_points_montgomery_form) cudaStreamCreate(&stream_points);
if (!are_points_on_device) {
// copy points to gpu
cudaMallocAsync(&d_points, sizeof(A) * size, stream_points);
cudaMemcpyAsync(d_points, points, sizeof(A) * size, cudaMemcpyHostToDevice, stream_points);
} else {
d_points = points;
}
if (are_points_montgomery_form) mont::FromMontgomery(d_points, size, stream_points);
if (are_scalars_montgomery_form) mont::FromMontgomery(d_scalars, size, stream);
cudaEvent_t event_points_uploaded;
if (!are_points_on_device || are_points_montgomery_form) {
cudaEventCreateWithFlags(&event_points_uploaded, cudaEventDisableTiming);
cudaEventRecord(event_points_uploaded, stream_points);
}
P* buckets;
// compute number of bucket modules and number of buckets in each module
@@ -393,22 +398,6 @@ namespace msm {
unsigned NUM_BLOCKS = (nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
initialize_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, nof_buckets);
// accumulate ones
P* ones_results; // fix whole division, in last run in kernel too
const unsigned nof_runs = msm_log_size > 10 ? (1 << (msm_log_size - 6)) : 16;
const unsigned run_length = (size + nof_runs - 1) / nof_runs;
cudaMallocAsync(&ones_results, sizeof(P) * nof_runs, stream);
NUM_THREADS = min(1 << 8, nof_runs);
NUM_BLOCKS = (nof_runs + NUM_THREADS - 1) / NUM_THREADS;
add_ones_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(d_points, d_scalars, ones_results, size, run_length);
for (int s = nof_runs >> 1; s > 0; s >>= 1) {
NUM_THREADS = min(MAX_TH, s);
NUM_BLOCKS = (s + NUM_THREADS - 1) / NUM_THREADS;
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
ones_results, ones_results, s * 2, 0, 0, 0, s);
}
unsigned* bucket_indices;
unsigned* point_indices;
cudaMallocAsync(&bucket_indices, sizeof(unsigned) * size * (nof_bms + 1), stream);
@@ -482,7 +471,7 @@ namespace msm {
// if all points are 0 just return point 0
if (h_nof_buckets_to_compute == 0) {
if (!on_device)
if (!is_result_on_device)
final_result[0] = P::zero();
else {
P* h_final_result = (P*)malloc(sizeof(P));
@@ -533,7 +522,7 @@ namespace msm {
unsigned cutoff_nof_runs = (h_nof_buckets_to_compute + cutoff_run_length - 1) / cutoff_run_length;
NUM_THREADS = min(1 << 5, cutoff_nof_runs);
NUM_BLOCKS = (cutoff_nof_runs + NUM_THREADS - 1) / NUM_THREADS;
find_cutoff_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
find_cutoff_kernel<S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
sorted_bucket_sizes, h_nof_buckets_to_compute, bucket_th, cutoff_run_length, nof_large_buckets);
unsigned h_nof_large_buckets;
@@ -541,7 +530,7 @@ namespace msm {
unsigned* max_res;
cudaMallocAsync(&max_res, sizeof(unsigned) * 2, stream);
find_max_size<<<1, 1, 0, stream>>>(sorted_bucket_sizes, sorted_single_bucket_indices, c, max_res);
find_max_size<S><<<1, 1, 0, stream>>>(sorted_bucket_sizes, sorted_single_bucket_indices, c, max_res);
unsigned h_max_res[2];
cudaMemcpyAsync(h_max_res, max_res, sizeof(unsigned) * 2, cudaMemcpyDeviceToHost, stream);
@@ -550,11 +539,19 @@ namespace msm {
unsigned large_buckets_to_compute =
h_nof_large_buckets > h_nof_zero_large_buckets ? h_nof_large_buckets - h_nof_zero_large_buckets : 0;
cudaStream_t stream2;
cudaStreamCreate(&stream2);
P* large_buckets;
if (!are_points_on_device || are_points_montgomery_form) {
// by this point, points need to be already uploaded and un-Montgomeried
cudaStreamWaitEvent(stream, event_points_uploaded);
cudaStreamDestroy(stream_points);
}
cudaStream_t stream_large_buckets;
cudaEvent_t event_large_buckets_accumulated;
P* large_buckets;
if (large_buckets_to_compute > 0 && bucket_th > 0) {
cudaStreamCreate(&stream_large_buckets);
cudaEventCreateWithFlags(&event_large_buckets_accumulated, cudaEventDisableTiming);
unsigned threads_per_bucket =
1 << (unsigned)ceil(log2((h_largest_bucket_size + bucket_th - 1) / bucket_th)); // global param
unsigned max_bucket_size_run_length = (h_largest_bucket_size + threads_per_bucket - 1) / threads_per_bucket;
@@ -563,7 +560,7 @@ namespace msm {
NUM_THREADS = min(1 << 8, total_large_buckets_size);
NUM_BLOCKS = (total_large_buckets_size + NUM_THREADS - 1) / NUM_THREADS;
accumulate_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
accumulate_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
large_buckets, sorted_bucket_offsets + h_nof_zero_large_buckets,
sorted_bucket_sizes + h_nof_zero_large_buckets, sorted_single_bucket_indices + h_nof_zero_large_buckets,
point_indices, d_points, nof_buckets, large_buckets_to_compute, c + bm_bitsize, c, threads_per_bucket,
@@ -573,17 +570,18 @@ namespace msm {
for (int s = total_large_buckets_size >> 1; s > large_buckets_to_compute - 1; s >>= 1) {
NUM_THREADS = min(MAX_TH, s);
NUM_BLOCKS = (s + NUM_THREADS - 1) / NUM_THREADS;
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
large_buckets, large_buckets, s * 2, 0, 0, 0, s);
CHECK_LAST_CUDA_ERROR();
}
// distribute
NUM_THREADS = min(MAX_TH, large_buckets_to_compute);
NUM_BLOCKS = (large_buckets_to_compute + NUM_THREADS - 1) / NUM_THREADS;
distribute_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream2>>>(
distribute_large_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_large_buckets>>>(
large_buckets, buckets, sorted_single_bucket_indices + h_nof_zero_large_buckets, large_buckets_to_compute);
cudaEventRecord(event_large_buckets_accumulated, stream_large_buckets);
cudaStreamDestroy(stream_large_buckets);
} else {
h_nof_large_buckets = 0;
}
@@ -598,9 +596,11 @@ namespace msm {
h_nof_buckets_to_compute - h_nof_large_buckets, c + bm_bitsize, c);
}
// all the large buckets need to be accumulated before the final summation
cudaStreamSynchronize(stream2);
cudaStreamDestroy(stream2);
if (large_buckets_to_compute > 0 && bucket_th > 0) {
// all the large buckets need to be accumulated before the final summation
cudaStreamWaitEvent(stream, event_large_buckets_accumulated);
cudaStreamDestroy(stream_large_buckets);
}
#ifdef SSM_SUM
// sum each bucket
@@ -618,10 +618,10 @@ namespace msm {
#endif
P* d_final_result;
if (!on_device) cudaMallocAsync(&d_final_result, sizeof(P), stream);
if (!is_result_on_device) cudaMallocAsync(&d_final_result, sizeof(P), stream);
P* final_results;
if (big_triangle) {
if (is_big_triangle) {
cudaMallocAsync(&final_results, sizeof(P) * nof_bms, stream);
// launch the bucket module sum kernel - a thread for each bucket module
NUM_THREADS = nof_bms;
@@ -697,19 +697,17 @@ namespace msm {
}
// launch the double and add kernel, a single thread
final_accumulation_kernel<P, S><<<1, 1, 0, stream>>>(
final_results, ones_results, on_device ? final_result : d_final_result, 1, nof_bms, c, true);
final_accumulation_kernel<P, S>
<<<1, 1, 0, stream>>>(final_results, is_result_on_device ? final_result : d_final_result, 1, nof_bms, c);
cudaFreeAsync(final_results, stream);
cudaStreamSynchronize(stream);
if (!on_device) cudaMemcpyAsync(final_result, d_final_result, sizeof(P), cudaMemcpyDeviceToHost, stream);
if (!is_result_on_device)
cudaMemcpyAsync(final_result, d_final_result, sizeof(P), cudaMemcpyDeviceToHost, stream);
// free memory
if (!on_device) {
cudaFreeAsync(d_points, stream);
cudaFreeAsync(d_scalars, stream);
cudaFreeAsync(d_final_result, stream);
}
if (!are_scalars_on_device) cudaFreeAsync(d_scalars, stream);
if (!are_points_on_device) cudaFreeAsync(d_points, stream);
if (!is_result_on_device) cudaFreeAsync(d_final_result, stream);
cudaFreeAsync(buckets, stream);
#ifndef PHASE1_TEST
cudaFreeAsync(bucket_indices, stream);
@@ -725,9 +723,8 @@ namespace msm {
cudaFreeAsync(nof_large_buckets, stream);
cudaFreeAsync(max_res, stream);
if (large_buckets_to_compute > 0 && bucket_th > 0) cudaFreeAsync(large_buckets, stream);
cudaFreeAsync(ones_results, stream);
cudaStreamSynchronize(stream);
if (!is_async) cudaStreamSynchronize(stream);
}
// this function computes multiple msms using the bucket method
@@ -879,7 +876,7 @@ namespace msm {
NUM_THREADS = 1 << 8;
NUM_BLOCKS = (batch_size + NUM_THREADS - 1) / NUM_THREADS;
final_accumulation_kernel<P, S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
bm_sums, bm_sums, on_device ? final_results : d_final_results, batch_size, nof_bms, c, false);
bm_sums, on_device ? final_results : d_final_results, batch_size, nof_bms, c);
// copy final result to host
if (!on_device)
@@ -907,26 +904,59 @@ namespace msm {
} // namespace
MSMConfig DefaultMSMConfig()
{
device_context::DeviceContext ctx = {
0, // device_id
(cudaStream_t)0, // stream
0, // mempool
};
MSMConfig config = {
false, // are_scalars_on_device
false, // are_scalars_montgomery_form
0, // points_size
1, // precompute_factor
false, // are_points_on_device
false, // are_points_montgomery_form
1, // batch_size
false, // are_results_on_device
0, // c
0, // bitsize
false, // is_big_triangle
10, // large_bucket_factor
false, // is_async
ctx, // DeviceContext
};
return config;
}
template <typename S, typename A, typename P>
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig config, P* results)
{
int bitsize = (config.bitsize == 0) ? S::NBITS : config.bitsize;
// TODO: DmytroTym/HadarIngonyama - unify the implementation of the bucket method and the batched bucket method in
// one function
// TODO: DmytroTym/HadarIngonyama - parameters to be included into the implementation: on deviceness of points,
// scalars and results, precompute factor, points size and device id
if (config.batch_size == 1)
bucket_method_msm(
config.bitsize, config.c, scalars, points, msm_size, results, config.are_scalars_on_device, config.big_triangle,
config.large_bucket_factor, config.ctx.stream);
bitsize, 16, scalars, points, msm_size, results, config.are_scalars_on_device,
config.are_scalars_montgomery_form, config.are_points_on_device, config.are_points_montgomery_form,
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.is_async,
config.ctx.stream);
else
batched_bucket_method_msm(
config.bitsize, config.c, scalars, points, config.batch_size, msm_size, results, config.are_scalars_on_device,
config.ctx.stream);
bitsize, (config.c == 0) ? get_optimal_c<S>(bitsize) : config.c, scalars, points, config.batch_size, msm_size,
results, config.are_scalars_on_device, config.ctx.stream);
return cudaSuccess;
}
/**
* Extern version of [msm](@ref msm) function with the following values of template parameters
* Extern version of [DefaultMSMConfig](@ref DefaultMSMConfig) function.
* @return Default value of [MSMConfig](@ref MSMConfig).
*/
extern "C" MSMConfig GetDefaultMSMConfig() { return DefaultMSMConfig(); }
/**
* Extern version of [MSM](@ref MSM) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` is the [scalar field](@ref scalar_t) of the curve;
* - `A` is the [affine representation](@ref affine_t) of curve points;
@@ -947,7 +977,7 @@ namespace msm {
#if defined(G2_DEFINED)
/**
* Extern version of [msm](@ref msm) function with the following values of template parameters
* Extern version of [MSM](@ref MSM) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` is the [scalar field](@ref scalar_t) of the curve;
* - `A` is the [affine representation](@ref g2_affine_t) of G2 curve points;

View File

@@ -4,7 +4,13 @@
#include <cuda_runtime.h>
#include "../../curves/curve_config.cuh"
#include "../../primitives/affine.cuh"
#include "../../primitives/field.cuh"
#include "../../primitives/projective.cuh"
#include "../../utils/cuda_utils.cuh"
#include "../../utils/device_context.cuh"
#include "../../utils/error_handler.cuh"
/**
* @namespace msm
@@ -41,31 +47,41 @@ namespace msm {
* the MSM size). */
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the
* number of computations to make, on-line memory footprint, but increase the static
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
bool are_points_on_device; /**< True if points are on device and false if they're on host. Default value: false. */
bool are_points_montgomery_form; /**< True if coordinates of points are in Montgomery form and false otherwise.
Default value: true. */
int batch_size; /**< The number of MSMs to compute. Default value: 1. */
bool are_result_on_device; /**< True if the results should be on device and false if they should be on host. Default
value: false. */
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
* footprint but also more parallelism and less computational complexity (up to a certain point).
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a
* different (better) upper bound is known, it should be reflected in this variable. Default
* value: 0 (set to the bitsize of scalar field). */
bool big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but also
* greatly decreases parallelism, so only suitable for large batches of MSMs. Default value:
* false. */
bool are_results_on_device; /**< True if the results should be on device and false if they should be on host. If set
* to false, `is_async` won't take effect because a synchronization is needed to
* transfer results to the host. Default value: false. */
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
* footprint but also more parallelism and less computational complexity (up to a certain point).
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a
* different (better) upper bound is known, it should be reflected in this variable. Default value: 0
* (set to the bitsize of scalar field). */
bool is_big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but
* also greatly decreases parallelism, so only suitable for large batches of MSMs. Default
* value: false. */
int large_bucket_factor; /**< Variable that controls how sensitive the algorithm is to the buckets that occur very
* frequently. Useful for efficient treatment of non-uniform distributions of scalars and
* "top windows" with few bits. Can be set to 0 to disable separate treatment of large
* buckets altogether. Default value: 10. */
int is_async; /**< Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
* and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or
* `cudaDeviceSynchronize`. If set to false, the MSM function will block the current CPU thread. */
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See
[DeviceContext](@ref device_context::DeviceContext). */
[DeviceContext](@ref `device_context::DeviceContext`). */
};
/**
* A function that returns the default value of [MSMConfig](@ref MSMConfig) for the [MSM](@ref MSM) function.
* @return Default value of [MSMConfig](@ref MSMConfig).
*/
extern "C" MSMConfig DefaultMSMConfig();
/**
* A function that computes MSM: \f$ MSM(s_i, P_i) = \sum_{i=1}^N s_i \cdot P_i \f$.
* @param scalars Scalars \f$ s_i \f$. In case of batch MSM, the scalars from all MSMs are concatenated.
@@ -81,6 +97,15 @@ namespace msm {
* @tparam P Output type, which is typically a [projective
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) point in our codebase.
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*
* This function is asyncronous, and to sync it with host, you need to call `cudaDeviceSyncronize()`. To syncronize
* with a different stream `stream1`, call `cudaStreamSynchronize(config.stream)` and
* `cudaStreamSynchronize(stream1)`.
*
* **Note:** this function is still WIP and the following [MSMConfig](@ref MSMConfig) members do not yet have any
* effect: `points_size` (it's always equal to the msm size currenly), `precompute_factor` (always equals 1) and
* `ctx.device_id` (0 device is always used). Also, it's currently better to use `batch_size=1` in most cases (expept
* with dealing with very many MSMs).
*/
template <typename S, typename A, typename P>
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig config, P* results);

View File

@@ -130,18 +130,14 @@ int main()
test_scalar* scalars = new test_scalar[N];
test_affine* points = new test_affine[N];
for (unsigned i = 0; i < N; i++) {
// scalars[i] = (i%msm_size < 10)? test_scalar::rand_host() : scalars[i-10];
points[i] = (i % msm_size < 100) ? test_projective::to_affine(test_projective::rand_host()) : points[i - 100];
scalars[i] = test_scalar::rand_host();
// scalars[i] = i < N/2? test_scalar::rand_host() : test_scalar::one();
// points[i] = test_projective::to_affine(test_projective::rand_host());
}
test_scalar::RandHostMany(scalars, N);
test_projective::RandHostManyAffine(points, N);
std::cout << "finished generating" << std::endl;
// projective_t *short_res = (projective_t*)malloc(sizeof(projective_t));
// test_projective *large_res = (test_projective*)malloc(sizeof(test_projective));
test_projective large_res[batch_size * 2];
test_projective large_res[batch_size];
// test_projective batched_large_res[batch_size];
// fake_point *large_res = (fake_point*)malloc(sizeof(fake_point));
// fake_point batched_large_res[256];
@@ -175,44 +171,47 @@ int main()
0, // mempool
};
msm::MSMConfig config = {
false, // scalars_on_device
true, // scalars_montgomery_form
msm_size, // points_size
1, // precompute_factor
false, // points_on_device
true, // points_montgomery_form
1, // batch_size
false, // result_on_device
16, // c
test_scalar::NBITS, // bitsize
false, // big_triangle
10, // large_bucket_factor
ctx, // DeviceContext
false, // are_scalars_on_device
false, // are_scalars_montgomery_form
0, // points_size
1, // precompute_factor
false, // are_points_on_device
false, // are_points_montgomery_form
1, // batch_size
true, // are_results_on_device
0, // c
0, // bitsize
false, // is_big_triangle
10, // large_bucket_factor
true, // is_async
ctx, // DeviceContext
};
auto begin1 = std::chrono::high_resolution_clock::now();
msm::MSM<test_scalar, test_affine, test_projective>(scalars, points, msm_size, config, large_res);
msm::MSM<test_scalar, test_affine, test_projective>(scalars, points, msm_size, config, large_res_d);
cudaEvent_t msm_end_event;
cudaEventCreate(&msm_end_event);
auto end1 = std::chrono::high_resolution_clock::now();
auto elapsed1 = std::chrono::duration_cast<std::chrono::nanoseconds>(end1 - begin1);
printf("Big Triangle : %.3f seconds.\n", elapsed1.count() * 1e-9);
config.big_triangle = true;
printf("No Big Triangle : %.3f seconds.\n", elapsed1.count() * 1e-9);
config.is_big_triangle = true;
config.are_results_on_device = false;
// std::cout<<test_projective::to_affine(large_res[0])<<std::endl;
auto begin = std::chrono::high_resolution_clock::now();
msm::MSM<test_scalar, test_affine, test_projective>(scalars_d, points_d, msm_size, config, large_res_d);
msm::MSM<test_scalar, test_affine, test_projective>(scalars_d, points_d, msm_size, config, large_res);
// test_reduce_triangle(scalars);
// test_reduce_rectangle(scalars);
// test_reduce_single(scalars);
// test_reduce_var(scalars);
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
printf("On Device No Big Triangle: %.3f seconds.\n", elapsed.count() * 1e-9);
printf("Big Triangle: %.3f seconds.\n", elapsed.count() * 1e-9);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);
std::cout << test_projective::to_affine(large_res[0]) << std::endl;
cudaMemcpy(&large_res[1], large_res_d, sizeof(test_projective), cudaMemcpyDeviceToHost);
std::cout << test_projective::to_affine(large_res[1]) << std::endl;
// reference_msm<test_affine, test_scalar, test_projective>(scalars, points, msm_size);

View File

@@ -8,10 +8,10 @@ namespace ntt {
namespace {
const uint32_t MAX_NUM_THREADS = 1024;
const uint32_t MAX_NUM_THREADS = 512;
const uint32_t MAX_THREADS_BATCH = 512; // TODO: allows 100% occupancy for scalar NTT for sm_86..sm_89
const uint32_t MAX_SHARED_MEM_ELEMENT_SIZE = 32; // TODO: occupancy calculator, hardcoded for sm_86..sm_89
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * 1024;
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * MAX_NUM_THREADS;
/**
* Computes the twiddle factors.
@@ -21,13 +21,10 @@ namespace ntt {
* @param omega multiplying factor.
*/
template <typename S>
__global__ void twiddle_factors_kernel(S* d_twiddles, uint32_t n_twiddles, S omega)
__global__ void twiddle_factors_kernel(S* d_twiddles, int n_twiddles, S omega)
{
for (uint32_t i = 0; i < n_twiddles; i++) {
d_twiddles[i] = S::zero();
}
d_twiddles[0] = S::one();
for (uint32_t i = 0; i < n_twiddles - 1; i++) {
for (int i = 0; i < n_twiddles - 1; i++) {
d_twiddles[i + 1] = omega * d_twiddles[i];
}
}
@@ -61,7 +58,7 @@ namespace ntt {
int number_of_threads = MAX_THREADS_BATCH;
int number_of_blocks = (n * batch_size + number_of_threads - 1) / number_of_threads;
reverse_order_kernel<<<number_of_blocks, number_of_threads, 0, stream>>>(arr, arr_reversed, n, logn, batch_size);
cudaMemcpyAsync(arr, arr_reversed, n * batch_size * sizeof(E), cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(arr, arr_reversed, n * batch_size * sizeof(E), cudaMemcpyDefault, stream);
cudaFreeAsync(arr_reversed, stream);
}
@@ -313,15 +310,15 @@ namespace ntt {
}
if (is_coset)
utils_internal::batchVectorMult<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
utils_internal::BatchMulKernel<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
num_threads = max(min(n / 2, MAX_NUM_THREADS), 1);
num_blocks = (n * batch_size + num_threads - 1) / num_threads;
utils_internal::template_normalize_kernel<E, S>
utils_internal::NormalizeKernel<E, S>
<<<num_blocks, num_threads, 0, stream>>>(d_inout, S::inv_log_size(logn), n * batch_size);
} else {
if (is_coset)
utils_internal::batchVectorMult<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
utils_internal::BatchMulKernel<E, S><<<num_blocks, num_threads, 0, stream>>>(d_inout, coset, n, batch_size);
for (int s = logn - 1; s >= logn_shmem; s--) // TODO: this loop also can be unrolled
{
@@ -349,30 +346,44 @@ namespace ntt {
}
template <typename E, typename S>
cudaError_t NTT(E* input, int size, bool is_inverse, NTTConfig<S> config)
cudaError_t NTT(NTTConfig<E, S>* config)
{
uint32_t logn = uint32_t(log(size) / log(2));
uint32_t n_twiddles = size; // n_twiddles is set to 4096 as BLS12_381::scalar_t::omega() is of that order.
size_t input_size = size * config.batch_size * sizeof(E);
bool is_on_device = config.are_inputs_on_device;
bool generate_twiddles = (config.twiddles == nullptr);
cudaStream_t stream = config.ctx.stream;
CHECK_LAST_CUDA_ERROR();
cudaStream_t stream = config->ctx.stream;
int size = config->size;
int batch_size = config->batch_size;
bool is_inverse = config->is_inverse;
int n_twiddles = size;
int logn = int(log(size) / log(2));
int input_size_bytes = size * batch_size * sizeof(E);
bool is_input_on_device = config->are_inputs_on_device;
bool is_output_on_device = config->is_output_on_device;
bool is_forward_twiddle_empty = config->twiddles == nullptr;
bool is_inverse_twiddle_empty = config->inv_twiddles == nullptr;
bool is_generating_twiddles = (is_forward_twiddle_empty && is_inverse_twiddle_empty) ||
(is_forward_twiddle_empty && !is_inverse) || (is_inverse_twiddle_empty && is_inverse);
S* d_twiddles;
if (generate_twiddles) {
if (is_generating_twiddles) {
cudaMallocAsync(&d_twiddles, n_twiddles * sizeof(S), stream);
GenerateTwiddleFactors(d_twiddles, n_twiddles, is_inverse ? S::omega_inv(logn) : S::omega(logn), config.ctx);
S omega = is_inverse ? S::omega_inv(logn) : S::omega(logn);
GenerateTwiddleFactors(d_twiddles, n_twiddles, omega, config->ctx);
} else {
d_twiddles = is_inverse ? config->inv_twiddles : config->twiddles;
}
E* d_input;
if (!is_on_device) {
cudaMallocAsync(&d_input, input_size, stream);
cudaMemcpyAsync(d_input, input, input_size, cudaMemcpyHostToDevice, stream);
E* d_inout;
if (is_input_on_device) {
d_inout = config->inout;
} else {
cudaMallocAsync(&d_inout, input_size_bytes, stream);
cudaMemcpyAsync(d_inout, config->inout, input_size_bytes, cudaMemcpyHostToDevice, stream);
}
bool reverse_input;
bool reverse_output;
switch (config.ordering) {
switch (config->ordering) {
case Ordering::kNN:
reverse_input = is_inverse;
reverse_output = !is_inverse;
@@ -390,33 +401,50 @@ namespace ntt {
reverse_output = is_inverse;
break;
}
CHECK_LAST_CUDA_ERROR();
if (reverse_input) reverse_order_batch(d_inout, size, logn, config->batch_size, stream);
CHECK_LAST_CUDA_ERROR();
if (reverse_input) reverse_order_batch(is_on_device ? input : d_input, size, logn, config.batch_size, stream);
ntt_inplace_batch_template(
is_on_device ? input : d_input, generate_twiddles ? d_twiddles : config.twiddles, size, config.batch_size,
is_inverse, config.is_coset, config.coset_gen, stream, false);
if (reverse_output) reverse_order_batch(is_on_device ? input : d_input, size, logn, config.batch_size, stream);
d_inout, d_twiddles, size, batch_size, is_inverse, config->is_coset, config->coset_gen, stream, false);
CHECK_LAST_CUDA_ERROR();
if (!is_on_device) {
cudaMemcpyAsync(input, d_input, input_size, cudaMemcpyDeviceToHost, stream);
cudaFreeAsync(d_input, stream);
if (reverse_output) reverse_order_batch(d_inout, size, logn, batch_size, stream);
CHECK_LAST_CUDA_ERROR();
if (is_output_on_device) {
// free(config->inout); // TODO: ? or callback?+
config->inout = d_inout;
} else {
if (is_input_on_device) {
E* h_output = (E*)malloc(input_size_bytes); // TODO: caller responsible for memory management
cudaMemcpyAsync(h_output, d_inout, input_size_bytes, cudaMemcpyDeviceToHost, stream);
config->inout = h_output;
CHECK_LAST_CUDA_ERROR();
} else {
cudaMemcpyAsync(config->inout, d_inout, input_size_bytes, cudaMemcpyDeviceToHost, stream);
CHECK_LAST_CUDA_ERROR();
}
cudaFreeAsync(d_inout, stream); // TODO: make it optional? so can be reused
}
CHECK_LAST_CUDA_ERROR();
if (is_generating_twiddles && !config->is_preserving_twiddles) { cudaFreeAsync(d_twiddles, stream); }
if (config->is_preserving_twiddles) {
if (is_inverse)
config->inv_twiddles = d_twiddles;
else {
config->twiddles = d_twiddles;
}
}
if (generate_twiddles) { cudaFreeAsync(d_twiddles, stream); }
cudaStreamSynchronize(stream);
return cudaSuccess;
}
/**
* Extern version of [GenerateTwiddleFactors](@ref GenerateTwiddleFactors) function with the template parameter
* `S` being the [scalar field](@ref scalar_t) of the curve given by `-DCURVE` env variable during build.
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t GenerateTwiddleFactorsCuda(
curve_config::scalar_t* d_twiddles, int n_twiddles, curve_config::scalar_t omega, device_context::DeviceContext ctx)
{
return GenerateTwiddleFactors<curve_config::scalar_t>(d_twiddles, n_twiddles, omega, ctx);
CHECK_LAST_CUDA_ERROR();
return cudaSuccess;
}
/**
@@ -425,10 +453,43 @@ namespace ntt {
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t
NTTCuda(curve_config::scalar_t* input, int size, bool is_inverse, NTTConfig<curve_config::scalar_t> config)
extern "C" cudaError_t NTTCuda(NTTConfig<curve_config::scalar_t, curve_config::scalar_t>* config)
{
return NTT<curve_config::scalar_t, curve_config::scalar_t>(input, size, is_inverse, config);
return NTT<curve_config::scalar_t, curve_config::scalar_t>(config);
}
/**
* Extern version of [ntt](@ref ntt) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
template <typename E, typename S>
cudaError_t NTTDefaultContext(NTTConfig<E, S>* config)
{
// TODO: if empty - create default
cudaMemPool_t mempool;
cudaDeviceGetDefaultMemPool(&mempool, config->ctx.device_id);
device_context::DeviceContext context = {
config->ctx.device_id,
0, // default stream
mempool};
config->ctx = context;
return NTT<E, S>(config);
}
/**
* Extern version of [ntt](@ref ntt) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` and `E` are both the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t NTTDefaultContextCuda(NTTConfig<curve_config::scalar_t, curve_config::scalar_t>* config)
{
return NTTDefaultContext(config);
}
#if defined(ECNTT_DEFINED)
@@ -440,10 +501,9 @@ namespace ntt {
* - `E` is the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t
ECNTTCuda(curve_config::projective_t* input, int size, bool is_inverse, NTTConfig<curve_config::scalar_t> config)
extern "C" cudaError_t ECNTTCuda(NTTConfig<curve_config::projective_t, curve_config::scalar_t>* config)
{
return NTT<curve_config::projective_t, curve_config::scalar_t>(input, size, is_inverse, config);
return NTT<curve_config::projective_t, curve_config::scalar_t>(config);
}
#endif

View File

@@ -4,7 +4,11 @@
#include <cuda_runtime.h>
#include "../../curves/curve_config.cuh"
#include "../../utils/device_context.cuh"
#include "../../utils/error_handler.cuh"
#include "../../utils/sharedmem.cuh"
#include "../../utils/utils_kernels.cuh"
/**
* @namespace ntt
@@ -51,10 +55,14 @@ namespace ntt {
* @struct NTTConfig
* Struct that encodes NTT parameters to be passed into the [ntt](@ref ntt) function.
*/
template <typename S>
template <typename E, typename S>
struct NTTConfig {
E* inout; /**< Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot
* config.batch_size \f$. Note that if inputs are in Montgomery form, the outputs will be as well and
* vice-verse: non-Montgomery inputs produce non-Montgomety outputs.*/
bool are_inputs_on_device; /**< True if inputs/outputs are on device and false if they're on host. Default value:
false. */
bool is_inverse; /**< True if true . Default value: false. */
Ordering
ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. */
Decimation
@@ -64,35 +72,40 @@ namespace ntt {
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
* `Decimation::kDIT` and if ordering is `Ordering::kNR` — to `Decimation::kDIF`. */
Butterfly
butterfly; /**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value:
* `Butterfly::kCooleyTukey`.
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
bool is_coset; /**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is
* computed on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen),
* so: \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\}
* \f$. Default value: false. */
S* coset_gen; /**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
S* twiddles; /**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
* This pointer is expected to live on device. The order is as follows:
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
butterfly; /**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value:
* `Butterfly::kCooleyTukey`.
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
bool is_coset; /**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is
* computed on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen),
* so: \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\}
* \f$. Default value: false. */
S* coset_gen; /**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
S* twiddles; /**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
* This pointer is expected to live on device. The order is as follows:
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
S* inv_twiddles; /**< "Inverse twiddle factors", (or "domain", or "roots of unity") on which the iNTT is evaluated.
* This pointer is expected to live on device. The order is as follows:
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle
* factors are generated online using the default generator (TODO: link to twiddle gen here) and
* function [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
int size; /**< NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is
the size of 1 NTT. */
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
bool is_preserving_twiddles; /**< If true, twiddle factors are preserved on device for subsequent use in config and
not freed after calculation. Default value: false. */
bool is_output_on_device; /**< If true, output is preserved on device for subsequent use in config and not freed
after calculation. Default value: false. */
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See
[DeviceContext](@ref device_context::DeviceContext). */
};
/**
* A function that computes NTT or iNTT in-place.
* @param input Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot
* config.batch_size \f$. Note that if inputs are in Montgomery form, the outputs will be as well and vice-verse:
* non-Montgomery inputs produce non-Montgomety outputs.
* @param size NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is the
* size of 1 NTT.
* @param is_inverse If true, inverse NTT is computed, otherwise — regular forward NTT.
* @param config [NTTConfig](@ref NTTConfig) used in this NTT.
* @tparam E The type of inputs and outputs (i.e. coefficients \f$ \{p_i\} \f$ and values \f$ p(x) \f$). Must be a
* group.
@@ -100,7 +113,7 @@ namespace ntt {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
template <typename E, typename S>
cudaError_t NTT(E* input, int size, bool is_inverse, NTTConfig<S> config);
cudaError_t NTT(NTTConfig<E, S>* config);
/**
* Generates twiddles \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$ from root of unity \f$ \omega \f$ and

View File

@@ -25,11 +25,10 @@ using namespace bls12_377;
namespace curve_config {
typedef Field<fp_config> scalar_field_t;
typedef scalar_field_t scalar_t;
typedef Field<fp_config> scalar_t;
typedef Field<fq_config> point_field_t;
static constexpr point_field_t b = point_field_t{weierstrass_b};
typedef Projective<point_field_t, scalar_field_t, b> projective_t;
typedef Projective<point_field_t, scalar_t, b> projective_t;
typedef Affine<point_field_t> affine_t;
#if defined(G2_DEFINED)

View File

@@ -11,6 +11,16 @@ public:
static HOST_DEVICE_INLINE Affine neg(const Affine& point) { return {point.x, FF::neg(point.y)}; }
static HOST_DEVICE_INLINE Affine ToMontgomery(const Affine& point)
{
return {FF::ToMontgomery(point.x), FF::ToMontgomery(point.y)};
}
static HOST_DEVICE_INLINE Affine FromMontgomery(const Affine& point)
{
return {FF::FromMontgomery(point.x), FF::FromMontgomery(point.y)};
}
friend HOST_DEVICE_INLINE bool operator==(const Affine& xs, const Affine& ys)
{
return (xs.x == ys.x) && (xs.y == ys.y);

View File

@@ -0,0 +1,6 @@
#include "../curves/curve_config.cuh"
#include "field.cuh"
#define scalar_t curve_config::scalar_t
extern "C" void GenerateScalars(scalar_t* scalars, int size) { scalar_t::RandHostMany(scalars, size); }

View File

@@ -69,10 +69,6 @@ public:
static constexpr HOST_DEVICE_INLINE Field modulus() { return Field{CONFIG::modulus}; }
static constexpr HOST_DEVICE_INLINE Field montgomery_r() { return Field{CONFIG::montgomery_r}; }
static constexpr HOST_DEVICE_INLINE Field montgomery_r_inv() { return Field{CONFIG::montgomery_r_inv}; }
// private:
typedef storage<TLC> ff_storage;
typedef storage<2 * TLC> ff_wide_storage;
@@ -703,6 +699,12 @@ public:
return value;
}
static void RandHostMany(Field* out, int size)
{
for (int i = 0; i < size; i++)
out[i] = rand_host();
}
template <unsigned REDUCTION_SIZE = 1>
static constexpr HOST_DEVICE_INLINE Field sub_modulus(const Field& xs)
{
@@ -718,7 +720,7 @@ public:
hex_string << std::hex << std::setfill('0');
for (int i = 0; i < TLC; i++) {
hex_string << std::setw(8) << xs.limbs_storage.limbs[i];
hex_string << std::setw(8) << xs.limbs_storage.limbs[TLC - i - 1];
}
os << "0x" << hex_string.str();
@@ -869,6 +871,13 @@ public:
return xs * xs;
}
static constexpr HOST_DEVICE_INLINE Field ToMontgomery(const Field& xs) { return xs * Field{CONFIG::montgomery_r}; }
static constexpr HOST_DEVICE_INLINE Field FromMontgomery(const Field& xs)
{
return xs * Field{CONFIG::montgomery_r_inv};
}
template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE Field neg(const Field& xs)
{

View File

@@ -1,61 +1,54 @@
#include "../curves/bls12_377/curve_config.cuh"
#include "../curves/bls12_381/curve_config.cuh"
#include "../curves/bn254/curve_config.cuh"
#include "../curves/curve_config.cuh"
#include "projective.cuh"
#include <cuda.h>
extern "C" bool eq_bls12_381(BLS12_381::projective_t* point1, BLS12_381::projective_t* point2)
#define projective_t curve_config::projective_t // TODO: global to avoid lengthy texts
#define affine_t curve_config::affine_t
#define point_field_t curve_config::point_field_t
extern "C" bool Eq(projective_t* point1, projective_t* point2)
{
return (*point1 == *point2) &&
!((point1->x == BLS12_381::point_field_t::zero()) && (point1->y == BLS12_381::point_field_t::zero()) &&
(point1->z == BLS12_381::point_field_t::zero())) &&
!((point2->x == BLS12_381::point_field_t::zero()) && (point2->y == BLS12_381::point_field_t::zero()) &&
(point2->z == BLS12_381::point_field_t::zero()));
!((point1->x == point_field_t::zero()) && (point1->y == point_field_t::zero()) &&
(point1->z == point_field_t::zero())) &&
!((point2->x == point_field_t::zero()) && (point2->y == point_field_t::zero()) &&
(point2->z == point_field_t::zero()));
}
extern "C" bool eq_bls12_377(BLS12_377::projective_t* point1, BLS12_377::projective_t* point2)
{
return (*point1 == *point2) &&
!((point1->x == BLS12_377::point_field_t::zero()) && (point1->y == BLS12_377::point_field_t::zero()) &&
(point1->z == BLS12_377::point_field_t::zero())) &&
!((point2->x == BLS12_377::point_field_t::zero()) && (point2->y == BLS12_377::point_field_t::zero()) &&
(point2->z == BLS12_377::point_field_t::zero()));
}
extern "C" void ToAffine(projective_t* point, affine_t* point_out) { *point_out = projective_t::to_affine(*point); }
extern "C" bool eq_bn254(BN254::projective_t* point1, BN254::projective_t* point2)
{
return (*point1 == *point2) &&
!((point1->x == BN254::point_field_t::zero()) && (point1->y == BN254::point_field_t::zero()) &&
(point1->z == BN254::point_field_t::zero())) &&
!((point2->x == BN254::point_field_t::zero()) && (point2->y == BN254::point_field_t::zero()) &&
(point2->z == BN254::point_field_t::zero()));
}
extern "C" void GenerateProjectivePoints(projective_t* points, int size) { projective_t::RandHostMany(points, size); }
extern "C" void GenerateAffinePoints(affine_t* points, int size) { projective_t::RandHostManyAffine(points, size); }
#if defined(G2_DEFINED)
extern "C" bool eq_g2_bls12_381(BLS12_381::g2_projective_t* point1, BLS12_381::g2_projective_t* point2)
#define g2_projective_t curve_config::g2_projective_t
#define g2_affine_t curve_config::g2_affine_t
#define g2_point_field_t curve_config::g2_point_field_t
extern "C" bool EqG2(g2_projective_t* point1, g2_projective_t* point2)
{
return (*point1 == *point2) &&
!((point1->x == BLS12_381::g2_point_field_t::zero()) && (point1->y == BLS12_381::g2_point_field_t::zero()) &&
(point1->z == BLS12_381::g2_point_field_t::zero())) &&
!((point2->x == BLS12_381::g2_point_field_t::zero()) && (point2->y == BLS12_381::g2_point_field_t::zero()) &&
(point2->z == BLS12_381::g2_point_field_t::zero()));
!((point1->x == g2_point_field_t::zero()) && (point1->y == g2_point_field_t::zero()) &&
(point1->z == g2_point_field_t::zero())) &&
!((point2->x == g2_point_field_t::zero()) && (point2->y == g2_point_field_t::zero()) &&
(point2->z == g2_point_field_t::zero()));
}
extern "C" bool eq_g2_bls12_377(BLS12_377::g2_projective_t* point1, BLS12_377::g2_projective_t* point2)
extern "C" void ToAffineG2(g2_projective_t* point, affine_t* point_out)
{
return (*point1 == *point2) &&
!((point1->x == BLS12_377::g2_point_field_t::zero()) && (point1->y == BLS12_377::g2_point_field_t::zero()) &&
(point1->z == BLS12_377::g2_point_field_t::zero())) &&
!((point2->x == BLS12_377::g2_point_field_t::zero()) && (point2->y == BLS12_377::g2_point_field_t::zero()) &&
(point2->z == BLS12_377::g2_point_field_t::zero()));
*point_out = projective_t::to_affine(*point);
}
extern "C" bool eq_g2_bn254(BN254::g2_projective_t* point1, BN254::g2_projective_t* point2)
extern "C" void GenerateProjectivePointsG2(g2_projective_t* points, int size)
{
return (*point1 == *point2) &&
!((point1->x == BN254::g2_point_field_t::zero()) && (point1->y == BN254::g2_point_field_t::zero()) &&
(point1->z == BN254::g2_point_field_t::zero())) &&
!((point2->x == BN254::g2_point_field_t::zero()) && (point2->y == BN254::g2_point_field_t::zero()) &&
(point2->z == BN254::g2_point_field_t::zero()));
g2_projective_t::RandHostMany(points, size);
}
#endif
extern "C" void GenerateAffinePointsG2(g2_affine_t* points, int size)
{
g2_projective_t::RandHostManyAffine(points, size);
}
#endif

View File

@@ -22,6 +22,16 @@ public:
static HOST_DEVICE_INLINE Projective from_affine(const Affine<FF>& point) { return {point.x, point.y, FF::one()}; }
static HOST_DEVICE_INLINE Projective ToMontgomery(const Projective& point)
{
return {FF::ToMontgomery(point.x), FF::ToMontgomery(point.y), FF::ToMontgomery(point.z)};
}
static HOST_DEVICE_INLINE Projective FromMontgomery(const Projective& point)
{
return {FF::FromMontgomery(point.x), FF::FromMontgomery(point.y), FF::FromMontgomery(point.z)};
}
static HOST_DEVICE_INLINE Projective generator() { return {FF::generator_x(), FF::generator_y(), FF::one()}; }
static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; }
@@ -163,4 +173,16 @@ public:
SCALAR_FF rand_scalar = SCALAR_FF::rand_host();
return rand_scalar * generator();
}
static void RandHostMany(Projective* out, int size)
{
for (int i = 0; i < size; i++)
out[i] = (i % size < 100) ? rand_host() : out[i - 100];
}
static void RandHostManyAffine(Affine<FF>* out, int size)
{
for (int i = 0; i < size; i++)
out[i] = (i % size < 100) ? to_affine(rand_host()) : out[i - 100];
}
};

View File

@@ -0,0 +1,28 @@
#include <iostream>
template <typename T>
void check(T err, const char* const func, const char* const file, const int line)
{
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
}
}
void checkLast(const char* const file, const int line)
{
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
}
}
void syncDevice(const char* const file, const int line)
{
cudaError_t err{cudaDeviceSynchronize()};
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
}
}

View File

@@ -1,32 +1,17 @@
#pragma once
#ifndef ERR_H
#define ERR_H
#include <iostream>
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
void check(T err, const char* const func, const char* const file, const int line)
{
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
}
}
void check(T err, const char* const func, const char* const file, const int line);
#define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__)
void checkLast(const char* const file, const int line)
{
cudaError_t err{cudaGetLastError()};
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
}
}
void checkLast(const char* const file, const int line);
#define CHECK_SYNC_DEVICE_ERROR() syncDevice(__FILE__, __LINE__)
void syncDevice(const char* const file, const int line)
{
cudaError_t err{cudaDeviceSynchronize()};
if (err != cudaSuccess) {
std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
std::cerr << cudaGetErrorString(err) << std::endl;
}
}
void syncDevice(const char* const file, const int line);
#endif

View File

@@ -2,23 +2,26 @@
#ifndef MONT_H
#define MONT_H
#include "utils_kernels.cuh"
namespace mont {
namespace {
#define MAX_THREADS_PER_BLOCK 256
// TODO (DmytroTym): do valid conversion for point types too
template <typename E>
int convert_montgomery(E* d_inout, int n, bool is_into, cudaStream_t stream)
template <typename E, bool is_into>
__global__ void MontgomeryKernel(E* inout, int n)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { inout[tid] = is_into ? E::ToMontgomery(inout[tid]) : E::FromMontgomery(inout[tid]); }
}
template <typename E, bool is_into>
int ConvertMontgomery(E* d_inout, int n, cudaStream_t stream)
{
// Set the grid and block dimensions
int num_threads = MAX_THREADS_PER_BLOCK;
int num_blocks = (n + num_threads - 1) / num_threads;
E mont = is_into ? E::montgomery_r() : E::montgomery_r_inv();
utils_internal::template_normalize_kernel<<<num_blocks, num_threads, 0, stream>>>(d_inout, mont, n);
MontgomeryKernel<E, is_into><<<num_blocks, num_threads, 0, stream>>>(d_inout, n);
return 0; // TODO: void with propper error handling
}
@@ -26,15 +29,15 @@ namespace mont {
} // namespace
template <typename E>
int to_montgomery(E* d_inout, int n, cudaStream_t stream)
int ToMontgomery(E* d_inout, int n, cudaStream_t stream)
{
return convert_montgomery(d_inout, n, true, stream);
return ConvertMontgomery<E, true>(d_inout, n, stream);
}
template <typename E>
int from_montgomery(E* d_inout, int n, cudaStream_t stream)
int FromMontgomery(E* d_inout, int n, cudaStream_t stream)
{
return convert_montgomery(d_inout, n, false, stream);
return ConvertMontgomery<E, false>(d_inout, n, stream);
}
} // namespace mont

View File

@@ -1,16 +1,23 @@
#include "utils_kernels.cuh"
namespace utils_internal {
// TODO: weird linking issue - only works in headers
// template <typename E, typename S>
// __global__ void NormalizeKernel(E* arr, S scalar, unsigned n)
// {
// int tid = blockIdx.x * blockDim.x + threadIdx.x;
// if (tid < n) { arr[tid] = scalar * arr[tid]; }
// }
template <typename E, typename S>
__global__ void template_normalize_kernel(E* arr, S scalar, uint32_t n)
__global__ void NormalizeKernel(E* arr, S scalar, int n)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { arr[tid] = scalar * arr[tid]; }
}
template <typename E, typename S>
__global__ void batchVectorMult(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size)
__global__ void BatchMulKernel(E* element_vec, S* scalar_vec, int n_scalars, int batch_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n_scalars * batch_size) {

View File

@@ -4,17 +4,22 @@
namespace utils_internal {
/**
* Multiply the elements of an input array by a scalar in-place. Used for normalization in iNTT.
* @param arr input array.
* @param n size of arr.
* @param n_inv scalar of type S (scalar).
*/
template <typename E, typename S>
__global__ void template_normalize_kernel(E* arr, S scalar, uint32_t n);
__global__ void NormalizeKernel(E* arr, S scalar, unsigned n)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { arr[tid] = scalar * arr[tid]; }
}
template <typename E, typename S>
__global__ void batchVectorMult(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size);
__global__ void BatchMulKernel(E* element_vec, S* scalar_vec, unsigned n_scalars, unsigned batch_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n_scalars * batch_size) {
int scalar_id = tid % n_scalars;
element_vec[tid] = scalar_vec[scalar_id] * element_vec[tid];
}
}
} // namespace utils_internal

View File

@@ -1,34 +1,32 @@
#include "lde.cuh"
#include <cuda.h>
#include <stdexcept>
#include "../../curves/curve_config.cuh"
#include "../../utils/device_context.cuh"
#include "../../utils/mont.cuh"
#include "../curves/curve_config.cuh"
#include "device_context.cuh"
#include "mont.cuh"
namespace lde {
namespace vec_ops {
namespace {
#define MAX_THREADS_PER_BLOCK 256
template <typename E, typename S>
__global__ void mul_kernel(S* scalar_vec, E* element_vec, int n, E* result)
__global__ void MulKernel(S* scalar_vec, E* element_vec, int n, E* result)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n) { result[tid] = scalar_vec[tid] * element_vec[tid]; }
}
template <typename E>
__global__ void add_kernel(E* element_vec1, E* element_vec2, int n, E* result)
__global__ void AddKernel(E* element_vec1, E* element_vec2, int n, E* result)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { result[tid] = element_vec1[tid] + element_vec2[tid]; }
}
template <typename E>
__global__ void sub_kernel(E* element_vec1, E* element_vec2, int n, E* result)
__global__ void SubKernel(E* element_vec1, E* element_vec2, int n, E* result)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { result[tid] = element_vec1[tid] - element_vec2[tid]; }
@@ -58,9 +56,9 @@ namespace lde {
}
// Call the kernel to perform element-wise modular multiplication
mul_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
MulKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
if (is_montgomery) mont::from_montgomery(is_on_device ? result : d_result, n, ctx.stream);
if (is_montgomery) mont::FromMontgomery(is_on_device ? result : d_result, n, ctx.stream);
if (!is_on_device) {
cudaMemcpyAsync(result, d_result, n * sizeof(E), cudaMemcpyDeviceToHost, ctx.stream);
@@ -93,7 +91,7 @@ namespace lde {
}
// Call the kernel to perform element-wise addition
add_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
AddKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
if (!is_on_device) {
@@ -127,7 +125,7 @@ namespace lde {
}
// Call the kernel to perform element-wise subtraction
sub_kernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
SubKernel<<<num_blocks, num_threads, 0, ctx.stream>>>(
is_on_device ? vec_a : d_vec_a, is_on_device ? vec_b : d_vec_b, n, is_on_device ? result : d_result);
if (!is_on_device) {
@@ -190,4 +188,4 @@ namespace lde {
return Sub<curve_config::scalar_t>(vec_a, vec_b, n, is_on_device, ctx, result);
}
} // namespace lde
} // namespace vec_ops

View File

@@ -2,15 +2,13 @@
#ifndef LDE_H
#define LDE_H
#include "../../utils/device_context.cuh"
#include "device_context.cuh"
/**
* @namespace lde
* LDE (stands for low degree extension) contains [NTT](@ref ntt)-based methods for translating between coefficient and
* evaluation domains of polynomials. It also contains methods for element-wise manipulation of vectors, which is useful
* for working with polynomials in evaluation domain.
* @namespace vec_ops
* This namespace contains methods for performing element-wise arithmetic operations on vectors.
*/
namespace lde {
namespace vec_ops {
/**
* A function that multiplies two vectors element-wise.
@@ -62,6 +60,6 @@ namespace lde {
template <typename E>
cudaError_t Sub(E* vec_a, E* vec_b, int n, bool is_on_device, device_context::DeviceContext ctx, E* result);
} // namespace lde
} // namespace vec_ops
#endif

View File

@@ -1,329 +0,0 @@
use std::ffi::c_uint;
use ark_CURVE_NAME_L::{Fq as Fq_CURVE_NAME_U, Fr as Fr_CURVE_NAME_U, G1Affine as G1Affine_CURVE_NAME_U, G1Projective as G1Projective_CURVE_NAME_U};
use ark_ec::AffineCurve;
use ark_ff::{BigInteger_limbs_q, BigInteger_limbs_p, PrimeField};
use std::mem::transmute;
use ark_ff::Field;
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}};
use rustacuda_core::DeviceCopy;
use rustacuda_derive::DeviceCopy;
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field_CURVE_NAME_U<const NUM_LIMBS: usize> {
pub s: [u32; NUM_LIMBS],
}
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_CURVE_NAME_U<NUM_LIMBS> {}
impl<const NUM_LIMBS: usize> Default for Field_CURVE_NAME_U<NUM_LIMBS> {
fn default() -> Self {
Field_CURVE_NAME_U::zero()
}
}
impl<const NUM_LIMBS: usize> Field_CURVE_NAME_U<NUM_LIMBS> {
pub fn zero() -> Self {
Field_CURVE_NAME_U {
s: [0u32; NUM_LIMBS],
}
}
pub fn one() -> Self {
let mut s = [0u32; NUM_LIMBS];
s[0] = 1;
Field_CURVE_NAME_U { s }
}
fn to_bytes_le(&self) -> Vec<u8> {
self.s
.iter()
.map(|s| s.to_le_bytes().to_vec())
.flatten()
.collect::<Vec<_>>()
}
}
pub const BASE_LIMBS_CURVE_NAME_U: usize = limbs_q;
pub const SCALAR_LIMBS_CURVE_NAME_U: usize = limbs_p;
pub type BaseField_CURVE_NAME_U = Field_CURVE_NAME_U<BASE_LIMBS_CURVE_NAME_U>;
pub type ScalarField_CURVE_NAME_U = Field_CURVE_NAME_U<SCALAR_LIMBS_CURVE_NAME_U>;
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val.try_into().unwrap(),
_ => panic!("slice has too many elements"),
}
}
//
impl BaseField_CURVE_NAME_U {
pub fn limbs(&self) -> [u32; BASE_LIMBS_CURVE_NAME_U] {
self.s
}
pub fn from_limbs(value: &[u32]) -> Self {
Self {
s: get_fixed_limbs(value),
}
}
pub fn to_ark(&self) -> BigInteger_limbs_q {
BigInteger_limbs_q::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
}
pub fn from_ark(ark: BigInteger_limbs_q) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
}
//
impl ScalarField_CURVE_NAME_U {
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_CURVE_NAME_U] {
self.s
}
pub fn to_ark(&self) -> BigInteger_limbs_p {
BigInteger_limbs_p::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
}
pub fn from_ark(ark: BigInteger_limbs_p) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
pub fn to_ark_transmute(&self) -> BigInteger_limbs_p {
unsafe { transmute(*self) }
}
pub fn from_ark_transmute(v: BigInteger_limbs_p) -> ScalarField_CURVE_NAME_U {
unsafe { transmute(v) }
}
}
#[derive(Debug, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct Point_CURVE_NAME_U {
pub x: BaseField_CURVE_NAME_U,
pub y: BaseField_CURVE_NAME_U,
pub z: BaseField_CURVE_NAME_U,
}
impl Default for Point_CURVE_NAME_U {
fn default() -> Self {
Point_CURVE_NAME_U::zero()
}
}
impl Point_CURVE_NAME_U {
pub fn zero() -> Self {
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::zero(),
y: BaseField_CURVE_NAME_U::one(),
z: BaseField_CURVE_NAME_U::zero(),
}
}
pub fn infinity() -> Self {
Self::zero()
}
pub fn to_ark(&self) -> G1Projective_CURVE_NAME_U {
//TODO: generic conversion
self.to_ark_affine().into_projective()
}
pub fn to_ark_affine(&self) -> G1Affine_CURVE_NAME_U {
//TODO: generic conversion
use std::ops::Mul;
let proj_x_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.x.to_bytes_le());
let proj_y_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.y.to_bytes_le());
let proj_z_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.z.to_bytes_le());
let inverse_z = proj_z_field.inverse().unwrap();
let aff_x = proj_x_field.mul(inverse_z);
let aff_y = proj_y_field.mul(inverse_z);
G1Affine_CURVE_NAME_U::new(aff_x, aff_y, false)
}
pub fn from_ark(ark: G1Projective_CURVE_NAME_U) -> Point_CURVE_NAME_U {
let z_inv = ark.z.inverse().unwrap();
let z_invsq = z_inv * z_inv;
let z_invq3 = z_invsq * z_inv;
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark((ark.x * z_invsq).into_repr()),
y: BaseField_CURVE_NAME_U::from_ark((ark.y * z_invq3).into_repr()),
z: BaseField_CURVE_NAME_U::one(),
}
}
}
extern "C" {
fn eq_CURVE_NAME_L(point1: *const Point_CURVE_NAME_U, point2: *const Point_CURVE_NAME_U) -> c_uint;
}
impl PartialEq for Point_CURVE_NAME_U {
fn eq(&self, other: &Self) -> bool {
unsafe { eq_CURVE_NAME_L(self, other) != 0 }
}
}
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct PointAffineNoInfinity_CURVE_NAME_U {
pub x: BaseField_CURVE_NAME_U,
pub y: BaseField_CURVE_NAME_U,
}
impl Default for PointAffineNoInfinity_CURVE_NAME_U {
fn default() -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::zero(),
y: BaseField_CURVE_NAME_U::zero(),
}
}
}
impl PointAffineNoInfinity_CURVE_NAME_U {
// TODO: generics
///From u32 limbs x,y
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(x),
},
y: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(y),
},
}
}
pub fn limbs(&self) -> Vec<u32> {
[self.x.limbs(), self.y.limbs()].concat()
}
pub fn to_projective(&self) -> Point_CURVE_NAME_U {
Point_CURVE_NAME_U {
x: self.x,
y: self.y,
z: BaseField_CURVE_NAME_U::one(),
}
}
pub fn to_ark(&self) -> G1Affine_CURVE_NAME_U {
G1Affine_CURVE_NAME_U::new(Fq_CURVE_NAME_U::new(self.x.to_ark()), Fq_CURVE_NAME_U::new(self.y.to_ark()), false)
}
pub fn to_ark_repr(&self) -> G1Affine_CURVE_NAME_U {
G1Affine_CURVE_NAME_U::new(
Fq_CURVE_NAME_U::from_repr(self.x.to_ark()).unwrap(),
Fq_CURVE_NAME_U::from_repr(self.y.to_ark()).unwrap(),
false,
)
}
pub fn from_ark(p: &G1Affine_CURVE_NAME_U) -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark(p.x.into_repr()),
y: BaseField_CURVE_NAME_U::from_ark(p.y.into_repr()),
}
}
}
impl Point_CURVE_NAME_U {
// TODO: generics
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(x),
},
y: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(y),
},
z: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(z),
},
}
}
pub fn from_xy_limbs(value: &[u32]) -> Point_CURVE_NAME_U {
let l = value.len();
assert_eq!(l, 3 * BASE_LIMBS_CURVE_NAME_U, "length must be 3 * {}", BASE_LIMBS_CURVE_NAME_U);
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: value[..BASE_LIMBS_CURVE_NAME_U].try_into().unwrap(),
},
y: BaseField_CURVE_NAME_U {
s: value[BASE_LIMBS_CURVE_NAME_U..BASE_LIMBS_CURVE_NAME_U * 2].try_into().unwrap(),
},
z: BaseField_CURVE_NAME_U {
s: value[BASE_LIMBS_CURVE_NAME_U * 2..].try_into().unwrap(),
},
}
}
pub fn to_affine(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
let ark_affine = self.to_ark_affine();
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark(ark_affine.x.into_repr()),
y: BaseField_CURVE_NAME_U::from_ark(ark_affine.y.into_repr()),
}
}
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
PointAffineNoInfinity_CURVE_NAME_U {
x: self.x,
y: self.y,
}
}
}
impl ScalarField_CURVE_NAME_U {
pub fn from_limbs(value: &[u32]) -> ScalarField_CURVE_NAME_U {
ScalarField_CURVE_NAME_U {
s: get_fixed_limbs(value),
}
}
}
#[cfg(test)]
mod tests {
use ark_CURVE_NAME_L::{Fr as Fr_CURVE_NAME_U};
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}, curves::CURVE_NAME_L::{Point_CURVE_NAME_U, ScalarField_CURVE_NAME_U}};
#[test]
fn test_ark_scalar_convert() {
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
let scalar = ScalarField_CURVE_NAME_U::from_limbs(&limbs);
assert_eq!(
scalar.to_ark(),
scalar.to_ark_transmute(),
"{:08X?} {:08X?}",
scalar.to_ark(),
scalar.to_ark_transmute()
)
}
#[test]
#[allow(non_snake_case)]
fn test_point_equality() {
let left = Point_CURVE_NAME_U::zero();
let right = Point_CURVE_NAME_U::zero();
assert_eq!(left, right);
let right = Point_CURVE_NAME_U::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
assert_eq!(left, right);
let right = Point_CURVE_NAME_U::from_limbs(
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
&[0; 12],
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
);
assert!(left != right);
}
}

View File

@@ -1,309 +0,0 @@
use std::ffi::c_uint;
use ark_CURVE_NAME_L::{Fq as Fq_CURVE_NAME_U, Fr as Fr_CURVE_NAME_U, G1Affine as G1Affine_CURVE_NAME_U, G1Projective as G1Projective_CURVE_NAME_U};
use ark_ec::AffineCurve;
use ark_ff::{BigInteger_limbs_p, PrimeField};
use std::mem::transmute;
use ark_ff::Field;
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}};
use rustacuda_core::DeviceCopy;
use rustacuda_derive::DeviceCopy;
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field_CURVE_NAME_U<const NUM_LIMBS: usize> {
pub s: [u32; NUM_LIMBS],
}
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_CURVE_NAME_U<NUM_LIMBS> {}
impl<const NUM_LIMBS: usize> Default for Field_CURVE_NAME_U<NUM_LIMBS> {
fn default() -> Self {
Field_CURVE_NAME_U::zero()
}
}
impl<const NUM_LIMBS: usize> Field_CURVE_NAME_U<NUM_LIMBS> {
pub fn zero() -> Self {
Field_CURVE_NAME_U {
s: [0u32; NUM_LIMBS],
}
}
pub fn one() -> Self {
let mut s = [0u32; NUM_LIMBS];
s[0] = 1;
Field_CURVE_NAME_U { s }
}
fn to_bytes_le(&self) -> Vec<u8> {
self.s
.iter()
.map(|s| s.to_le_bytes().to_vec())
.flatten()
.collect::<Vec<_>>()
}
}
pub const BASE_LIMBS_CURVE_NAME_U: usize = limbs_p;
pub const SCALAR_LIMBS_CURVE_NAME_U: usize = limbs_p;
pub type BaseField_CURVE_NAME_U = Field_CURVE_NAME_U<BASE_LIMBS_CURVE_NAME_U>;
pub type ScalarField_CURVE_NAME_U = Field_CURVE_NAME_U<SCALAR_LIMBS_CURVE_NAME_U>;
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val.try_into().unwrap(),
_ => panic!("slice has too many elements"),
}
}
impl ScalarField_CURVE_NAME_U {
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_CURVE_NAME_U] {
self.s
}
pub fn to_ark(&self) -> BigInteger_limbs_p {
BigInteger_limbs_p::new(u32_vec_to_u64_vec(&self.limbs()).try_into().unwrap())
}
pub fn from_ark(ark: BigInteger_limbs_p) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
pub fn to_ark_transmute(&self) -> BigInteger_limbs_p {
unsafe { transmute(*self) }
}
pub fn from_ark_transmute(v: BigInteger_limbs_p) -> ScalarField_CURVE_NAME_U {
unsafe { transmute(v) }
}
}
#[derive(Debug, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct Point_CURVE_NAME_U {
pub x: BaseField_CURVE_NAME_U,
pub y: BaseField_CURVE_NAME_U,
pub z: BaseField_CURVE_NAME_U,
}
impl Default for Point_CURVE_NAME_U {
fn default() -> Self {
Point_CURVE_NAME_U::zero()
}
}
impl Point_CURVE_NAME_U {
pub fn zero() -> Self {
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::zero(),
y: BaseField_CURVE_NAME_U::one(),
z: BaseField_CURVE_NAME_U::zero(),
}
}
pub fn infinity() -> Self {
Self::zero()
}
pub fn to_ark(&self) -> G1Projective_CURVE_NAME_U {
//TODO: generic conversion
self.to_ark_affine().into_projective()
}
pub fn to_ark_affine(&self) -> G1Affine_CURVE_NAME_U {
//TODO: generic conversion
use ark_ff::Field;
use std::ops::Mul;
let proj_x_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.x.to_bytes_le());
let proj_y_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.y.to_bytes_le());
let proj_z_field = Fq_CURVE_NAME_U::from_le_bytes_mod_order(&self.z.to_bytes_le());
let inverse_z = proj_z_field.inverse().unwrap();
let aff_x = proj_x_field.mul(inverse_z);
let aff_y = proj_y_field.mul(inverse_z);
G1Affine_CURVE_NAME_U::new(aff_x, aff_y, false)
}
pub fn from_ark(ark: G1Projective_CURVE_NAME_U) -> Point_CURVE_NAME_U {
use ark_ff::Field;
let z_inv = ark.z.inverse().unwrap();
let z_invsq = z_inv * z_inv;
let z_invq3 = z_invsq * z_inv;
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark((ark.x * z_invsq).into_repr()),
y: BaseField_CURVE_NAME_U::from_ark((ark.y * z_invq3).into_repr()),
z: BaseField_CURVE_NAME_U::one(),
}
}
}
extern "C" {
fn eq_CURVE_NAME_L(point1: *const Point_CURVE_NAME_U, point2: *const Point_CURVE_NAME_U) -> c_uint;
}
impl PartialEq for Point_CURVE_NAME_U {
fn eq(&self, other: &Self) -> bool {
unsafe { eq_CURVE_NAME_L(self, other) != 0 }
}
}
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct PointAffineNoInfinity_CURVE_NAME_U {
pub x: BaseField_CURVE_NAME_U,
pub y: BaseField_CURVE_NAME_U,
}
impl Default for PointAffineNoInfinity_CURVE_NAME_U {
fn default() -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::zero(),
y: BaseField_CURVE_NAME_U::zero(),
}
}
}
impl PointAffineNoInfinity_CURVE_NAME_U {
// TODO: generics
///From u32 limbs x,y
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(x),
},
y: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(y),
},
}
}
pub fn limbs(&self) -> Vec<u32> {
[self.x.limbs(), self.y.limbs()].concat()
}
pub fn to_projective(&self) -> Point_CURVE_NAME_U {
Point_CURVE_NAME_U {
x: self.x,
y: self.y,
z: BaseField_CURVE_NAME_U::one(),
}
}
pub fn to_ark(&self) -> G1Affine_CURVE_NAME_U {
G1Affine_CURVE_NAME_U::new(Fq_CURVE_NAME_U::new(self.x.to_ark()), Fq_CURVE_NAME_U::new(self.y.to_ark()), false)
}
pub fn to_ark_repr(&self) -> G1Affine_CURVE_NAME_U {
G1Affine_CURVE_NAME_U::new(
Fq_CURVE_NAME_U::from_repr(self.x.to_ark()).unwrap(),
Fq_CURVE_NAME_U::from_repr(self.y.to_ark()).unwrap(),
false,
)
}
pub fn from_ark(p: &G1Affine_CURVE_NAME_U) -> Self {
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark(p.x.into_repr()),
y: BaseField_CURVE_NAME_U::from_ark(p.y.into_repr()),
}
}
}
impl Point_CURVE_NAME_U {
// TODO: generics
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(x),
},
y: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(y),
},
z: BaseField_CURVE_NAME_U {
s: get_fixed_limbs(z),
},
}
}
pub fn from_xy_limbs(value: &[u32]) -> Point_CURVE_NAME_U {
let l = value.len();
assert_eq!(l, 3 * BASE_LIMBS_CURVE_NAME_U, "length must be 3 * {}", BASE_LIMBS_CURVE_NAME_U);
Point_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U {
s: value[..BASE_LIMBS_CURVE_NAME_U].try_into().unwrap(),
},
y: BaseField_CURVE_NAME_U {
s: value[BASE_LIMBS_CURVE_NAME_U..BASE_LIMBS_CURVE_NAME_U * 2].try_into().unwrap(),
},
z: BaseField_CURVE_NAME_U {
s: value[BASE_LIMBS_CURVE_NAME_U * 2..].try_into().unwrap(),
},
}
}
pub fn to_affine(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
let ark_affine = self.to_ark_affine();
PointAffineNoInfinity_CURVE_NAME_U {
x: BaseField_CURVE_NAME_U::from_ark(ark_affine.x.into_repr()),
y: BaseField_CURVE_NAME_U::from_ark(ark_affine.y.into_repr()),
}
}
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_CURVE_NAME_U {
PointAffineNoInfinity_CURVE_NAME_U {
x: self.x,
y: self.y,
}
}
}
impl ScalarField_CURVE_NAME_U {
pub fn from_limbs(value: &[u32]) -> ScalarField_CURVE_NAME_U {
ScalarField_CURVE_NAME_U {
s: get_fixed_limbs(value),
}
}
}
#[cfg(test)]
mod tests {
use ark_CURVE_NAME_L::{Fr as Fr_CURVE_NAME_U};
use crate::{utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec}, curves::CURVE_NAME_L::{Point_CURVE_NAME_U, ScalarField_CURVE_NAME_U}};
#[test]
fn test_ark_scalar_convert() {
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
let scalar = ScalarField_CURVE_NAME_U::from_limbs(&limbs);
assert_eq!(
scalar.to_ark(),
scalar.to_ark_transmute(),
"{:08X?} {:08X?}",
scalar.to_ark(),
scalar.to_ark_transmute()
)
}
#[test]
#[allow(non_snake_case)]
fn test_point_equality() {
let left = Point_CURVE_NAME_U::zero();
let right = Point_CURVE_NAME_U::zero();
assert_eq!(left, right);
let right = Point_CURVE_NAME_U::from_limbs(&[0; 8], &[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8]);
assert_eq!(left, right);
let right = Point_CURVE_NAME_U::from_limbs(
&[2, 0, 0, 0, 0, 0, 0, 0],
&[0; 8],
&[1, 0, 0, 0, 0, 0, 0, 0],
);
assert!(left != right);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,385 +0,0 @@
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
use ark_bls12_377::{Fq as Fq_BLS12_377, G1Affine as G1Affine_BLS12_377, G1Projective as G1Projective_BLS12_377};
use ark_ec::AffineCurve;
use ark_ff::Field;
use ark_ff::{BigInteger256, BigInteger384, PrimeField};
use rustacuda_core::DeviceCopy;
use rustacuda_derive::DeviceCopy;
use std::ffi::c_uint;
use std::mem::transmute;
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field_BLS12_377<const NUM_LIMBS: usize> {
pub s: [u32; NUM_LIMBS],
}
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BLS12_377<NUM_LIMBS> {}
impl<const NUM_LIMBS: usize> Default for Field_BLS12_377<NUM_LIMBS> {
fn default() -> Self {
Field_BLS12_377::zero()
}
}
impl<const NUM_LIMBS: usize> Field_BLS12_377<NUM_LIMBS> {
pub fn zero() -> Self {
Field_BLS12_377 { s: [0u32; NUM_LIMBS] }
}
pub fn one() -> Self {
let mut s = [0u32; NUM_LIMBS];
s[0] = 1;
Field_BLS12_377 { s }
}
fn to_bytes_le(&self) -> Vec<u8> {
self.s
.iter()
.map(|s| {
s.to_le_bytes()
.to_vec()
})
.flatten()
.collect::<Vec<_>>()
}
}
pub const BASE_LIMBS_BLS12_377: usize = 12;
pub const SCALAR_LIMBS_BLS12_377: usize = 8;
#[allow(non_camel_case_types)]
pub type BaseField_BLS12_377 = Field_BLS12_377<BASE_LIMBS_BLS12_377>;
#[allow(non_camel_case_types)]
pub type ScalarField_BLS12_377 = Field_BLS12_377<SCALAR_LIMBS_BLS12_377>;
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val
.try_into()
.unwrap(),
_ => panic!("slice has too many elements"),
}
}
impl BaseField_BLS12_377 {
pub fn limbs(&self) -> [u32; BASE_LIMBS_BLS12_377] {
self.s
}
pub fn from_limbs(value: &[u32]) -> Self {
Self {
s: get_fixed_limbs(value),
}
}
pub fn to_ark(&self) -> BigInteger384 {
BigInteger384::new(
u32_vec_to_u64_vec(&self.limbs())
.try_into()
.unwrap(),
)
}
pub fn from_ark(ark: BigInteger384) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
}
impl ScalarField_BLS12_377 {
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BLS12_377] {
self.s
}
pub fn to_ark(&self) -> BigInteger256 {
BigInteger256::new(
u32_vec_to_u64_vec(&self.limbs())
.try_into()
.unwrap(),
)
}
pub fn from_ark(ark: BigInteger256) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
pub fn to_ark_transmute(&self) -> BigInteger256 {
unsafe { transmute(*self) }
}
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BLS12_377 {
unsafe { transmute(v) }
}
}
#[derive(Debug, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct Point_BLS12_377 {
pub x: BaseField_BLS12_377,
pub y: BaseField_BLS12_377,
pub z: BaseField_BLS12_377,
}
impl Default for Point_BLS12_377 {
fn default() -> Self {
Point_BLS12_377::zero()
}
}
impl Point_BLS12_377 {
pub fn zero() -> Self {
Point_BLS12_377 {
x: BaseField_BLS12_377::zero(),
y: BaseField_BLS12_377::one(),
z: BaseField_BLS12_377::zero(),
}
}
pub fn infinity() -> Self {
Self::zero()
}
pub fn to_ark(&self) -> G1Projective_BLS12_377 {
//TODO: generic conversion
self.to_ark_affine()
.into_projective()
}
pub fn to_ark_affine(&self) -> G1Affine_BLS12_377 {
//TODO: generic conversion
use std::ops::Mul;
let proj_x_field = Fq_BLS12_377::from_le_bytes_mod_order(
&self
.x
.to_bytes_le(),
);
let proj_y_field = Fq_BLS12_377::from_le_bytes_mod_order(
&self
.y
.to_bytes_le(),
);
let proj_z_field = Fq_BLS12_377::from_le_bytes_mod_order(
&self
.z
.to_bytes_le(),
);
let inverse_z = proj_z_field
.inverse()
.unwrap();
let aff_x = proj_x_field.mul(inverse_z);
let aff_y = proj_y_field.mul(inverse_z);
G1Affine_BLS12_377::new(aff_x, aff_y, false)
}
pub fn from_ark(ark: G1Projective_BLS12_377) -> Point_BLS12_377 {
let z_inv = ark
.z
.inverse()
.unwrap();
let z_invsq = z_inv * z_inv;
let z_invq3 = z_invsq * z_inv;
Point_BLS12_377 {
x: BaseField_BLS12_377::from_ark((ark.x * z_invsq).into_repr()),
y: BaseField_BLS12_377::from_ark((ark.y * z_invq3).into_repr()),
z: BaseField_BLS12_377::one(),
}
}
}
extern "C" {
fn eq_bls12_377(point1: *const Point_BLS12_377, point2: *const Point_BLS12_377) -> c_uint;
}
impl PartialEq for Point_BLS12_377 {
fn eq(&self, other: &Self) -> bool {
unsafe { eq_bls12_377(self, other) != 0 }
}
}
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct PointAffineNoInfinity_BLS12_377 {
pub x: BaseField_BLS12_377,
pub y: BaseField_BLS12_377,
}
impl Default for PointAffineNoInfinity_BLS12_377 {
fn default() -> Self {
PointAffineNoInfinity_BLS12_377 {
x: BaseField_BLS12_377::zero(),
y: BaseField_BLS12_377::zero(),
}
}
}
impl PointAffineNoInfinity_BLS12_377 {
// TODO: generics
///From u32 limbs x,y
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
PointAffineNoInfinity_BLS12_377 {
x: BaseField_BLS12_377 { s: get_fixed_limbs(x) },
y: BaseField_BLS12_377 { s: get_fixed_limbs(y) },
}
}
pub fn limbs(&self) -> Vec<u32> {
[
self.x
.limbs(),
self.y
.limbs(),
]
.concat()
}
pub fn to_projective(&self) -> Point_BLS12_377 {
Point_BLS12_377 {
x: self.x,
y: self.y,
z: BaseField_BLS12_377::one(),
}
}
pub fn to_ark(&self) -> G1Affine_BLS12_377 {
G1Affine_BLS12_377::new(
Fq_BLS12_377::new(
self.x
.to_ark(),
),
Fq_BLS12_377::new(
self.y
.to_ark(),
),
false,
)
}
pub fn to_ark_repr(&self) -> G1Affine_BLS12_377 {
G1Affine_BLS12_377::new(
Fq_BLS12_377::from_repr(
self.x
.to_ark(),
)
.unwrap(),
Fq_BLS12_377::from_repr(
self.y
.to_ark(),
)
.unwrap(),
false,
)
}
pub fn from_ark(p: &G1Affine_BLS12_377) -> Self {
PointAffineNoInfinity_BLS12_377 {
x: BaseField_BLS12_377::from_ark(p.x.into_repr()),
y: BaseField_BLS12_377::from_ark(p.y.into_repr()),
}
}
}
impl Point_BLS12_377 {
// TODO: generics
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Point_BLS12_377 {
x: BaseField_BLS12_377 { s: get_fixed_limbs(x) },
y: BaseField_BLS12_377 { s: get_fixed_limbs(y) },
z: BaseField_BLS12_377 { s: get_fixed_limbs(z) },
}
}
pub fn from_xy_limbs(value: &[u32]) -> Point_BLS12_377 {
let l = value.len();
assert_eq!(
l,
3 * BASE_LIMBS_BLS12_377,
"length must be 3 * {}",
BASE_LIMBS_BLS12_377
);
Point_BLS12_377 {
x: BaseField_BLS12_377 {
s: value[..BASE_LIMBS_BLS12_377]
.try_into()
.unwrap(),
},
y: BaseField_BLS12_377 {
s: value[BASE_LIMBS_BLS12_377..BASE_LIMBS_BLS12_377 * 2]
.try_into()
.unwrap(),
},
z: BaseField_BLS12_377 {
s: value[BASE_LIMBS_BLS12_377 * 2..]
.try_into()
.unwrap(),
},
}
}
pub fn to_affine(&self) -> PointAffineNoInfinity_BLS12_377 {
let ark_affine = self.to_ark_affine();
PointAffineNoInfinity_BLS12_377 {
x: BaseField_BLS12_377::from_ark(
ark_affine
.x
.into_repr(),
),
y: BaseField_BLS12_377::from_ark(
ark_affine
.y
.into_repr(),
),
}
}
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BLS12_377 {
PointAffineNoInfinity_BLS12_377 { x: self.x, y: self.y }
}
}
impl ScalarField_BLS12_377 {
pub fn from_limbs(value: &[u32]) -> ScalarField_BLS12_377 {
ScalarField_BLS12_377 {
s: get_fixed_limbs(value),
}
}
}
#[cfg(test)]
mod tests {
use crate::curves::bls12_377::{Point_BLS12_377, ScalarField_BLS12_377};
#[test]
fn test_ark_scalar_convert() {
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
let scalar = ScalarField_BLS12_377::from_limbs(&limbs);
assert_eq!(
scalar.to_ark(),
scalar.to_ark_transmute(),
"{:08X?} {:08X?}",
scalar.to_ark(),
scalar.to_ark_transmute()
)
}
#[test]
#[allow(non_snake_case)]
fn test_point_equality() {
let left = Point_BLS12_377::zero();
let right = Point_BLS12_377::zero();
assert_eq!(left, right);
let right = Point_BLS12_377::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
assert_eq!(left, right);
let right = Point_BLS12_377::from_limbs(
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
&[0; 12],
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
);
assert!(left != right);
}
}

View File

@@ -1,407 +0,0 @@
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
use ark_bls12_381::{Fq as Fq_BLS12_381, G1Affine as G1Affine_BLS12_381, G1Projective as G1Projective_BLS12_381};
use ark_ec::AffineCurve;
use ark_ff::Field;
use ark_ff::{BigInteger256, BigInteger384, PrimeField};
use rustacuda_core::DeviceCopy;
use rustacuda_derive::DeviceCopy;
use serde::{Deserialize, Serialize};
use std::ffi::c_uint;
use std::mem::transmute;
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field_BLS12_381<const NUM_LIMBS: usize> {
pub s: [u32; NUM_LIMBS],
}
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BLS12_381<NUM_LIMBS> {}
impl<const NUM_LIMBS: usize> Default for Field_BLS12_381<NUM_LIMBS> {
fn default() -> Self {
Field_BLS12_381::zero()
}
}
impl<const NUM_LIMBS: usize> Field_BLS12_381<NUM_LIMBS> {
pub fn zero() -> Self {
Field_BLS12_381 { s: [0u32; NUM_LIMBS] }
}
pub fn one() -> Self {
let mut s = [0u32; NUM_LIMBS];
s[0] = 1;
Field_BLS12_381 { s }
}
fn to_bytes_le(&self) -> Vec<u8> {
self.s
.iter()
.map(|s| {
s.to_le_bytes()
.to_vec()
})
.flatten()
.collect::<Vec<_>>()
}
}
pub const BASE_LIMBS_BLS12_381: usize = 12;
pub const SCALAR_LIMBS_BLS12_381: usize = 8;
#[allow(non_camel_case_types)]
pub type BaseField_BLS12_381 = Field_BLS12_381<BASE_LIMBS_BLS12_381>;
#[allow(non_camel_case_types)]
pub type ScalarField_BLS12_381 = Field_BLS12_381<SCALAR_LIMBS_BLS12_381>;
impl Serialize for ScalarField_BLS12_381 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.s
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for ScalarField_BLS12_381 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = <[u32; SCALAR_LIMBS_BLS12_381] as Deserialize<'de>>::deserialize(deserializer)?;
Ok(ScalarField_BLS12_381 { s })
}
}
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val
.try_into()
.unwrap(),
_ => panic!("slice has too many elements"),
}
}
impl BaseField_BLS12_381 {
pub fn limbs(&self) -> [u32; BASE_LIMBS_BLS12_381] {
self.s
}
pub fn from_limbs(value: &[u32]) -> Self {
Self {
s: get_fixed_limbs(value),
}
}
pub fn to_ark(&self) -> BigInteger384 {
BigInteger384::new(
u32_vec_to_u64_vec(&self.limbs())
.try_into()
.unwrap(),
)
}
pub fn from_ark(ark: BigInteger384) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
}
impl ScalarField_BLS12_381 {
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BLS12_381] {
self.s
}
pub fn to_ark(&self) -> BigInteger256 {
BigInteger256::new(
u32_vec_to_u64_vec(&self.limbs())
.try_into()
.unwrap(),
)
}
pub fn from_ark(ark: BigInteger256) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
pub fn to_ark_transmute(&self) -> BigInteger256 {
unsafe { transmute(*self) }
}
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BLS12_381 {
unsafe { transmute(v) }
}
}
#[derive(Debug, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct Point_BLS12_381 {
pub x: BaseField_BLS12_381,
pub y: BaseField_BLS12_381,
pub z: BaseField_BLS12_381,
}
impl Default for Point_BLS12_381 {
fn default() -> Self {
Point_BLS12_381::zero()
}
}
impl Point_BLS12_381 {
pub fn zero() -> Self {
Point_BLS12_381 {
x: BaseField_BLS12_381::zero(),
y: BaseField_BLS12_381::one(),
z: BaseField_BLS12_381::zero(),
}
}
pub fn infinity() -> Self {
Self::zero()
}
pub fn to_ark(&self) -> G1Projective_BLS12_381 {
//TODO: generic conversion
self.to_ark_affine()
.into_projective()
}
pub fn to_ark_affine(&self) -> G1Affine_BLS12_381 {
//TODO: generic conversion
use std::ops::Mul;
let proj_x_field = Fq_BLS12_381::from_le_bytes_mod_order(
&self
.x
.to_bytes_le(),
);
let proj_y_field = Fq_BLS12_381::from_le_bytes_mod_order(
&self
.y
.to_bytes_le(),
);
let proj_z_field = Fq_BLS12_381::from_le_bytes_mod_order(
&self
.z
.to_bytes_le(),
);
let inverse_z = proj_z_field
.inverse()
.unwrap();
let aff_x = proj_x_field.mul(inverse_z);
let aff_y = proj_y_field.mul(inverse_z);
G1Affine_BLS12_381::new(aff_x, aff_y, false)
}
pub fn from_ark(ark: G1Projective_BLS12_381) -> Point_BLS12_381 {
let z_inv = ark
.z
.inverse()
.unwrap();
let z_invsq = z_inv * z_inv;
let z_invq3 = z_invsq * z_inv;
Point_BLS12_381 {
x: BaseField_BLS12_381::from_ark((ark.x * z_invsq).into_repr()),
y: BaseField_BLS12_381::from_ark((ark.y * z_invq3).into_repr()),
z: BaseField_BLS12_381::one(),
}
}
}
extern "C" {
fn eq_bls12_381(point1: *const Point_BLS12_381, point2: *const Point_BLS12_381) -> c_uint;
}
impl PartialEq for Point_BLS12_381 {
fn eq(&self, other: &Self) -> bool {
unsafe { eq_bls12_381(self, other) != 0 }
}
}
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct PointAffineNoInfinity_BLS12_381 {
pub x: BaseField_BLS12_381,
pub y: BaseField_BLS12_381,
}
impl Default for PointAffineNoInfinity_BLS12_381 {
fn default() -> Self {
PointAffineNoInfinity_BLS12_381 {
x: BaseField_BLS12_381::zero(),
y: BaseField_BLS12_381::zero(),
}
}
}
impl PointAffineNoInfinity_BLS12_381 {
// TODO: generics
///From u32 limbs x,y
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
PointAffineNoInfinity_BLS12_381 {
x: BaseField_BLS12_381 { s: get_fixed_limbs(x) },
y: BaseField_BLS12_381 { s: get_fixed_limbs(y) },
}
}
pub fn limbs(&self) -> Vec<u32> {
[
self.x
.limbs(),
self.y
.limbs(),
]
.concat()
}
pub fn to_projective(&self) -> Point_BLS12_381 {
Point_BLS12_381 {
x: self.x,
y: self.y,
z: BaseField_BLS12_381::one(),
}
}
pub fn to_ark(&self) -> G1Affine_BLS12_381 {
G1Affine_BLS12_381::new(
Fq_BLS12_381::new(
self.x
.to_ark(),
),
Fq_BLS12_381::new(
self.y
.to_ark(),
),
false,
)
}
pub fn to_ark_repr(&self) -> G1Affine_BLS12_381 {
G1Affine_BLS12_381::new(
Fq_BLS12_381::from_repr(
self.x
.to_ark(),
)
.unwrap(),
Fq_BLS12_381::from_repr(
self.y
.to_ark(),
)
.unwrap(),
false,
)
}
pub fn from_ark(p: &G1Affine_BLS12_381) -> Self {
PointAffineNoInfinity_BLS12_381 {
x: BaseField_BLS12_381::from_ark(p.x.into_repr()),
y: BaseField_BLS12_381::from_ark(p.y.into_repr()),
}
}
}
impl Point_BLS12_381 {
// TODO: generics
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Point_BLS12_381 {
x: BaseField_BLS12_381 { s: get_fixed_limbs(x) },
y: BaseField_BLS12_381 { s: get_fixed_limbs(y) },
z: BaseField_BLS12_381 { s: get_fixed_limbs(z) },
}
}
pub fn from_xy_limbs(value: &[u32]) -> Point_BLS12_381 {
let l = value.len();
assert_eq!(
l,
3 * BASE_LIMBS_BLS12_381,
"length must be 3 * {}",
BASE_LIMBS_BLS12_381
);
Point_BLS12_381 {
x: BaseField_BLS12_381 {
s: value[..BASE_LIMBS_BLS12_381]
.try_into()
.unwrap(),
},
y: BaseField_BLS12_381 {
s: value[BASE_LIMBS_BLS12_381..BASE_LIMBS_BLS12_381 * 2]
.try_into()
.unwrap(),
},
z: BaseField_BLS12_381 {
s: value[BASE_LIMBS_BLS12_381 * 2..]
.try_into()
.unwrap(),
},
}
}
pub fn to_affine(&self) -> PointAffineNoInfinity_BLS12_381 {
let ark_affine = self.to_ark_affine();
PointAffineNoInfinity_BLS12_381 {
x: BaseField_BLS12_381::from_ark(
ark_affine
.x
.into_repr(),
),
y: BaseField_BLS12_381::from_ark(
ark_affine
.y
.into_repr(),
),
}
}
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BLS12_381 {
PointAffineNoInfinity_BLS12_381 { x: self.x, y: self.y }
}
}
impl ScalarField_BLS12_381 {
pub fn from_limbs(value: &[u32]) -> ScalarField_BLS12_381 {
ScalarField_BLS12_381 {
s: get_fixed_limbs(value),
}
}
}
#[cfg(test)]
mod tests {
use crate::curves::bls12_381::{Point_BLS12_381, ScalarField_BLS12_381};
#[test]
fn test_ark_scalar_convert() {
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
let scalar = ScalarField_BLS12_381::from_limbs(&limbs);
assert_eq!(
scalar.to_ark(),
scalar.to_ark_transmute(),
"{:08X?} {:08X?}",
scalar.to_ark(),
scalar.to_ark_transmute()
)
}
#[test]
#[allow(non_snake_case)]
fn test_point_equality() {
let left = Point_BLS12_381::zero();
let right = Point_BLS12_381::zero();
assert_eq!(left, right);
let right = Point_BLS12_381::from_limbs(&[0; 12], &[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], &[0; 12]);
assert_eq!(left, right);
let right = Point_BLS12_381::from_limbs(
&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
&[0; 12],
&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
);
assert!(left != right);
}
}

View File

@@ -1,353 +0,0 @@
use crate::utils::{u32_vec_to_u64_vec, u64_vec_to_u32_vec};
use ark_bn254::{Fq as Fq_BN254, G1Affine as G1Affine_BN254, G1Projective as G1Projective_BN254};
use ark_ec::AffineCurve;
use ark_ff::Field;
use ark_ff::{BigInteger256, PrimeField};
use rustacuda_core::DeviceCopy;
use rustacuda_derive::DeviceCopy;
use std::ffi::c_uint;
use std::mem::transmute;
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field_BN254<const NUM_LIMBS: usize> {
pub s: [u32; NUM_LIMBS],
}
unsafe impl<const NUM_LIMBS: usize> DeviceCopy for Field_BN254<NUM_LIMBS> {}
impl<const NUM_LIMBS: usize> Default for Field_BN254<NUM_LIMBS> {
fn default() -> Self {
Field_BN254::zero()
}
}
impl<const NUM_LIMBS: usize> Field_BN254<NUM_LIMBS> {
pub fn zero() -> Self {
Field_BN254 { s: [0u32; NUM_LIMBS] }
}
pub fn one() -> Self {
let mut s = [0u32; NUM_LIMBS];
s[0] = 1;
Field_BN254 { s }
}
fn to_bytes_le(&self) -> Vec<u8> {
self.s
.iter()
.map(|s| {
s.to_le_bytes()
.to_vec()
})
.flatten()
.collect::<Vec<_>>()
}
}
pub const BASE_LIMBS_BN254: usize = 8;
pub const SCALAR_LIMBS_BN254: usize = 8;
#[allow(non_camel_case_types)]
pub type BaseField_BN254 = Field_BN254<BASE_LIMBS_BN254>;
#[allow(non_camel_case_types)]
pub type ScalarField_BN254 = Field_BN254<SCALAR_LIMBS_BN254>;
fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val
.try_into()
.unwrap(),
_ => panic!("slice has too many elements"),
}
}
impl ScalarField_BN254 {
pub fn limbs(&self) -> [u32; SCALAR_LIMBS_BN254] {
self.s
}
pub fn to_ark(&self) -> BigInteger256 {
BigInteger256::new(
u32_vec_to_u64_vec(&self.limbs())
.try_into()
.unwrap(),
)
}
pub fn from_ark(ark: BigInteger256) -> Self {
Self::from_limbs(&u64_vec_to_u32_vec(&ark.0))
}
pub fn to_ark_transmute(&self) -> BigInteger256 {
unsafe { transmute(*self) }
}
pub fn from_ark_transmute(v: BigInteger256) -> ScalarField_BN254 {
unsafe { transmute(v) }
}
}
#[derive(Debug, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct Point_BN254 {
pub x: BaseField_BN254,
pub y: BaseField_BN254,
pub z: BaseField_BN254,
}
impl Default for Point_BN254 {
fn default() -> Self {
Point_BN254::zero()
}
}
impl Point_BN254 {
pub fn zero() -> Self {
Point_BN254 {
x: BaseField_BN254::zero(),
y: BaseField_BN254::one(),
z: BaseField_BN254::zero(),
}
}
pub fn infinity() -> Self {
Self::zero()
}
pub fn to_ark(&self) -> G1Projective_BN254 {
//TODO: generic conversion
self.to_ark_affine()
.into_projective()
}
pub fn to_ark_affine(&self) -> G1Affine_BN254 {
//TODO: generic conversion
use std::ops::Mul;
let proj_x_field = Fq_BN254::from_le_bytes_mod_order(
&self
.x
.to_bytes_le(),
);
let proj_y_field = Fq_BN254::from_le_bytes_mod_order(
&self
.y
.to_bytes_le(),
);
let proj_z_field = Fq_BN254::from_le_bytes_mod_order(
&self
.z
.to_bytes_le(),
);
let inverse_z = proj_z_field
.inverse()
.unwrap();
let aff_x = proj_x_field.mul(inverse_z);
let aff_y = proj_y_field.mul(inverse_z);
G1Affine_BN254::new(aff_x, aff_y, false)
}
pub fn from_ark(ark: G1Projective_BN254) -> Point_BN254 {
let z_inv = ark
.z
.inverse()
.unwrap();
let z_invsq = z_inv * z_inv;
let z_invq3 = z_invsq * z_inv;
Point_BN254 {
x: BaseField_BN254::from_ark((ark.x * z_invsq).into_repr()),
y: BaseField_BN254::from_ark((ark.y * z_invq3).into_repr()),
z: BaseField_BN254::one(),
}
}
}
extern "C" {
fn eq_bn254(point1: *const Point_BN254, point2: *const Point_BN254) -> c_uint;
}
impl PartialEq for Point_BN254 {
fn eq(&self, other: &Self) -> bool {
unsafe { eq_bn254(self, other) != 0 }
}
}
#[derive(Debug, PartialEq, Clone, Copy, DeviceCopy)]
#[repr(C)]
pub struct PointAffineNoInfinity_BN254 {
pub x: BaseField_BN254,
pub y: BaseField_BN254,
}
impl Default for PointAffineNoInfinity_BN254 {
fn default() -> Self {
PointAffineNoInfinity_BN254 {
x: BaseField_BN254::zero(),
y: BaseField_BN254::zero(),
}
}
}
impl PointAffineNoInfinity_BN254 {
// TODO: generics
///From u32 limbs x,y
pub fn from_limbs(x: &[u32], y: &[u32]) -> Self {
PointAffineNoInfinity_BN254 {
x: BaseField_BN254 { s: get_fixed_limbs(x) },
y: BaseField_BN254 { s: get_fixed_limbs(y) },
}
}
pub fn limbs(&self) -> Vec<u32> {
[
self.x
.limbs(),
self.y
.limbs(),
]
.concat()
}
pub fn to_projective(&self) -> Point_BN254 {
Point_BN254 {
x: self.x,
y: self.y,
z: BaseField_BN254::one(),
}
}
pub fn to_ark(&self) -> G1Affine_BN254 {
G1Affine_BN254::new(
Fq_BN254::new(
self.x
.to_ark(),
),
Fq_BN254::new(
self.y
.to_ark(),
),
false,
)
}
pub fn to_ark_repr(&self) -> G1Affine_BN254 {
G1Affine_BN254::new(
Fq_BN254::from_repr(
self.x
.to_ark(),
)
.unwrap(),
Fq_BN254::from_repr(
self.y
.to_ark(),
)
.unwrap(),
false,
)
}
pub fn from_ark(p: &G1Affine_BN254) -> Self {
PointAffineNoInfinity_BN254 {
x: BaseField_BN254::from_ark(p.x.into_repr()),
y: BaseField_BN254::from_ark(p.y.into_repr()),
}
}
}
impl Point_BN254 {
// TODO: generics
pub fn from_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Point_BN254 {
x: BaseField_BN254 { s: get_fixed_limbs(x) },
y: BaseField_BN254 { s: get_fixed_limbs(y) },
z: BaseField_BN254 { s: get_fixed_limbs(z) },
}
}
pub fn from_xy_limbs(value: &[u32]) -> Point_BN254 {
let l = value.len();
assert_eq!(l, 3 * BASE_LIMBS_BN254, "length must be 3 * {}", BASE_LIMBS_BN254);
Point_BN254 {
x: BaseField_BN254 {
s: value[..BASE_LIMBS_BN254]
.try_into()
.unwrap(),
},
y: BaseField_BN254 {
s: value[BASE_LIMBS_BN254..BASE_LIMBS_BN254 * 2]
.try_into()
.unwrap(),
},
z: BaseField_BN254 {
s: value[BASE_LIMBS_BN254 * 2..]
.try_into()
.unwrap(),
},
}
}
pub fn to_affine(&self) -> PointAffineNoInfinity_BN254 {
let ark_affine = self.to_ark_affine();
PointAffineNoInfinity_BN254 {
x: BaseField_BN254::from_ark(
ark_affine
.x
.into_repr(),
),
y: BaseField_BN254::from_ark(
ark_affine
.y
.into_repr(),
),
}
}
pub fn to_xy_strip_z(&self) -> PointAffineNoInfinity_BN254 {
PointAffineNoInfinity_BN254 { x: self.x, y: self.y }
}
}
impl ScalarField_BN254 {
pub fn from_limbs(value: &[u32]) -> ScalarField_BN254 {
ScalarField_BN254 {
s: get_fixed_limbs(value),
}
}
}
#[cfg(test)]
mod tests {
use crate::curves::bn254::{Point_BN254, ScalarField_BN254};
#[test]
fn test_ark_scalar_convert() {
let limbs = [0x0fffffff, 1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7];
let scalar = ScalarField_BN254::from_limbs(&limbs);
assert_eq!(
scalar.to_ark(),
scalar.to_ark_transmute(),
"{:08X?} {:08X?}",
scalar.to_ark(),
scalar.to_ark_transmute()
)
}
#[test]
#[allow(non_snake_case)]
fn test_point_equality() {
let left = Point_BN254::zero();
let right = Point_BN254::zero();
assert_eq!(left, right);
let right = Point_BN254::from_limbs(&[0; 8], &[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8]);
assert_eq!(left, right);
let right = Point_BN254::from_limbs(&[2, 0, 0, 0, 0, 0, 0, 0], &[0; 8], &[1, 0, 0, 0, 0, 0, 0, 0]);
assert!(left != right);
}
}

View File

@@ -1,3 +0,0 @@
pub mod bls12_377;
pub mod bls12_381;
pub mod bn254;

View File

@@ -1,5 +0,0 @@
pub mod curves;
pub mod test_bls12_377;
pub mod test_bls12_381;
pub mod test_bn254;
pub mod utils;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,99 +0,0 @@
use rand::rngs::StdRng;
use rand::RngCore;
use rand::SeedableRng;
pub fn from_limbs<T>(limbs: Vec<u32>, chunk_size: usize, f: fn(&[u32]) -> T) -> Vec<T> {
let points = limbs
.chunks(chunk_size)
.map(|lmbs| f(lmbs))
.collect::<Vec<T>>();
points
}
pub fn u32_vec_to_u64_vec(arr_u32: &[u32]) -> Vec<u64> {
let len = (arr_u32.len() / 2) as usize;
let mut arr_u64 = vec![0u64; len];
for i in 0..len {
arr_u64[i] = u64::from(arr_u32[i * 2]) | (u64::from(arr_u32[i * 2 + 1]) << 32);
}
arr_u64
}
pub fn u64_vec_to_u32_vec(arr_u64: &[u64]) -> Vec<u32> {
let len = arr_u64.len() * 2;
let mut arr_u32 = vec![0u32; len];
for i in 0..arr_u64.len() {
arr_u32[i * 2] = arr_u64[i] as u32;
arr_u32[i * 2 + 1] = (arr_u64[i] >> 32) as u32;
}
arr_u32
}
pub fn get_rng(seed: Option<u64>) -> Box<dyn RngCore> {
//TOOD: this func is universal
let rng: Box<dyn RngCore> = match seed {
Some(seed) => Box::new(StdRng::seed_from_u64(seed)),
None => Box::new(rand::thread_rng()),
};
rng
}
#[cfg(test)]
mod tests {
use ark_ff::BigInteger256;
use crate::curves::bls12_381::ScalarField_BLS12_381 as ScalarField;
use super::*;
#[test]
fn test_u32_vec_to_u64_vec() {
let arr_u32 = [1, 0x0fffffff, 3, 0x2fffffff, 5, 0x4fffffff, 7, 0x6fffffff];
let s = ScalarField::from_ark_transmute(BigInteger256::new(
u32_vec_to_u64_vec(&arr_u32)
.try_into()
.unwrap(),
))
.limbs();
assert_eq!(arr_u32.to_vec(), s);
let arr_u64_expected = [
0x0FFFFFFF00000001,
0x2FFFFFFF00000003,
0x4FFFFFFF00000005,
0x6FFFFFFF00000007,
];
assert_eq!(
u32_vec_to_u64_vec(&arr_u32),
arr_u64_expected,
"{:016X?}",
u32_vec_to_u64_vec(&arr_u32)
);
}
#[test]
fn test_u64_vec_to_u32_vec() {
let arr_u64 = [
0x2FFFFFFF00000001,
0x4FFFFFFF00000003,
0x6FFFFFFF00000005,
0x8FFFFFFF00000007,
];
let arr_u32_expected = [1, 0x2fffffff, 3, 0x4fffffff, 5, 0x6fffffff, 7, 0x8fffffff];
assert_eq!(
u64_vec_to_u32_vec(&arr_u64),
arr_u32_expected,
"{:016X?}",
u64_vec_to_u32_vec(&arr_u64)
);
}
}

3
wrappers/rust/Cargo.toml Normal file
View File

@@ -0,0 +1,3 @@
[workspace]
resolver = "2"
members = ["icicle-cuda-runtime", "icicle-core", "icicle-curves/icicle-bn254"]

View File

@@ -0,0 +1,36 @@
[package]
name = "icicle-core"
version = "0.1.0"
edition = "2021"
authors = [ "Ingonyama" ]
description = "A library for GPU ZK acceleration by Ingonyama"
homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/icicle"
[dependencies]
icicle-cuda-runtime = { path = "../icicle-cuda-runtime" }
ark-ff = { version = "0.4.0", optional = true }
ark-ec = { version = "0.4.0", optional = true, features = [ "parallel" ] }
# [build-dependencies]
# cc = { version = "1.0", features = ["parallel"] }
# cmake = "*"
# bindgen = "*"
# libc = "*" #TODO: move libc dependencies to build
# [dev-dependencies]
# rustacuda = "0.1"
# rustacuda_core = "0.1"
# rustacuda_derive = "0.1"
# "criterion" = "*"
[features]
default = []
arkworks = ["ark-ff", "ark-ec"]
# TODO: impl G2 and EC NTT
g2 = []
ec_ntt = []

View File

@@ -0,0 +1,193 @@
use crate::field::{Field, FieldConfig};
#[cfg(feature = "arkworks")]
use crate::traits::ArkConvertible;
#[cfg(feature = "arkworks")]
use ark_ec::models::CurveConfig as ArkCurveConfig;
#[cfg(feature = "arkworks")]
use ark_ec::short_weierstrass::{Affine as ArkAffine, Projective as ArkProjective, SWCurveConfig};
use std::ffi::{c_uint, c_void};
use std::marker::PhantomData;
pub trait CurveConfig: PartialEq + Copy + Clone {
fn eq_proj(point1: *const c_void, point2: *const c_void) -> c_uint;
fn to_affine(point: *const c_void, point_aff: *mut c_void);
#[cfg(feature = "arkworks")]
type ArkSWConfig: SWCurveConfig;
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct Projective<T, C: CurveConfig> {
pub x: T,
pub y: T,
pub z: T,
p: PhantomData<C>,
}
#[derive(Debug, PartialEq, Clone, Copy)]
#[repr(C)]
pub struct Affine<T, C: CurveConfig> {
pub x: T,
pub y: T,
p: PhantomData<C>,
}
impl<const NUM_LIMBS: usize, F, C> Affine<Field<NUM_LIMBS, F>, C>
where
F: FieldConfig,
C: CurveConfig,
{
// While this is not a true zero point and not even a valid point, it's still useful
// both as a handy default as well as a representation of zero points in other codebases
pub fn zero() -> Self {
Affine {
x: Field::<NUM_LIMBS, F>::zero(),
y: Field::<NUM_LIMBS, F>::zero(),
p: PhantomData,
}
}
pub fn set_limbs(x: &[u32], y: &[u32]) -> Self {
Affine {
x: Field::<NUM_LIMBS, F>::set_limbs(x),
y: Field::<NUM_LIMBS, F>::set_limbs(y),
p: PhantomData,
}
}
pub fn to_projective(&self) -> Projective<Field<NUM_LIMBS, F>, C> {
Projective {
x: self.x,
y: self.y,
z: Field::<NUM_LIMBS, F>::one(),
p: PhantomData,
}
}
}
impl<const NUM_LIMBS: usize, F, C> From<Affine<Field<NUM_LIMBS, F>, C>> for Projective<Field<NUM_LIMBS, F>, C>
where
F: FieldConfig,
C: CurveConfig,
{
fn from(item: Affine<Field<NUM_LIMBS, F>, C>) -> Self {
Self {
x: item.x,
y: item.y,
z: Field::<NUM_LIMBS, F>::one(),
p: PhantomData,
}
}
}
impl<const NUM_LIMBS: usize, F, C> Projective<Field<NUM_LIMBS, F>, C>
where
F: FieldConfig,
C: CurveConfig,
{
pub fn zero() -> Self {
Projective {
x: Field::<NUM_LIMBS, F>::zero(),
y: Field::<NUM_LIMBS, F>::one(),
z: Field::<NUM_LIMBS, F>::zero(),
p: PhantomData,
}
}
pub fn set_limbs(x: &[u32], y: &[u32], z: &[u32]) -> Self {
Projective {
x: Field::<NUM_LIMBS, F>::set_limbs(x),
y: Field::<NUM_LIMBS, F>::set_limbs(y),
z: Field::<NUM_LIMBS, F>::set_limbs(z),
p: PhantomData,
}
}
}
impl<const NUM_LIMBS: usize, F, C> PartialEq for Projective<Field<NUM_LIMBS, F>, C>
where
F: FieldConfig,
C: CurveConfig,
{
fn eq(&self, other: &Self) -> bool {
C::eq_proj(self as *const _ as *const c_void, other as *const _ as *const c_void) != 0
}
}
impl<const NUM_LIMBS: usize, F, C> From<Projective<Field<NUM_LIMBS, F>, C>> for Affine<Field<NUM_LIMBS, F>, C>
where
F: FieldConfig,
C: CurveConfig,
{
fn from(item: Projective<Field<NUM_LIMBS, F>, C>) -> Self {
let mut aff = Self::zero();
C::to_affine(&item as *const _ as *const c_void, &mut aff as *mut _ as *mut c_void);
aff
}
}
#[cfg(feature = "arkworks")]
impl<const NUM_LIMBS: usize, F, C> ArkConvertible for Affine<Field<NUM_LIMBS, F>, C>
where
C: CurveConfig,
F: FieldConfig<ArkField = <<C as CurveConfig>::ArkSWConfig as ArkCurveConfig>::BaseField>,
{
type ArkEquivalent = ArkAffine<C::ArkSWConfig>;
fn to_ark(&self) -> Self::ArkEquivalent {
let proj_x = self
.x
.to_ark();
let proj_y = self
.y
.to_ark();
Self::ArkEquivalent::new_unchecked(proj_x, proj_y)
}
fn from_ark(ark: Self::ArkEquivalent) -> Self {
Self {
x: Field::<NUM_LIMBS, F>::from_ark(ark.x),
y: Field::<NUM_LIMBS, F>::from_ark(ark.y),
p: PhantomData,
}
}
}
#[cfg(feature = "arkworks")]
impl<const NUM_LIMBS: usize, F, C> ArkConvertible for Projective<Field<NUM_LIMBS, F>, C>
where
C: CurveConfig,
F: FieldConfig<ArkField = <<C as CurveConfig>::ArkSWConfig as ArkCurveConfig>::BaseField>,
{
type ArkEquivalent = ArkProjective<C::ArkSWConfig>;
fn to_ark(&self) -> Self::ArkEquivalent {
let proj_x = self
.x
.to_ark();
let proj_y = self
.y
.to_ark();
let proj_z = self
.z
.to_ark();
// conversion between projective used in icicle and Jacobian used in arkworks
let proj_x = proj_x * proj_z;
let proj_y = proj_y * proj_z * proj_z;
Self::ArkEquivalent::new_unchecked(proj_x, proj_y, proj_z)
}
fn from_ark(ark: Self::ArkEquivalent) -> Self {
// conversion between Jacobian used in arkworks and projective used in icicle
let proj_x = ark.x * ark.z;
let proj_z = ark.z * ark.z * ark.z;
Self {
x: Field::<NUM_LIMBS, F>::from_ark(proj_x),
y: Field::<NUM_LIMBS, F>::from_ark(ark.y),
z: Field::<NUM_LIMBS, F>::from_ark(proj_z),
p: PhantomData,
}
}
}

View File

@@ -0,0 +1,98 @@
#[cfg(feature = "arkworks")]
use crate::traits::ArkConvertible;
#[cfg(feature = "arkworks")]
use ark_ff::{BigInteger, PrimeField};
use std::marker::PhantomData;
#[cfg(feature = "arkworks")]
pub trait FieldConfig: PartialEq + Copy + Clone {
type ArkField: PrimeField;
}
#[cfg(not(feature = "arkworks"))]
pub trait FieldConfig: PartialEq + Copy + Clone {}
#[derive(Debug, PartialEq, Copy, Clone)]
#[repr(C)]
pub struct Field<const NUM_LIMBS: usize, F: FieldConfig> {
limbs: [u32; NUM_LIMBS],
p: PhantomData<F>,
}
pub(crate) fn get_fixed_limbs<const NUM_LIMBS: usize>(val: &[u32]) -> [u32; NUM_LIMBS] {
match val.len() {
n if n < NUM_LIMBS => {
let mut padded: [u32; NUM_LIMBS] = [0; NUM_LIMBS];
padded[..val.len()].copy_from_slice(&val);
padded
}
n if n == NUM_LIMBS => val
.try_into()
.unwrap(),
_ => panic!("slice has too many elements"),
}
}
impl<const NUM_LIMBS: usize, F: FieldConfig> Field<NUM_LIMBS, F> {
pub fn get_limbs(&self) -> [u32; NUM_LIMBS] {
self.limbs
}
pub fn set_limbs(value: &[u32]) -> Self {
Self {
limbs: get_fixed_limbs(value),
p: PhantomData,
}
}
pub fn to_bytes_le(&self) -> Vec<u8> {
self.limbs
.iter()
.map(|limb| {
limb.to_le_bytes()
.to_vec()
})
.flatten()
.collect::<Vec<_>>()
}
pub fn from_bytes_le(bytes: &[u8]) -> Self {
let limbs = bytes
.chunks(4)
.map(|chunk| {
u32::from_le_bytes(
chunk
.try_into()
.unwrap(),
)
})
.collect::<Vec<_>>();
Self::set_limbs(&limbs)
}
pub fn zero() -> Self {
Field {
limbs: [0u32; NUM_LIMBS],
p: PhantomData,
}
}
pub fn one() -> Self {
let mut limbs = [0u32; NUM_LIMBS];
limbs[0] = 1;
Field { limbs, p: PhantomData }
}
}
#[cfg(feature = "arkworks")]
impl<const NUM_LIMBS: usize, F: FieldConfig> ArkConvertible for Field<NUM_LIMBS, F> {
type ArkEquivalent = F::ArkField;
fn to_ark(&self) -> Self::ArkEquivalent {
F::ArkField::from_le_bytes_mod_order(&self.to_bytes_le())
}
fn from_ark(ark: Self::ArkEquivalent) -> Self {
let ark_bigint: <Self::ArkEquivalent as PrimeField>::BigInt = ark.into();
Self::from_bytes_le(&ark_bigint.to_bytes_le())
}
}

View File

@@ -0,0 +1,5 @@
pub mod curve;
pub mod field;
pub mod msm;
pub mod ntt;
pub mod traits;

View File

@@ -0,0 +1,98 @@
use icicle_cuda_runtime::device_context::DeviceContext;
/*
/**
* @struct MSMConfig
* Struct that encodes MSM parameters to be passed into the [msm](@ref msm) function.
*/
struct MSMConfig {
bool are_scalars_on_device; /**< True if scalars are on device and false if they're on host. Default value: false. */
bool are_scalars_montgomery_form; /**< True if scalars are in Montgomery form and false otherwise. Default value: true. */
int points_size; /**< Number of points in the MSM. If a batch of MSMs needs to be computed, this should be a number
* of different points. So, if each MSM re-uses the same set of points, this variable is set equal
* to the MSM size. And if every MSM uses a distinct set of points, it should be set to the product of
* MSM size and [batch_size](@ref batch_size). Default value: 0 (meaning it's equal to the MSM size). */
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the number of computations
* to make, on-line memory footprint, but increase the static memory footprint. Default value: 1 (i.e. don't pre-compute). */
bool are_points_on_device; /**< True if points are on device and false if they're on host. Default value: false. */
bool are_points_montgomery_form; /**< True if coordinates of points are in Montgomery form and false otherwise. Default value: true. */
int batch_size; /**< The number of MSMs to compute. Default value: 1. */
bool are_results_on_device; /**< True if the results should be on device and false if they should be on host. If set to false,
* `is_async` won't take effect because a synchronization is needed to transfer results to the host. Default value: false. */
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket method"
* that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
* footprint but also more parallelism and less computational complexity (up to a certain point).
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a different
* (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field). */
bool is_big_triangle; /**< Whether to do "bucket accumulation" serially. Decreases computational complexity, but also greatly
* decreases parallelism, so only suitable for large batches of MSMs. Default value: false. */
int large_bucket_factor; /**< Variable that controls how sensitive the algorithm is to the buckets that occur very frequently.
* Useful for efficient treatment of non-uniform distributions of scalars and "top windows" with few bits.
* Can be set to 0 to disable separate treatment of large buckets altogether. Default value: 10. */
int is_async; /**< Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
* and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`.
* If set to false, the MSM function will block the current CPU thread. */
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. See [DeviceContext](@ref `device_context::DeviceContext`). */
};
*/
/// Struct that encodes MSM parameters to be passed into the `msm` function.
#[repr(C)]
pub struct MSMConfig<'a> {
/// True if scalars are on device and false if they're on host. Default value: false.
pub are_scalars_on_device: bool,
/// True if scalars are in Montgomery form and false otherwise. Default value: true.
pub are_scalars_montgomery_form: bool,
/// Number of points in the MSM. If a batch of MSMs needs to be computed, this should be a number
/// of different points. So, if each MSM re-uses the same set of points, this variable is set equal
/// to the MSM size. And if every MSM uses a distinct set of points, it should be set to the product of
/// MSM size and batch_size. Default value: 0 (meaning it's equal to the MSM size).
pub points_size: usize, // Note: `unsigned` in C++ corresponds to `u32` in Rust
/// The number of extra points to pre-compute for each point. Larger values decrease the number of computations
/// to make, on-line memory footprint, but increase the static memory footprint. Default value: 1 (i.e. don't pre-compute).
pub precompute_factor: usize,
/// True if points are on device and false if they're on host. Default value: false.
pub are_points_on_device: bool,
/// True if coordinates of points are in Montgomery form and false otherwise. Default value: true.
pub are_points_montgomery_form: bool,
/// The number of MSMs to compute. Default value: 1.
pub batch_size: usize,
/// True if the results should be on device and false if they should be on host. If set to false,
/// `is_async` won't take effect because a synchronization is needed to transfer results to the host. Default value: false.
pub are_results_on_device: bool,
/// `c` value, or "window bitsize" which is the main parameter of the "bucket method"
/// that we use to solve the MSM problem. As a rule of thumb, larger value means more on-line memory
/// footprint but also more parallelism and less computational complexity (up to a certain point).
/// Default value: 0 (the optimal value of `c` is chosen automatically).
pub c: usize,
/// Number of bits of the largest scalar. Typically equals the bitsize of scalar field, but if a different
/// (better) upper bound is known, it should be reflected in this variable. Default value: 0 (set to the bitsize of scalar field).
pub bitsize: usize,
/// Whether to do "bucket accumulation" serially. Decreases computational complexity, but also greatly
/// decreases parallelism, so only suitable for large batches of MSMs. Default value: false.
pub is_big_triangle: bool,
/// Variable that controls how sensitive the algorithm is to the buckets that occur very frequently.
/// Useful for efficient treatment of non-uniform distributions of scalars and "top windows" with few bits.
/// Can be set to 0 to disable separate treatment of large buckets altogether. Default value: 10.
pub large_bucket_factor: usize,
/// Whether to run the MSM asyncronously. If set to `true`, the MSM function will be non-blocking
/// and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`.
/// If set to `false`, the MSM function will block the current CPU thread.
pub is_async: bool,
/// Details related to the device such as its id and stream id.
pub ctx: DeviceContext<'a>,
}

View File

@@ -0,0 +1,104 @@
use icicle_cuda_runtime::device_context::DeviceContext;
use std::os::raw::c_int;
/**
* @enum Ordering
* How to order inputs and outputs of the NTT:
* - kNN: inputs and outputs are natural-order (example of natural ordering: \f$ \{a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7\} \f$).
* - kNR: inputs are natural-order and outputs are bit-reversed-order (example of bit-reversed ordering: \f$ \{a_0, a_4, a_2, a_6, a_1, a_5, a_3, a_7\} \f$).
* - kRN: inputs are bit-reversed-order and outputs are natural-order.
* - kRR: inputs and outputs are bit-reversed-order.
*/
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Ordering {
kNN,
kNR,
kRN,
kRR,
}
/**
* @enum Decimation
* Decimation of the NTT algorithm:
* - kDIT: decimation in time.
* - kDIF: decimation in frequency.
*/
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Decimation {
kDIT,
kDIF,
}
/**
* @enum Butterfly
* [Butterfly](https://en.wikipedia.org/wiki/Butterfly_diagram) used in the NTT algorithm (i.e. what happens to each pair of inputs on every iteration):
* - kCooleyTukey: Cooley-Tukey butterfly.
* - kGentlemanSande: Gentleman-Sande butterfly.
*/
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Butterfly {
kCooleyTukey,
kGentlemanSande,
}
/**
* @struct NTTConfig
* Struct that encodes NTT parameters to be passed into the [ntt](@ref ntt) function.
*/
#[repr(C)]
#[derive(Debug)]
pub struct NTTConfigCuda<'a, E, S> {
pub inout: *mut E,
/**< Input that's mutated in-place by this function. Length of this array needs to be \f$ size \cdot config.batch_size \f$.
* Note that if inputs are in Montgomery form, the outputs will be as well and vice-verse: non-Montgomery inputs produce non-Montgomety outputs.*/
pub is_input_on_device: bool,
/**< True if inputs/outputs are on device and false if they're on host. Default value: false. */
pub is_inverse: bool,
/**< True if true . Default value: false. */
pub ordering: Ordering,
/**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. */
pub decimation: Decimation,
/**< Decimation of the algorithm, see [Decimation](@ref Decimation). Default value: `Decimation::kDIT`.
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
* `Decimation::kDIT` and if ordering is `Ordering::kNR` — to `Decimation::kDIF`. */
pub butterfly: Butterfly,
/**< Butterfly used by the NTT. See [Butterfly](@ref Butterfly). Default value: `Butterfly::kCooleyTukey`.
* __Note:__ this variable exists mainly for compatibility with codebases that use similar notation.
* If [ordering](@ref ordering) is `Ordering::kRN`, the value of this variable will be overridden to
* `Butterfly::kCooleyTukey` and if ordering is `Ordering::kNR` — to `Butterfly::kGentlemanSande`. */
pub is_coset: bool,
/**< If false, NTT is computed on a subfield given by [twiddles](@ref twiddles). If true, NTT is computed
* on a coset of [twiddles](@ref twiddles) given by [the coset generator](@ref coset_gen), so:
* \f$ \{coset\_gen\cdot\omega^0, coset\_gen\cdot\omega^1, \dots, coset\_gen\cdot\omega^{n-1}\} \f$. Default value: false. */
pub coset_gen: *const S,
/**< The field element that generates a coset if [is_coset](@ref is_coset) is true.
* Otherwise should be set to `nullptr`. Default value: `nullptr`. */
pub twiddles: *const S,
/**< "Twiddle factors", (or "domain", or "roots of unity") on which the NTT is evaluated.
* This pointer is expected to live on device. The order is as follows:
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle factors
* are generated online using the default generator (TODO: link to twiddle gen here) and function
* [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
pub inv_twiddles: *const S,
/**< "Inverse twiddle factors", (or "domain", or "roots of unity") on which the iNTT is evaluated.
* This pointer is expected to live on device. The order is as follows:
* \f$ \{\omega^0=1, \omega^1, \dots, \omega^{n-1}\} \f$. If this pointer is `nullptr`, twiddle factors
* are generated online using the default generator (TODO: link to twiddle gen here) and function
* [GenerateTwiddleFactors](@ref GenerateTwiddleFactors). Default value: `nullptr`. */
pub size: c_int,
/**< NTT size \f$ n \f$. If a batch of NTTs (which all need to have the same size) is computed, this is the size of 1 NTT. */
pub batch_size: c_int,
/**< The number of NTTs to compute. Default value: 1. */
pub is_preserving_twiddles: bool,
/**< If true, twiddle factors are preserved on device for subsequent use in config and not freed after calculation. Default value: false. */
pub is_output_on_device: bool,
/**< If true, output is preserved on device for subsequent use in config and not freed after calculation. Default value: false. */
pub ctx: DeviceContext<'a>, /*< Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext). */
}

View File

@@ -0,0 +1,7 @@
#[cfg(feature = "arkworks")]
pub trait ArkConvertible {
type ArkEquivalent;
fn to_ark(&self) -> Self::ArkEquivalent;
fn from_ark(ark: Self::ArkEquivalent) -> Self;
}

View File

@@ -0,0 +1,14 @@
[package]
name = "icicle-cuda-runtime"
version = "0.1.0"
edition = "2021"
authors = [ "Ingonyama" ]
description = "Ingonyama's Rust wrapper of CUDA runtime"
homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/icicle"
[dependencies]
bitflags = "2.4"
[build-dependencies]
bindgen = "*"

View File

@@ -0,0 +1,82 @@
// Based on https://github.com/matter-labs/z-prize-msm-gpu/blob/main/bellman-cuda-rust/cudart-sys/build.rs
use std::fs;
use std::path::PathBuf;
fn cuda_include_path() -> &'static str {
#[cfg(target_os = "windows")]
{
concat!(env!("CUDA_PATH"), "/include")
}
#[cfg(target_os = "linux")]
{
"/usr/local/cuda/include"
}
}
fn cuda_lib_path() -> &'static str {
#[cfg(target_os = "windows")]
{
concat!(env!("CUDA_PATH"), "/lib/x64")
}
#[cfg(target_os = "linux")]
{
"/usr/local/cuda/lib64"
}
}
fn main() {
let cuda_runtime_api_path = PathBuf::from(cuda_include_path())
.join("cuda_runtime_api.h")
.to_string_lossy()
.to_string();
println!("cargo:rustc-link-search=native={}", cuda_lib_path());
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rerun-if-changed={}", cuda_runtime_api_path);
let bindings = bindgen::Builder::default()
.header(cuda_runtime_api_path)
.size_t_is_usize(true)
.generate_comments(false)
.layout_tests(false)
.allowlist_type("cudaError")
.rustified_enum("cudaError")
.must_use_type("cudaError")
// device management
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html
.allowlist_function("cudaSetDevice")
// error handling
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__ERROR.html
.allowlist_function("cudaGetLastError")
// stream management
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
.allowlist_function("cudaStreamCreate")
.allowlist_var("cudaStreamDefault")
.allowlist_var("cudaStreamNonBlocking")
.allowlist_function("cudaStreamCreateWithFlags")
.allowlist_function("cudaStreamDestroy")
.allowlist_function("cudaStreamQuery")
.allowlist_function("cudaStreamSynchronize")
.allowlist_var("cudaEventWaitDefault")
.allowlist_var("cudaEventWaitExternal")
.allowlist_function("cudaStreamWaitEvent")
// memory management
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html
.allowlist_function("cudaFree")
.allowlist_function("cudaMalloc")
.allowlist_function("cudaMemcpy")
.allowlist_function("cudaMemcpyAsync")
.allowlist_function("cudaMemset")
.allowlist_function("cudaMemsetAsync")
.rustified_enum("cudaMemcpyKind")
// Stream Ordered Memory Allocator
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html
.allowlist_function("cudaFreeAsync")
.allowlist_function("cudaMallocAsync")
//
.generate()
.expect("Unable to generate bindings");
fs::write(PathBuf::from("src").join("bindings.rs"), bindings.to_string()).expect("Couldn't write bindings!");
}

View File

@@ -0,0 +1,27 @@
use crate::memory::CudaMemPool;
use crate::stream::CudaStream;
/// Properties of the device used in icicle functions.
#[repr(C)]
#[derive(Debug)]
pub struct DeviceContext<'a> {
/// Index of the currently used GPU. Default value: 0.
pub device_id: usize,
/// Stream to use. Default value: 0.
pub stream: &'a CudaStream, // Assuming the type is provided by a CUDA binding crate
/// Mempool to use. Default value: 0.
pub mempool: CudaMemPool, // Assuming the type is provided by a CUDA binding crate
}
pub fn get_default_device_context() -> DeviceContext<'static> {
static default_stream: CudaStream = CudaStream {
handle: std::ptr::null_mut(),
};
DeviceContext {
device_id: 0,
stream: &default_stream,
mempool: 0,
}
}

View File

@@ -0,0 +1,35 @@
use crate::bindings::{cudaError, cudaGetLastError};
use std::mem::MaybeUninit;
pub type CudaError = cudaError;
pub type CudaResult<T> = Result<T, CudaError>;
pub trait CudaResultWrap {
fn wrap(self) -> CudaResult<()>;
fn wrap_value<T>(self, value: T) -> CudaResult<T>;
fn wrap_maybe_uninit<T>(self, value: MaybeUninit<T>) -> CudaResult<T>;
}
impl CudaResultWrap for CudaError {
fn wrap(self) -> CudaResult<()> {
self.wrap_value(())
}
fn wrap_value<T>(self, value: T) -> CudaResult<T> {
if self == CudaError::cudaSuccess {
Ok(value)
} else {
Err(self)
}
}
fn wrap_maybe_uninit<T>(self, value: MaybeUninit<T>) -> CudaResult<T> {
self.wrap_value(value)
.map(|x| unsafe { x.assume_init() })
}
}
pub fn get_last_error() -> CudaError {
unsafe { cudaGetLastError() }
}

View File

@@ -0,0 +1,8 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
mod bindings;
pub mod device_context;
pub mod error;
pub mod memory;
pub mod stream;

View File

@@ -0,0 +1,156 @@
use crate::bindings::{cudaMalloc, cudaMallocAsync, cudaMemcpy, cudaMemcpyAsync, cudaMemcpyKind};
use crate::error::{CudaError, CudaResult, CudaResultWrap};
use crate::stream::CudaStream;
use std::mem::{size_of, MaybeUninit};
use std::os::raw::c_void;
use std::{ptr, slice};
/// Fixed-size device-side slice.
#[derive(Debug)]
#[repr(C)]
pub struct DeviceSlice<'a, T>(&'a mut [T]);
impl<'a, T> DeviceSlice<'a, T> {
pub fn len(&self) -> usize {
self.0
.len()
}
pub fn is_empty(&self) -> bool {
self.0
.is_empty()
}
pub fn as_slice(&mut self) -> &mut [T] {
self.0
}
pub fn as_ptr(&self) -> *const T {
self.0
.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.0
.as_mut_ptr()
}
pub fn cuda_malloc(count: usize) -> CudaResult<Self> {
let size = count
.checked_mul(size_of::<T>())
.unwrap_or(0);
if size == 0 {
return Err(CudaError::cudaErrorMemoryAllocation);
}
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
unsafe {
cudaMalloc(device_ptr.as_mut_ptr(), size).wrap()?;
Ok(DeviceSlice {
0: slice::from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
})
}
}
pub fn cuda_malloc_async(count: usize, stream: &mut CudaStream) -> CudaResult<Self> {
let size = count
.checked_mul(size_of::<T>())
.unwrap_or(0);
if size == 0 {
return Err(CudaError::cudaErrorMemoryAllocation);
}
let mut device_ptr = MaybeUninit::<*mut c_void>::uninit();
unsafe {
cudaMallocAsync(device_ptr.as_mut_ptr(), size, stream as *mut _ as *mut _).wrap()?;
Ok(DeviceSlice {
0: slice::from_raw_parts_mut(device_ptr.assume_init() as *mut T, count),
})
}
}
pub fn copy_from_host(&mut self, val: &[T]) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = size_of::<T>() * self.len();
if size != 0 {
unsafe {
cudaMemcpy(
self.as_mut_ptr() as *mut c_void,
val.as_ptr() as *const c_void,
size,
cudaMemcpyKind::cudaMemcpyHostToDevice,
)
.wrap()?
}
}
Ok(())
}
pub fn copy_to_host(&self, val: &mut [T]) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = size_of::<T>() * self.len();
if size != 0 {
unsafe {
cudaMemcpy(
val.as_mut_ptr() as *mut c_void,
self.as_ptr() as *const c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
.wrap()?
}
}
Ok(())
}
pub fn copy_from_host_async(&mut self, val: &[T], stream: &mut CudaStream) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = size_of::<T>() * self.len();
if size != 0 {
unsafe {
cudaMemcpyAsync(
self.as_mut_ptr() as *mut c_void,
val.as_ptr() as *const c_void,
size,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream as *mut _ as *mut _,
)
.wrap()?
}
}
Ok(())
}
pub fn copy_to_host_async(&self, val: &mut [T], stream: &mut CudaStream) -> CudaResult<()> {
assert!(
self.len() == val.len(),
"destination and source slices have different lengths"
);
let size = size_of::<T>() * self.len();
if size != 0 {
unsafe {
cudaMemcpyAsync(
val.as_mut_ptr() as *mut c_void,
self.as_ptr() as *const c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
stream as *mut _ as *mut _,
)
.wrap()?
}
}
Ok(())
}
}
#[allow(non_camel_case_types)]
pub type CudaMemPool = usize; // This is a placeholder, TODO: actually make this into a proper CUDA wrapper

View File

@@ -0,0 +1,74 @@
use crate::bindings::{
cudaStreamCreate, cudaStreamDefault, cudaStreamDestroy, cudaStreamNonBlocking, cudaStreamSynchronize, cudaStream_t,
};
use crate::error::{CudaResult, CudaResultWrap};
use bitflags::bitflags;
use std::mem::{forget, MaybeUninit};
#[repr(transparent)]
#[derive(Debug)]
pub struct CudaStream {
pub(crate) handle: cudaStream_t,
}
unsafe impl Sync for CudaStream {}
bitflags! {
pub struct CudaStreamCreateFlags: u32 {
const DEFAULT = cudaStreamDefault;
const NON_BLOCKING = cudaStreamNonBlocking;
}
}
impl CudaStream {
pub(crate) fn from_handle(handle: cudaStream_t) -> Self {
Self { handle }
}
pub fn create() -> CudaResult<Self> {
let mut handle = MaybeUninit::<cudaStream_t>::uninit();
unsafe {
cudaStreamCreate(handle.as_mut_ptr())
.wrap_maybe_uninit(handle)
.map(CudaStream::from_handle)
}
}
pub fn destroy(self) -> CudaResult<()> {
let handle = self.handle;
forget(self);
if handle.is_null() {
Ok(())
} else {
unsafe { cudaStreamDestroy(handle).wrap() }
}
}
pub fn synchronize(&self) -> CudaResult<()> {
unsafe { cudaStreamSynchronize(self.handle).wrap() }
}
}
impl Default for CudaStream {
fn default() -> Self {
Self {
handle: std::ptr::null_mut(),
}
}
}
impl Drop for CudaStream {
fn drop(&mut self) {
let handle = self.handle;
if handle.is_null() {
return;
}
let _ = unsafe { cudaStreamDestroy(handle) };
}
}
impl From<&CudaStream> for cudaStream_t {
fn from(stream: &CudaStream) -> Self {
stream.handle
}
}

View File

@@ -0,0 +1,29 @@
[package]
name = "icicle-bn254"
version = "0.1.0"
edition = "2021"
authors = [ "Ingonyama" ]
description = "Rust wrapper for the CUDA implementation of BN254 pairing friendly elliptic curve by Ingonyama"
homepage = "https://www.ingonyama.com"
repository = "https://github.com/ingonyama-zk/icicle"
[dependencies]
icicle-core = { path = "../../icicle-core" }
icicle-cuda-runtime = { path = "../../icicle-cuda-runtime" }
ark-bn254 = { version = "0.4.0", optional = true }
[build-dependencies]
cmake = "*"
[dev-dependencies]
ark-bn254 = "0.4.0"
ark-std = "0.4.0"
ark-ff = "0.4.0"
ark-ec = "0.4.0"
ark-poly = "0.4.0"
icicle-core = { path = "../../icicle-core", features = ["arkworks"] }
icicle-bn254 = { path = ".", features = ["arkworks"] }
[features]
default = []
arkworks = ["ark-bn254", "icicle-core/arkworks"]

View File

@@ -0,0 +1,30 @@
use cmake::Config;
use std::env::var;
fn main() {
println!("cargo:rerun-if-env-changed=CXXFLAGS");
println!("cargo:rerun-if-changed=../../../../icicle");
let cargo_dir = var("CARGO_MANIFEST_DIR").unwrap();
let profile = var("PROFILE").unwrap();
let target_output_dir = format!("{}/../../target/{}", cargo_dir, profile);
Config::new("./icicle")
.define("BUILD_TESTS", "OFF") //TODO: feature
// .define("CURVE", "bls12_381")
.define("CURVE", "bn254")
// .define("ECNTT_DEFINED", "") //TODO: feature
.define("LIBRARY_OUTPUT_DIRECTORY", &target_output_dir)
.define("CMAKE_BUILD_TYPE", "Release")
.build_target("icicle")
.build();
println!("cargo:rustc-link-search={}", &target_output_dir);
// println!("cargo:rustc-link-lib=icicle");
println!("cargo:rustc-link-lib=ingo_bn254");
println!("cargo:rustc-link-lib=stdc++");
// println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-lib=cudart");
}

View File

@@ -0,0 +1,151 @@
#[cfg(feature = "arkworks")]
use ark_bn254::{g1::Config as ArkG1Config, Fq, Fr};
use icicle_core::curve::{Affine, CurveConfig, Projective};
use icicle_core::field::{Field, FieldConfig};
use std::ffi::{c_uint, c_void};
#[derive(Debug, PartialEq, Copy, Clone)]
pub struct ScalarCfg {}
impl FieldConfig for ScalarCfg {
#[cfg(feature = "arkworks")]
type ArkField = Fr;
}
const SCALAR_LIMBS: usize = 8;
pub type ScalarField = Field<SCALAR_LIMBS, ScalarCfg>;
extern "C" {
fn GenerateScalars(scalars: *mut ScalarField, size: usize);
}
pub(crate) fn generate_random_scalars(size: usize) -> Vec<ScalarField> {
let mut res = vec![ScalarField::zero(); size];
unsafe { GenerateScalars(&mut res[..] as *mut _ as *mut ScalarField, size) };
res
}
#[derive(Debug, PartialEq, Copy, Clone)]
pub struct BaseCfg {}
impl FieldConfig for BaseCfg {
#[cfg(feature = "arkworks")]
type ArkField = Fq;
}
pub const BASE_LIMBS: usize = 8;
#[derive(Debug, PartialEq, Copy, Clone)]
pub struct CurveCfg {}
pub type BaseField = Field<BASE_LIMBS, BaseCfg>;
pub type G1Affine = Affine<BaseField, CurveCfg>;
pub type G1Projective = Projective<BaseField, CurveCfg>;
extern "C" {
fn Eq(point1: *const c_void, point2: *const c_void) -> c_uint;
fn ToAffine(point: *const c_void, point_out: *mut c_void);
fn GenerateProjectivePoints(points: *mut G1Projective, size: usize);
fn GenerateAffinePoints(points: *mut G1Affine, size: usize);
}
impl CurveConfig for CurveCfg {
fn eq_proj(point1: *const c_void, point2: *const c_void) -> c_uint {
unsafe { Eq(point1, point2) }
}
fn to_affine(point: *const c_void, point_out: *mut c_void) {
unsafe { ToAffine(point, point_out) };
}
#[cfg(feature = "arkworks")]
type ArkSWConfig = ArkG1Config;
}
pub(crate) fn generate_random_projective_points(size: usize) -> Vec<G1Projective> {
let mut res = vec![G1Projective::zero(); size];
unsafe { GenerateProjectivePoints(&mut res[..] as *mut _ as *mut G1Projective, size) };
res
}
pub(crate) fn generate_random_affine_points(size: usize) -> Vec<G1Affine> {
let mut res = vec![G1Affine::zero(); size];
unsafe { GenerateAffinePoints(&mut res[..] as *mut _ as *mut G1Affine, size) };
res
}
#[cfg(test)]
mod tests {
use super::{
generate_random_affine_points, generate_random_projective_points, generate_random_scalars, BaseField, G1Affine,
G1Projective, ScalarField, BASE_LIMBS,
};
use icicle_core::traits::ArkConvertible;
use ark_bn254::G1Affine as ArkG1Affine;
#[test]
fn test_scalar_equality() {
let left = ScalarField::zero();
let right = ScalarField::one();
assert_ne!(left, right);
let left = ScalarField::set_limbs(&[1]);
assert_eq!(left, right);
}
#[test]
fn test_ark_scalar_convert() {
let size = 1 << 10;
let scalars = generate_random_scalars(size);
for scalar in scalars {
assert_eq!(scalar.to_ark(), scalar.to_ark())
}
}
#[test]
fn test_affine_projective_convert() {
let size = 1 << 10;
let affine_points = generate_random_affine_points(size);
let projective_points = generate_random_projective_points(size);
for affine_point in affine_points {
let projective_eqivalent: G1Projective = affine_point.into();
assert_eq!(affine_point, projective_eqivalent.into());
}
for projective_point in projective_points {
let affine_eqivalent: G1Affine = projective_point.into();
assert_eq!(projective_point, affine_eqivalent.into());
}
}
#[test]
fn test_point_equality() {
let left = G1Projective::zero();
let right = G1Projective::zero();
assert_eq!(left, right);
let right = G1Projective::set_limbs(&[0; BASE_LIMBS], &[2; BASE_LIMBS], &[0; BASE_LIMBS]);
assert_eq!(left, right);
let right = G1Projective::set_limbs(
&[0; BASE_LIMBS],
&[4; BASE_LIMBS],
&BaseField::set_limbs(&[2]).get_limbs(),
);
assert_ne!(left, right);
let left = G1Projective::set_limbs(&[0; BASE_LIMBS], &[2; BASE_LIMBS], &BaseField::one().get_limbs());
assert_eq!(left, right);
}
#[test]
fn test_ark_point_convert() {
let size = 1 << 10;
let affine_points = generate_random_affine_points(size);
for affine_point in affine_points {
let ark_projective = Into::<G1Projective>::into(affine_point).to_ark();
let ark_affine: ArkG1Affine = ark_projective.into();
assert!(ark_affine.is_on_curve());
assert!(ark_affine.is_in_correct_subgroup_assuming_on_curve());
let affine_after_conversion: G1Affine = G1Projective::from_ark(ark_projective).into();
assert_eq!(affine_point, affine_after_conversion);
}
}
}

View File

@@ -0,0 +1,3 @@
pub mod curve;
pub mod msm;
pub mod ntt;

View File

@@ -0,0 +1,101 @@
use crate::curve::{G1Affine, G1Projective, ScalarField};
use icicle_core::msm::MSMConfig;
use icicle_cuda_runtime::error::{CudaError, CudaResult, CudaResultWrap};
extern "C" {
#[link_name = "bn254MSMCuda"]
fn msm_cuda<'a>(
scalars: *const ScalarField,
points: *const G1Affine,
count: usize,
config: MSMConfig<'a>,
out: *mut G1Projective,
) -> CudaError;
#[link_name = "bn254GetDefaultMSMConfig"]
fn GetDefaultMSMConfig() -> MSMConfig<'static>;
}
pub fn get_default_msm_config() -> MSMConfig<'static> {
unsafe { GetDefaultMSMConfig() }
}
pub fn msm<'a>(
scalars: &[ScalarField],
points: &[G1Affine],
cfg: MSMConfig<'a>,
results: &mut [G1Projective],
) -> CudaResult<()> {
if points.len() != scalars.len() {
return Err(CudaError::cudaErrorInvalidValue);
}
unsafe {
msm_cuda(
scalars as *const _ as *const ScalarField,
points as *const _ as *const G1Affine,
points.len(),
cfg,
results as *mut _ as *mut G1Projective,
)
.wrap()
}
}
#[cfg(test)]
pub(crate) mod tests {
use ark_bn254::G1Projective as ArkG1Projective;
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
use crate::{
curve::{generate_random_affine_points, generate_random_scalars, G1Projective},
msm::{get_default_msm_config, msm},
};
use icicle_core::traits::ArkConvertible;
use icicle_cuda_runtime::memory::DeviceSlice;
use icicle_cuda_runtime::stream::CudaStream;
#[test]
fn test_msm() {
let log_test_sizes = [20];
for log_test_size in log_test_sizes {
let count = 1 << log_test_size;
let points = generate_random_affine_points(count);
let scalars = generate_random_scalars(count);
let mut msm_results = DeviceSlice::cuda_malloc(1).unwrap();
let stream = CudaStream::create().unwrap();
let mut cfg = get_default_msm_config();
cfg.ctx
.stream = &stream;
cfg.is_async = true;
cfg.are_results_on_device = true;
msm(&scalars, &points, cfg, &mut msm_results.as_slice()).unwrap();
// this happens on CPU in parallel to the GPU MSM computations
let point_r_ark: Vec<_> = points
.iter()
.map(|x| x.to_ark())
.collect();
let scalars_r_ark: Vec<_> = scalars
.iter()
.map(|x| x.to_ark())
.collect();
let msm_result_ark: ArkG1Projective = VariableBaseMSM::msm(&point_r_ark, &scalars_r_ark).unwrap();
let mut msm_host_result = vec![G1Projective::zero(); 1];
msm_results
.copy_to_host(&mut msm_host_result[..])
.unwrap();
stream
.synchronize()
.unwrap();
stream
.destroy()
.unwrap();
assert_eq!(msm_host_result[0].to_ark(), msm_result_ark);
}
}
}

View File

@@ -0,0 +1,61 @@
use std::os::raw::c_int;
use crate::curve::*;
use icicle_core::ntt::{Butterfly, Decimation, NTTConfigCuda, Ordering};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
pub(super) type ECNTTConfig<'a> = NTTConfigCuda<'a, G1Projective, ScalarField>;
pub(super) type NTTConfig<'a> = NTTConfigCuda<'a, ScalarField, ScalarField>;
pub(super) fn get_ntt_config<E, S>(size: usize, ctx: DeviceContext) -> NTTConfigCuda<E, S> {
//TODO: implement on CUDA side
NTTConfigCuda::<E, S> {
inout: 0 as _, // inout as *mut _ as *mut ScalarField,
is_input_on_device: false,
is_inverse: false,
ordering: Ordering::kNN,
decimation: Decimation::kDIF,
butterfly: Butterfly::kCooleyTukey,
is_coset: false,
coset_gen: 0 as _, //TODO: ?
twiddles: 0 as _, //TODO: ?,
inv_twiddles: 0 as _, //TODO: ?,
size: size as i32,
batch_size: 0 as i32,
is_preserving_twiddles: true,
is_output_on_device: false,
ctx,
}
}
pub(super) fn get_ntt_default_config<E, S>(size: usize) -> NTTConfigCuda<'static, E, S> {
//TODO: implement on CUDA side
let ctx = get_default_device_context();
// let root_of_unity = S::default(); //TODO: implement on CUDA side
let config = get_ntt_config(size, ctx);
config
}
pub(super) fn get_ntt_config_with_input(ntt_intt_result: &mut [ScalarField], size: usize, batches: usize) -> NTTConfig {
NTTConfig {
inout: ntt_intt_result as *mut _ as *mut ScalarField,
is_input_on_device: false,
is_inverse: false,
ordering: Ordering::kNN,
decimation: Decimation::kDIF,
butterfly: Butterfly::kCooleyTukey,
is_coset: false,
coset_gen: &[ScalarField::zero()] as _, //TODO: ?
twiddles: 0 as *const ScalarField, //TODO: ?,
inv_twiddles: 0 as *const ScalarField, //TODO: ?,
size: size as _,
batch_size: batches as i32,
is_preserving_twiddles: true,
is_output_on_device: true,
ctx: get_default_device_context(),
}
}

View File

@@ -0,0 +1,192 @@
use icicle_core::ntt::{Butterfly, Decimation, NTTConfigCuda, Ordering};
use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::memory::DeviceSlice;
use std::default;
pub(super) type ECNTTDomain<'a> = Domain<'a, G1Projective, ScalarField>;
pub(super) type NTTDomain<'a> = Domain<'a, ScalarField, ScalarField>;
use crate::curve::*;
use super::{config::*, ntt_internal};
/// Represents the NTT domain
pub struct Domain<'a, E, S> {
config: NTTConfigCuda<'a, E, S>,
}
impl<'a, E, S> Domain<'a, E, S> {
pub fn new(size: usize, ctx: DeviceContext<'a>) -> Self {
Domain {
config: get_ntt_config(size, ctx),
}
}
pub fn get_output_on_device(&self) -> Result<*mut E, &'static str> {
if self
.config
.is_output_on_device
{
Ok(self
.config
.inout)
} else {
Err("Output should be on device.")
}
}
pub fn get_input_on_device(&self) -> Result<*mut E, &'static str> {
if self
.config
.is_input_on_device
{
Ok(self
.config
.inout)
} else {
Err("Input should be on device.")
}
}
pub fn get_input(&self) -> Result<*mut E, &'static str> {
if !self
.config
.is_input_on_device
{
Ok(self
.config
.inout)
} else {
Err("Output is on device.")
}
}
pub fn get_output(&self) -> Result<*mut E, &'static str> {
if !self
.config
.is_output_on_device
{
Ok(self
.config
.inout)
} else {
Err("Output is on device.")
}
}
pub(crate) fn new_for_default_context(size: usize) -> Self {
let ctx = get_default_device_context();
// let default_root_of_unity = S::default(); //TODO: implement
let domain = Domain::new(size, ctx);
domain
}
}
// Add implementations for other methods and structs as needed.
impl<'a, E: 'static, S: 'static> Domain<'a, E, S> {
// ... previous methods ...
// NTT methods
pub fn ntt(&mut self, inout: &mut [E]) {
let batch_size = 1;
let size = inout.len();
if size
!= self
.config
.size as _
{
//TODO: test for this error
panic!(
"input lenght: {} does not match domain size: {}",
size,
self.config
.size
)
}
self.config
.inout = inout.as_mut_ptr(); // as *mut _ as *mut E;
self.config
.is_inverse = false;
self.config
.is_input_on_device = false;
self.config
.is_output_on_device = false;
// self.config
// .ordering = Ordering::default(); //TODO: each call?
self.config
.batch_size = batch_size as i32;
ntt_internal(&mut self.config);
}
pub fn ntt_on_device(&mut self, inout: &mut DeviceSlice<E>) {
// Implementation for NTT on device
}
pub fn ntt_batch(&mut self, inout: &mut [E]) {
// Implementation for batched NTT
}
pub fn ntt_batch_on_device(&mut self, inout: &mut DeviceSlice<E>) {
// Implementation for batched NTT on device
}
pub fn ntt_coset(&mut self, inout: &mut [E], coset: &mut [E]) {
// Implementation for NTT with coset
}
pub fn ntt_coset_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
// Implementation for NTT with coset on device
}
pub fn ntt_coset_batch(&mut self, inout: &mut [E], coset: &mut [E]) {
// Implementation for batched NTT with coset
}
pub fn ntt_coset_batch_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
// Implementation for batched NTT with coset on device
}
// iNTT methods
pub fn intt(&mut self, inout: &mut [E]) {
// Implementation for iNTT
}
pub fn intt_on_device(&mut self, inout: &mut DeviceSlice<E>) {
// Implementation for iNTT on device
}
pub fn intt_batch(&mut self, inout: &mut [E]) {
// Implementation for batched iNTT
}
pub fn intt_batch_on_device(&mut self, inout: &mut DeviceSlice<E>) {
// Implementation for batched iNTT on device
}
pub fn intt_coset(&mut self, inout: &mut [E], coset: &mut [E]) {
// Implementation for iNTT with coset
}
pub fn intt_coset_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
// Implementation for iNTT with coset on device
}
pub fn intt_coset_batch(&mut self, inout: &mut [E], coset: &mut [E]) {
// Implementation for batched iNTT with coset
}
pub fn intt_coset_batch_on_device(&mut self, inout: &mut DeviceSlice<E>, coset: &mut DeviceSlice<E>) {
// Implementation for batched iNTT with coset on device
}
// Ordering setter
pub fn set_ordering(&mut self, ordering: Ordering) {
self.config
.ordering = ordering;
}
}

View File

@@ -0,0 +1,322 @@
mod config;
pub mod domain;
use std::any::TypeId;
use crate::curve::*;
use self::config::*;
use icicle_core::ntt::{Butterfly, Decimation, Ordering};
use icicle_cuda_runtime::error::CudaError;
extern "C" {
#[link_name = "NTTDefaultContextCuda"]
fn ntt_cuda(config: *mut NTTConfig) -> CudaError;
}
pub(crate) fn ntt_wip(
inout: &mut [ScalarField],
is_inverse: bool,
is_input_on_device: bool,
ordering: Ordering,
is_output_on_device: bool,
batch_size: usize,
) {
let mut batch_size = batch_size;
if batch_size == 0 {
batch_size = 1;
}
let size = inout.len() / batch_size;
let mut config = get_ntt_default_config::<ScalarField, ScalarField>(size);
config.inout = inout as *mut _ as *mut ScalarField;
config.is_inverse = is_inverse;
config.is_input_on_device = is_input_on_device;
config.is_output_on_device = is_output_on_device;
config.ordering = ordering;
config.batch_size = batch_size as i32;
ntt_internal(&mut config);
}
pub(self) fn ntt_internal<TConfig>(config: *mut TConfig) -> CudaError {
let result_code = unsafe { ntt_cuda(config as _) };
// let typeid = TypeId::of::<TConfig>();
// if typeid == TypeId::of::<NTTConfig>() {
// result_code = unsafe { ntt_cuda(config as _) };
// } else {
// result_code = CudaError::cudaSuccess; //TODO: unsafe { ecntt_cuda(config as _) };
// }
// if result_code != CudaError::cudaSuccess {
// println!("_result_code = {:?}", result_code);
// }
return CudaError::cudaSuccess;
}
pub(self) fn ecntt_internal(config: *mut ECNTTConfig) -> u32 {
let result_code = 0; //TODO: unsafe { ecntt_cuda(config) };
if result_code != 0 {
println!("_result_code = {}", result_code);
}
return result_code;
}
#[cfg(test)]
pub(crate) mod tests {
use ark_bn254::{Fr, G1Affine as arkG1Affine, G1Projective as arkG1Projective};
// use ark_bls12_381::{Fr, G1Projective};
use ark_ff::PrimeField;
use ark_poly::EvaluationDomain;
use ark_poly::GeneralEvaluationDomain;
use ark_std::UniformRand;
use std::slice;
use crate::ntt::domain::NTTDomain;
use crate::{curve::*, ntt::*};
use icicle_core::traits::ArkConvertible;
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
fn is_power_of_two(n: u32) -> bool {
n != 0 && n & (n - 1) == 0
}
assert!(is_power_of_two(order));
let mask = order - 1;
let binary = format!("{:0width$b}", n, width = (32 - mask.leading_zeros()) as usize);
let reversed = binary
.chars()
.rev()
.collect::<String>();
u32::from_str_radix(&reversed, 2).unwrap()
}
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
l.iter()
.enumerate()
.map(|(i, _)| l[reverse_bit_order(i as u32, l.len() as u32) as usize])
.collect()
}
#[test]
fn test_ntt() {
//NTT
let test_size = 1 << 11;
let batches = 1;
let full_test_size = test_size * batches;
let scalars_batch: Vec<ScalarField> = generate_random_scalars(full_test_size);
// let scalars_batch: Vec<ScalarField> = (0..full_test_size)
// .into_iter()
// .map(|x| {
// // if x % 1 == 0 {
// if x % 2 == 0 {
// ScalarField::one()
// } else {
// ScalarField::zero()
// }
// })
// .collect();
let mut ntt_result = scalars_batch.clone();
let ark_domain = GeneralEvaluationDomain::<Fr>::new(test_size).unwrap();
let mut domain = NTTDomain::new_for_default_context(test_size);
let ark_scalars_batch = scalars_batch
.clone()
.iter()
.map(|v| v.to_ark())
.collect::<Vec<Fr>>();
let mut ark_ntt_result = ark_scalars_batch.clone();
ark_domain.fft_in_place(&mut ark_ntt_result);
assert_ne!(ark_ntt_result, ark_scalars_batch);
// do ntt
// ntt_wip(&mut ntt_result, false, false, Ordering::kNN, false, batches);
domain.ntt(&mut ntt_result); //single ntt
let ntt_result_as_ark = ntt_result
.iter()
.map(|p| p.to_ark())
.collect::<Vec<Fr>>();
assert_ne!(ntt_result, scalars_batch);
assert_eq!(ark_ntt_result, ntt_result_as_ark);
let mut ark_intt_result = ark_ntt_result;
ark_domain.ifft_in_place(&mut ark_intt_result);
assert_eq!(ark_intt_result, ark_scalars_batch);
// check that ntt output is different from input
assert_ne!(ntt_result, scalars_batch);
// do intt
let mut intt_result = ntt_result;
ntt_wip(&mut intt_result, true, false, Ordering::kNN, false, batches);
assert!(ark_intt_result == ark_scalars_batch);
assert!(intt_result == scalars_batch);
let mut ntt_intt_result = intt_result;
ntt_wip(&mut ntt_intt_result, false, false, Ordering::kNR, false, batches);
assert!(ntt_intt_result != scalars_batch);
ntt_wip(&mut ntt_intt_result, true, false, Ordering::kRN, false, batches);
assert!(ntt_intt_result == scalars_batch);
let mut ntt_intt_result = list_to_reverse_bit_order(&ntt_intt_result);
ntt_wip(&mut ntt_intt_result, false, false, Ordering::kRR, false, batches);
assert!(ntt_intt_result != scalars_batch);
ntt_wip(&mut ntt_intt_result, true, false, Ordering::kRN, false, batches);
assert!(ntt_intt_result == scalars_batch);
////
let size = ntt_intt_result.len() / batches;
let mut config = get_ntt_config_with_input(&mut ntt_intt_result, size, batches);
ntt_internal(&mut config);
//host
let mut ntt_result = scalars_batch.clone();
ntt_wip(&mut ntt_result, false, false, Ordering::kNR, false, batches);
// let mut buff1 = DeviceBuffer::from_slice(&scalars_batch[..]).unwrap();
// let dev_ptr1 = buff1
// .as_device_ptr()
// .as_raw_mut();
// let buff_len = buff1.len();
// std::mem::forget(buff1);
// let buff_from_dev_ptr = unsafe { DeviceBuffer::from_raw_parts(DevicePointer::wrap(dev_ptr1), buff_len) };
// let mut from_device = vec![ScalarField::zero(); scalars_batch.len()];
// buff_from_dev_ptr
// .copy_to(&mut from_device)
// .unwrap();
// assert_eq!(from_device, scalars_batch);
// host - device - device - host
let mut ntt_intt_result = scalars_batch.clone();
let mut config = get_ntt_config_with_input(&mut ntt_intt_result, size, batches);
config.is_input_on_device = false;
config.is_output_on_device = true;
// config.is_preserving_twiddles = true; // TODO: same as in get_ntt_config
config.ordering = Ordering::kNR;
ntt_internal(&mut config); //twiddles are preserved after first call
// config.is_preserving_twiddles = true; //TODO: same as in get_ntt_config
config.is_inverse = true;
config.is_input_on_device = false;
config.is_output_on_device = true;
config.ordering = Ordering::kNR;
ntt_internal(&mut config); //inv_twiddles are preserved after first call
let ntt_intt_result = &mut scalars_batch.clone()[..];
let raw_scalars_batch_copy = ntt_intt_result as *mut _ as *mut ScalarField;
let config_inout2: &mut [ScalarField] =
unsafe { std::slice::from_raw_parts_mut(raw_scalars_batch_copy, config.size as usize) };
assert_eq!(config_inout2, scalars_batch);
config.is_preserving_twiddles = true; //TODO: same as in get_ntt_config
config.inout = raw_scalars_batch_copy;
config.is_inverse = false;
config.is_input_on_device = false;
config.is_output_on_device = true;
config.ordering = Ordering::kNR;
ntt_internal(&mut config);
config.is_inverse = true;
config.is_input_on_device = true;
config.is_output_on_device = false;
config.ordering = Ordering::kRN;
ntt_internal(&mut config);
let result_from_device: &mut [ScalarField] =
unsafe { std::slice::from_raw_parts_mut(config.inout, scalars_batch.len()) };
assert_eq!(result_from_device, &scalars_batch);
}
#[test]
fn test_batch_ntt() {
//NTT
let test_size = 1 << 11;
let batches = 2;
let full_test_size = test_size * batches;
let scalars_batch: Vec<ScalarField> = generate_random_scalars(full_test_size);
let mut scalar_vec_of_vec: Vec<Vec<ScalarField>> = Vec::new();
for i in 0..batches {
scalar_vec_of_vec.push(scalars_batch[i * test_size..(i + 1) * test_size].to_vec());
}
let mut ntt_result = scalars_batch.clone();
// do batch ntt
ntt_wip(&mut ntt_result, false, false, Ordering::kNN, false, batches);
let mut ntt_result_vec_of_vec = Vec::new();
// do ntt for every chunk
for i in 0..batches {
ntt_result_vec_of_vec.push(scalar_vec_of_vec[i].clone());
ntt_wip(&mut ntt_result_vec_of_vec[i], false, false, Ordering::kNN, false, 1);
}
// check that the ntt of each vec of scalars is equal to the ntt of the specific batch
for i in 0..batches {
assert_eq!(ntt_result_vec_of_vec[i], ntt_result[i * test_size..(i + 1) * test_size]);
}
// check that ntt output is different from input
assert_ne!(ntt_result, scalars_batch);
let mut intt_result = ntt_result.clone();
// do batch intt
// intt_batch(&mut intt_result, test_size, 0);
ntt_wip(&mut intt_result, true, false, Ordering::kNN, false, batches);
let mut intt_result_vec_of_vec = Vec::new();
// do intt for every chunk
for i in 0..batches {
intt_result_vec_of_vec.push(ntt_result_vec_of_vec[i].clone());
// intt(&mut intt_result_vec_of_vec[i], 0);
ntt_wip(&mut intt_result_vec_of_vec[i], true, false, Ordering::kNN, false, 1);
}
// check that the intt of each vec of scalars is equal to the intt of the specific batch
for i in 0..batches {
assert_eq!(
intt_result_vec_of_vec[i],
intt_result[i * test_size..(i + 1) * test_size]
);
}
assert_eq!(intt_result, scalars_batch);
}
}