mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 20:37:55 -05:00
Impl sumcheck prover
This commit is contained in:
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,3 +1,3 @@
|
||||
[submodule "sumcheck/src/cuda/includes/barretenberg"]
|
||||
path = sumcheck/src/cuda/includes/barretenberg
|
||||
[submodule "sumcheck/src/gpu/cuda/includes/barretenberg"]
|
||||
path = sumcheck/src/gpu/cuda/includes/barretenberg
|
||||
url = https://github.com/pseXperiments/barretenberg_cuda.git
|
||||
|
||||
@@ -8,9 +8,8 @@ use regex::Regex;
|
||||
|
||||
fn main() {
|
||||
// Tell cargo to invalidate the built crate whenever files of interest changes.
|
||||
println!("cargo:rerun-if-changed=src/cuda/kernels/multilinear.cu");
|
||||
println!("cargo:rerun-if-changed=src/cuda/kernels/sumcheck.cu");
|
||||
println!("cargo:rerun-if-changed=src/cuda/kernels/scalar_multiplication.cu");
|
||||
println!("cargo:rerun-if-changed=src/gpu/cuda/kernels/multilinear.cu");
|
||||
println!("cargo:rerun-if-changed=src/gpu/cuda/kernels/sumcheck.cu");
|
||||
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
|
||||
@@ -22,11 +21,11 @@ fn main() {
|
||||
|
||||
// build the cuda kernels
|
||||
let cuda_src = [
|
||||
"src/cuda/kernels/multilinear.cu",
|
||||
"src/cuda/kernels/scalar_multiplication.cu",
|
||||
"src/gpu/cuda/kernels/multilinear.cu",
|
||||
"src/gpu/cuda/kernels/sumcheck.cu",
|
||||
]
|
||||
.map(|path| PathBuf::from(path));
|
||||
let ptx_file = ["multilinear.ptx", "scalar_multiplication.ptx"].map(|file| out_dir.join(file));
|
||||
let ptx_file = ["multilinear.ptx", "sumcheck.ptx"].map(|file| out_dir.join(file));
|
||||
|
||||
for (cuda_src, ptx_file) in cuda_src.into_iter().zip(ptx_file) {
|
||||
let nvcc_status = Command::new("nvcc")
|
||||
@@ -55,7 +54,7 @@ fn main() {
|
||||
let bindings = bindgen::Builder::default()
|
||||
// The input header we would like to generate
|
||||
// bindings for.
|
||||
.header("src/cuda/includes/wrapper.h")
|
||||
.header("src/gpu/cuda/includes/wrapper.h")
|
||||
// Tell cargo to invalidate the built crate whenever any of the
|
||||
// included header files changed.
|
||||
.parse_callbacks(Box::new(CargoCallbacks))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use ff::{BatchInvert, Field};
|
||||
use itertools::Itertools;
|
||||
use num_integer::Integer;
|
||||
|
||||
pub fn usize_from_bits_le(bits: &[bool]) -> usize {
|
||||
@@ -9,3 +11,45 @@ pub fn usize_from_bits_le(bits: &[bool]) -> usize {
|
||||
pub fn div_ceil(dividend: usize, divisor: usize) -> usize {
|
||||
Integer::div_ceil(÷nd, &divisor)
|
||||
}
|
||||
|
||||
pub fn barycentric_weights<F: Field>(points: &[F]) -> Vec<F> {
|
||||
let mut weights = points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, point_j)| {
|
||||
points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| i != &j)
|
||||
.map(|(_, point_i)| *point_j - point_i)
|
||||
.reduce(|acc, value| acc * &value)
|
||||
.unwrap_or(F::ONE)
|
||||
})
|
||||
.collect_vec();
|
||||
weights.batch_invert();
|
||||
weights
|
||||
}
|
||||
|
||||
pub fn inner_product<'a, 'b, F: Field>(
|
||||
lhs: impl IntoIterator<Item = &'a F>,
|
||||
rhs: impl IntoIterator<Item = &'b F>,
|
||||
) -> F {
|
||||
lhs.into_iter()
|
||||
.zip_eq(rhs)
|
||||
.map(|(lhs, rhs)| *lhs * rhs)
|
||||
.reduce(|acc, product| acc + product)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn barycentric_interpolate<F: Field>(weights: &[F], points: &[F], evals: &[F], x: &F) -> F {
|
||||
let (coeffs, sum_inv) = {
|
||||
let mut coeffs = points.iter().map(|point| *x - point).collect_vec();
|
||||
coeffs.batch_invert();
|
||||
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
|
||||
*coeff *= weight;
|
||||
});
|
||||
let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff);
|
||||
(coeffs, sum_inv.invert().unwrap())
|
||||
};
|
||||
inner_product(&coeffs, evals) * &sum_inv
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// https://github.com/han0110/plonkish/blob/main/plonkish_backend/src/poly/multilinear.rs
|
||||
mod arithmetic;
|
||||
pub mod multilinear;
|
||||
mod parallel;
|
||||
pub mod parallel;
|
||||
pub mod sumcheck;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::arithmetic::div_ceil;
|
||||
|
||||
fn num_threads() -> usize {
|
||||
return rayon::current_num_threads();
|
||||
rayon::current_num_threads()
|
||||
}
|
||||
|
||||
fn parallelize_iter<I, T, F>(iter: I, f: F)
|
||||
|
||||
76
sumcheck/src/cpu/sumcheck.rs
Normal file
76
sumcheck/src/cpu/sumcheck.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::cpu::{arithmetic::barycentric_weights, parallel::parallelize};
|
||||
|
||||
use super::arithmetic::barycentric_interpolate;
|
||||
|
||||
pub(crate) fn eval_at_k_and_combine<F: PrimeField>(
|
||||
num_vars: usize,
|
||||
polys: &[&[F]],
|
||||
combine_function: &impl Fn(&Vec<F>) -> F,
|
||||
k: F,
|
||||
) -> F {
|
||||
let evals = (0..1 << (num_vars - 1))
|
||||
.map(|idx| {
|
||||
let args = polys
|
||||
.iter()
|
||||
.map(|poly| k * (poly[idx + (1 << (num_vars - 1))] - poly[idx]) + poly[idx])
|
||||
.collect_vec();
|
||||
combine_function(&args)
|
||||
})
|
||||
.collect_vec();
|
||||
evals.into_iter().sum()
|
||||
}
|
||||
|
||||
pub(crate) fn fold_into_half_in_place<F: PrimeField>(poly: &mut [F], challenge: F) {
|
||||
let (poly0, poly1) = poly.split_at_mut(poly.len() >> 1);
|
||||
let poly1 = &*poly1;
|
||||
parallelize(poly0, |(poly0, start)| {
|
||||
poly0
|
||||
.iter_mut()
|
||||
.zip(poly1.iter().skip(start))
|
||||
.for_each(|(eval0, eval1)| {
|
||||
*eval0 = challenge * (*eval1 - *eval0) + *eval0;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn verify_sumcheck<F: PrimeField>(
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
sum: F,
|
||||
challenges: &[F],
|
||||
evals: &[&[F]],
|
||||
) -> bool {
|
||||
let points_vec: Vec<F> = (0..max_degree + 1)
|
||||
.map(|i| F::from_u128(i as u128))
|
||||
.collect();
|
||||
let weights = barycentric_weights(&points_vec);
|
||||
let mut expected_sum = sum;
|
||||
for round_index in 0..num_vars {
|
||||
if evals[round_index].len() != max_degree + 1 {
|
||||
return false;
|
||||
}
|
||||
let round_poly_eval_at_0 = evals[round_index][0];
|
||||
let round_poly_eval_at_1 = evals[round_index][1];
|
||||
let computed_sum = round_poly_eval_at_0 + round_poly_eval_at_1;
|
||||
|
||||
// Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1)
|
||||
if computed_sum != expected_sum {
|
||||
println!("computed_sum : {:?}", computed_sum);
|
||||
println!("expected_sum : {:?}", expected_sum);
|
||||
println!("round index : {}", round_index);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute r_{i}(α_i) using barycentric interpolation
|
||||
expected_sum = barycentric_interpolate(
|
||||
&weights,
|
||||
&points_vec,
|
||||
evals[round_index],
|
||||
&challenges[round_index],
|
||||
);
|
||||
}
|
||||
true
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
#include "prime_field.h"
|
||||
@@ -1,47 +0,0 @@
|
||||
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
|
||||
#include <stdio.h>
|
||||
|
||||
using namespace bb;
|
||||
|
||||
__device__ fr merge(fr* evals, fr* point, uint8_t point_index, u_int32_t chunk_size) {
|
||||
const int start = chunk_size * (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
auto step = 2;
|
||||
while (chunk_size > 1) {
|
||||
for (int i = 0; i < chunk_size / 2; i++) {
|
||||
auto fst = start + step * i;
|
||||
auto snd = fst + step / 2;
|
||||
evals[fst] = point[point_index] * (evals[snd] - evals[fst]) + evals[fst];
|
||||
}
|
||||
chunk_size >>= 1;
|
||||
step <<= 1;
|
||||
point_index++;
|
||||
}
|
||||
return evals[start];
|
||||
}
|
||||
|
||||
extern "C" __global__ void eval(fr* evals, fr* buf, fr* point, u_int32_t size, u_int32_t chunk_size, uint8_t offset) {
|
||||
const int tid = threadIdx.x;
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint8_t log2_chunk_size = log2(chunk_size);
|
||||
u_int32_t num_threads = ceil(size / chunk_size);
|
||||
auto i = offset;
|
||||
while (num_threads > 0) {
|
||||
if (tid < num_threads) {
|
||||
buf[idx] = merge(evals, point, i, chunk_size);
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
memcpy(&evals[chunk_size * blockIdx.x * blockDim.x], &buf[blockIdx.x * blockDim.x], num_threads * 32);
|
||||
}
|
||||
i += log2_chunk_size;
|
||||
num_threads >>= log2_chunk_size;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void convert_to_montgomery(fr* evals, u_int32_t size, u_int32_t chunk_size) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
evals[chunk_size * idx + i].self_to_montgomery_form();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
|
||||
|
||||
using namespace bb;
|
||||
|
||||
extern "C" __global__ void mul(fr* elems, fr* results) {
|
||||
fr temp = elems[0] * elems[1];
|
||||
results[threadIdx.x] = temp.from_montgomery_form();
|
||||
return;
|
||||
}
|
||||
@@ -65,4 +65,3 @@ macro_rules! field_binding_conversion {
|
||||
}
|
||||
|
||||
field_binding_conversion!(Fr);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef __PRIME_FIELD_H__
|
||||
#define __PRIME_FIELD_H__
|
||||
#ifndef __FIELD_BINDING_H__
|
||||
#define __FIELD_BINDING_H__
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
1
sumcheck/src/gpu/cuda/includes/wrapper.h
Normal file
1
sumcheck/src/gpu/cuda/includes/wrapper.h
Normal file
@@ -0,0 +1 @@
|
||||
#include "field_binding.h"
|
||||
37
sumcheck/src/gpu/cuda/kernels/multilinear.cu
Normal file
37
sumcheck/src/gpu/cuda/kernels/multilinear.cu
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
|
||||
#include <stdio.h>
|
||||
using namespace bb;
|
||||
|
||||
__device__ fr merge(fr* evals, fr x, const int start) {
|
||||
return x * (evals[start + 1] - evals[start]) + evals[start];
|
||||
}
|
||||
|
||||
extern "C" __global__ void eval(fr* evals, fr* point, u_int32_t size, uint8_t offset, int num_blocks) {
|
||||
extern __shared__ fr evals_shared[];
|
||||
volatile int num_block_processed = 0;
|
||||
const int tid = threadIdx.x;
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
evals_shared[tid] = evals[idx];
|
||||
__syncthreads();
|
||||
auto num_threads = size >> 1;
|
||||
auto i = offset;
|
||||
while (num_threads > 0) {
|
||||
if (tid < num_threads) {
|
||||
evals_shared[tid] = point[i] * (evals_shared[2 * tid + 1] - evals_shared[2 * tid]) + evals_shared[2 * tid];
|
||||
}
|
||||
__syncthreads();
|
||||
i++;
|
||||
num_threads >>= 1;
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
num_block_processed++;
|
||||
while (num_block_processed != num_blocks);
|
||||
evals[idx >> 10] = evals_shared[tid];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void convert_to_montgomery_form(fr* evals, u_int32_t size) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
evals[idx].self_to_montgomery_form();
|
||||
}
|
||||
70
sumcheck/src/gpu/cuda/kernels/sumcheck.cu
Normal file
70
sumcheck/src/gpu/cuda/kernels/sumcheck.cu
Normal file
@@ -0,0 +1,70 @@
|
||||
#include "../includes/barretenberg/ecc/curves/bn254/fr.cuh"
|
||||
#include "./multilinear.cu"
|
||||
|
||||
using namespace bb;
|
||||
|
||||
__device__ void sum(fr* data, const int stride) {
|
||||
const int tid = threadIdx.x;
|
||||
for (unsigned int s = stride; s > 0; s >>= 1) {
|
||||
int idx = tid;
|
||||
while (idx < s) {
|
||||
data[idx] += data[idx + s];
|
||||
idx += blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
__device__ fr combine_function(fr* evals, unsigned int start, unsigned int stride, unsigned int num_args) {
|
||||
fr result = fr::zero();
|
||||
for (int i = 0; i < num_args; i++) result += evals[start + i * stride];
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __global__ void combine_and_sum(
|
||||
fr* buf, fr* result, unsigned int size, unsigned int num_args, unsigned int index
|
||||
) {
|
||||
const int tid = threadIdx.x;
|
||||
int idx = tid;
|
||||
while (idx < size) {
|
||||
buf[idx] = combine_function(buf, idx, size, num_args);
|
||||
idx += blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
sum(buf, size >> 1);
|
||||
if (tid == 0) result[index] = buf[0];
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half(
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, fr* challenge
|
||||
) {
|
||||
int tid = threadIdx.x;
|
||||
const int stride = 1 << (num_vars - 1);
|
||||
const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride + (blockIdx.x % num_blocks_per_poly) * blockDim.x;
|
||||
const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size + (blockIdx.x % num_blocks_per_poly) * blockDim.x;
|
||||
while (tid < stride && buf_offset < (blockIdx.x / num_blocks_per_poly + 1) * stride) {
|
||||
buf[buf_offset + tid] = (*challenge) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid];
|
||||
tid += blockDim.x * num_blocks_per_poly;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half_in_place(
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* challenge
|
||||
) {
|
||||
int tid = threadIdx.x;
|
||||
const int stride = 1 << (num_vars - 1);
|
||||
const int offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size + (blockIdx.x % num_blocks_per_poly) * blockDim.x;
|
||||
while (tid < stride) {
|
||||
int idx = offset + tid;
|
||||
polys[idx] = (*challenge) * (polys[idx + stride] - polys[idx]) + polys[idx];
|
||||
tid += blockDim.x * num_blocks_per_poly;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO : Pass transcript and squeeze random challenge using hash function
|
||||
extern "C" __global__ void squeeze_challenge(fr* challenges, unsigned int index) {
|
||||
if (threadIdx.x == 0) {
|
||||
challenges[index] = fr(1034);
|
||||
}
|
||||
}
|
||||
2
sumcheck/src/gpu/mod.rs
Normal file
2
sumcheck/src/gpu/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod multilinear;
|
||||
pub mod sumcheck;
|
||||
109
sumcheck/src/gpu/multilinear.rs
Normal file
109
sumcheck/src/gpu/multilinear.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use cudarc::{
|
||||
driver::{DriverError, LaunchAsync, LaunchConfig},
|
||||
nvrtc::Ptx,
|
||||
};
|
||||
use ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||
GPUApiWrapper, MULTILINEAR_PTX,
|
||||
};
|
||||
|
||||
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
pub fn eval(&mut self, num_vars: usize, evals: &[F], point: &[F]) -> Result<F, DriverError> {
|
||||
self.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form", "eval"],
|
||||
)?;
|
||||
let now = Instant::now();
|
||||
// copy to GPU
|
||||
let gpu_eval_point = self.copy_to_device(point)?;
|
||||
let evals = self.copy_to_device(evals)?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let mut num_vars = num_vars;
|
||||
let mut results = vec![];
|
||||
let mut offset = 0;
|
||||
while num_vars > 0 {
|
||||
let log2_data_size_per_block = 10;
|
||||
let (data_size_per_block, num_blocks) = if num_vars < log2_data_size_per_block {
|
||||
(1 << num_vars, 1)
|
||||
} else {
|
||||
(
|
||||
1 << log2_data_size_per_block,
|
||||
1 << (num_vars - log2_data_size_per_block),
|
||||
)
|
||||
};
|
||||
let now = Instant::now();
|
||||
let eval = self.gpu.get_func("multilinear", "eval").unwrap();
|
||||
// (number of field elements processed per thread block) * 32
|
||||
let shared_mem_bytes = data_size_per_block << 5;
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes,
|
||||
};
|
||||
unsafe {
|
||||
eval.launch(
|
||||
launch_config,
|
||||
(
|
||||
&evals,
|
||||
&gpu_eval_point,
|
||||
data_size_per_block,
|
||||
offset,
|
||||
num_blocks,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
if num_blocks == 1 {
|
||||
let now = Instant::now();
|
||||
results = self.gpu.sync_reclaim(evals)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
break;
|
||||
} else {
|
||||
num_vars -= log2_data_size_per_block;
|
||||
offset += log2_data_size_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(F::from_montgomery_form(results[0]))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use cudarc::driver::DriverError;
|
||||
use ff::Field;
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{cpu, GPUApiWrapper};
|
||||
|
||||
#[test]
|
||||
fn test_eval() -> Result<(), DriverError> {
|
||||
let num_vars = 20;
|
||||
let rng = OsRng::default();
|
||||
let evals = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
|
||||
let now = Instant::now();
|
||||
let eval_poly_result_by_cpu = cpu::multilinear::evaluate(&evals, &point);
|
||||
println!("Time taken to evaluate on cpu: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let eval_poly_result_by_gpu = gpu_api_wrapper.eval(num_vars, &evals, &point)?;
|
||||
println!("Time taken to evaluate on gpu: {:.2?}", now.elapsed());
|
||||
|
||||
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
389
sumcheck/src/gpu/sumcheck.rs
Normal file
389
sumcheck/src/gpu/sumcheck.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use std::cell::{RefCell, RefMut};
|
||||
|
||||
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||
FieldBinding, GPUApiWrapper,
|
||||
};
|
||||
|
||||
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
pub fn prove_sumcheck(
|
||||
&self,
|
||||
num_vars: usize,
|
||||
num_polys: usize,
|
||||
max_degree: usize,
|
||||
sum: F,
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
buf: RefCell<CudaViewMut<FieldBinding>>,
|
||||
challenges: &mut CudaViewMut<FieldBinding>,
|
||||
round_evals: RefCell<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let initial_poly_num_vars = num_vars;
|
||||
for round in 0..num_vars {
|
||||
self.eval_at_k_and_combine(
|
||||
initial_poly_num_vars,
|
||||
round,
|
||||
max_degree,
|
||||
num_polys,
|
||||
&polys.slice(..),
|
||||
buf.borrow_mut(),
|
||||
round_evals.borrow_mut(),
|
||||
)?;
|
||||
// squeeze challenge
|
||||
self.squeeze_challenge(round, challenges)?;
|
||||
// fold_into_half_in_place
|
||||
self.fold_into_half_in_place(
|
||||
initial_poly_num_vars,
|
||||
round,
|
||||
num_polys,
|
||||
polys,
|
||||
&challenges.slice(round..round + 1),
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn eval_at_k_and_combine(
|
||||
&self,
|
||||
initial_poly_num_vars: usize,
|
||||
round: usize,
|
||||
max_degree: usize,
|
||||
num_polys: usize,
|
||||
polys: &CudaView<FieldBinding>,
|
||||
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
||||
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()?;
|
||||
for k in 0..max_degree + 1 {
|
||||
let device_k = self
|
||||
.gpu
|
||||
.htod_copy(vec![F::to_montgomery_form(F::from(k as u64))])?;
|
||||
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
fold_into_half.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
&mut *buf,
|
||||
&device_k,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
let combine_and_sum = self.gpu.get_func("sumcheck", "combine_and_sum").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
combine_and_sum.launch(
|
||||
launch_config,
|
||||
(
|
||||
&mut *buf,
|
||||
&mut *round_evals,
|
||||
1 << (initial_poly_num_vars - round - 1),
|
||||
num_polys,
|
||||
round * (max_degree + 1) + k,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn squeeze_challenge(
|
||||
&self,
|
||||
round: usize,
|
||||
challenges: &mut CudaViewMut<FieldBinding>,
|
||||
) -> Result<(), DriverError> {
|
||||
let squeeze_challenge = self.gpu.get_func("sumcheck", "squeeze_challenge").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (1, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
squeeze_challenge.launch(launch_config, (challenges, round))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn fold_into_half_in_place(
|
||||
&self,
|
||||
initial_poly_num_vars: usize,
|
||||
round: usize,
|
||||
num_polys: usize,
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
challenge: &CudaView<FieldBinding>,
|
||||
) -> Result<(), DriverError> {
|
||||
let fold_into_half_in_place = self
|
||||
.gpu
|
||||
.get_func("sumcheck", "fold_into_half_in_place")
|
||||
.unwrap();
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()?;
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
fold_into_half_in_place.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
challenge,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{cell::RefCell, time::Instant};
|
||||
|
||||
use cudarc::{driver::DriverError, nvrtc::Ptx};
|
||||
use ff::Field;
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{cpu, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||
let num_vars = 20;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
let rng = OsRng::default();
|
||||
|
||||
let combine_function = |args: &Vec<Fr>| args.iter().product();
|
||||
|
||||
let polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half", "combine_and_sum"],
|
||||
)?;
|
||||
|
||||
let mut cpu_round_evals = vec![];
|
||||
let now = Instant::now();
|
||||
let polys = polys.iter().map(|poly| poly.as_slice()).collect_vec();
|
||||
for k in 0..max_degree + 1 {
|
||||
cpu_round_evals.push(cpu::sumcheck::eval_at_k_and_combine(
|
||||
num_vars,
|
||||
polys.as_slice(),
|
||||
&combine_function,
|
||||
Fr::from(k),
|
||||
));
|
||||
}
|
||||
println!(
|
||||
"Time taken to eval_at_k_and_combine on cpu: {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
|
||||
// copy polynomials to device
|
||||
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let mut buf = gpu_api_wrapper.copy_to_device(&vec![Fr::ZERO; num_polys << num_vars])?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
let mut round_evals =
|
||||
gpu_api_wrapper.copy_to_device(&vec![Fr::ZERO; max_degree as usize + 1])?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
let round = 0;
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper.eval_at_k_and_combine(
|
||||
num_vars,
|
||||
round,
|
||||
max_degree as usize,
|
||||
num_polys,
|
||||
&gpu_polys.slice(..),
|
||||
buf_view.borrow_mut(),
|
||||
round_evals_view.borrow_mut(),
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
let gpu_round_evals = gpu_api_wrapper.dtoh_sync_copy(round_evals.slice(..), true)?;
|
||||
cpu_round_evals
|
||||
.iter()
|
||||
.zip_eq(gpu_round_evals.iter())
|
||||
.for_each(|(a, b)| {
|
||||
assert_eq!(a, b);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_into_half_in_place() -> Result<(), DriverError> {
|
||||
let num_vars = 20;
|
||||
let num_polys = 4;
|
||||
|
||||
let rng = OsRng::default();
|
||||
let mut polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half_in_place"],
|
||||
)?;
|
||||
// copy polynomials to device
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let challenge = Fr::random(rng);
|
||||
let gpu_challenge = gpu_api_wrapper.copy_to_device(&vec![challenge])?;
|
||||
let round = 0;
|
||||
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper.fold_into_half_in_place(
|
||||
num_vars,
|
||||
round,
|
||||
num_polys,
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
&gpu_challenge.slice(..),
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to fold_into_half_in_place on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
|
||||
let gpu_result = (0..num_polys)
|
||||
.map(|i| {
|
||||
gpu_api_wrapper.dtoh_sync_copy(
|
||||
gpu_polys.slice(i << num_vars..(i * 2 + 1) << (num_vars - 1)),
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
|
||||
|
||||
let now = Instant::now();
|
||||
(0..num_polys)
|
||||
.for_each(|i| cpu::sumcheck::fold_into_half_in_place(&mut polys[i], challenge));
|
||||
println!("Time taken to fold_into_half on cpu: {:.2?}", now.elapsed());
|
||||
polys
|
||||
.iter_mut()
|
||||
.for_each(|poly| poly.truncate(1 << (num_vars - 1)));
|
||||
|
||||
gpu_result
|
||||
.into_iter()
|
||||
.zip_eq(polys)
|
||||
.for_each(|(gpu_result, cpu_result)| {
|
||||
assert_eq!(gpu_result, cpu_result);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||
let num_vars = 19;
|
||||
let num_polys = 9;
|
||||
let max_degree = 1;
|
||||
|
||||
let rng = OsRng::default();
|
||||
let polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&[
|
||||
"fold_into_half",
|
||||
"fold_into_half_in_place",
|
||||
"combine_and_sum",
|
||||
"squeeze_challenge",
|
||||
],
|
||||
)?;
|
||||
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
||||
acc + polys.iter().map(|poly| poly[index]).sum::<Fr>()
|
||||
});
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
|
||||
let mut challenges = gpu_api_wrapper.malloc_on_device(num_vars)?;
|
||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(num_vars * (max_degree + 1))?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper.prove_sumcheck(
|
||||
num_vars,
|
||||
num_polys,
|
||||
max_degree,
|
||||
sum,
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
buf_view,
|
||||
&mut challenges.slice_mut(..),
|
||||
round_evals_view,
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to prove sumcheck on gpu : {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
|
||||
let challenges = gpu_api_wrapper.dtoh_sync_copy(challenges.slice(..), true)?;
|
||||
let round_evals = (0..num_vars)
|
||||
.map(|i| {
|
||||
gpu_api_wrapper.dtoh_sync_copy(
|
||||
round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
|
||||
let round_evals = round_evals
|
||||
.iter()
|
||||
.map(|round_evals| round_evals.as_slice())
|
||||
.collect_vec();
|
||||
let result = cpu::sumcheck::verify_sumcheck(
|
||||
num_vars,
|
||||
max_degree,
|
||||
sum,
|
||||
&challenges[..],
|
||||
&round_evals[..],
|
||||
);
|
||||
assert!(result);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,20 @@
|
||||
// silence warnings due to bindgen
|
||||
#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)]
|
||||
|
||||
use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, DriverError, LaunchAsync, LaunchConfig};
|
||||
use cudarc::nvrtc::Ptx;
|
||||
use cpu::parallel::parallelize;
|
||||
use cudarc::driver::{
|
||||
CudaDevice, CudaSlice, CudaView, DeviceRepr, DriverError, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use ff::PrimeField;
|
||||
use field::{FromFieldBinding, ToFieldBinding};
|
||||
use itertools::Itertools;
|
||||
use fieldbinding::{FromFieldBinding, ToFieldBinding};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
mod cpu;
|
||||
pub mod field;
|
||||
pub mod fieldbinding;
|
||||
pub mod gpu;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
@@ -24,9 +26,8 @@ impl Default for FieldBinding {
|
||||
}
|
||||
|
||||
// include the compiled PTX code as string
|
||||
const MULTILINEAR_POLY_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
|
||||
const SCALAR_MULTIPLICATION_KERNEL: &str =
|
||||
include_str!(concat!(env!("OUT_DIR"), "/scalar_multiplication.ptx"));
|
||||
const MULTILINEAR_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/multilinear.ptx"));
|
||||
const SUMCHECK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
|
||||
|
||||
/// Wrapper struct for APIs using GPU
|
||||
pub struct GPUApiWrapper<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> {
|
||||
@@ -46,197 +47,69 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_ptx(
|
||||
&self,
|
||||
ptx: &str,
|
||||
module_name: &str,
|
||||
func_names: &[&'static str],
|
||||
) -> Result<(), DriverError> {
|
||||
// compile ptx
|
||||
let now = Instant::now();
|
||||
let ptx = Ptx::from_src(ptx);
|
||||
self.gpu.load_ptx(ptx, module_name, &func_names)?;
|
||||
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn convert_to_montgomery(
|
||||
&self,
|
||||
values: &[F],
|
||||
size: usize,
|
||||
chunk_size: usize,
|
||||
pub fn copy_to_device(
|
||||
&mut self,
|
||||
host_data: &[F],
|
||||
) -> Result<CudaSlice<FieldBinding>, DriverError> {
|
||||
let now = Instant::now();
|
||||
let values = self.gpu.htod_copy(
|
||||
values
|
||||
let device_data = self.gpu.htod_copy(
|
||||
host_data
|
||||
.into_par_iter()
|
||||
.map(|&eval| F::to_canonical_form(eval))
|
||||
.collect(),
|
||||
)?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
let convert_to_montgomery = self
|
||||
let convert_to_montgomery_form = self
|
||||
.gpu
|
||||
.get_func("multilinear", "convert_to_montgomery")
|
||||
.get_func("multilinear", "convert_to_montgomery_form")
|
||||
.unwrap();
|
||||
let size = host_data.len();
|
||||
unsafe {
|
||||
convert_to_montgomery.launch(
|
||||
LaunchConfig::for_num_elems((size / chunk_size) as u32),
|
||||
(&values, size, chunk_size),
|
||||
convert_to_montgomery_form.launch(
|
||||
LaunchConfig::for_num_elems(size as u32),
|
||||
(&device_data, size),
|
||||
)?;
|
||||
};
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
self.gpu.synchronize()?;
|
||||
Ok(values)
|
||||
Ok(device_data)
|
||||
}
|
||||
|
||||
pub fn eval(&self, num_vars: usize, evals: &[F], point: &[F]) -> Result<F, DriverError> {
|
||||
let now = Instant::now();
|
||||
let point = point
|
||||
.into_iter()
|
||||
.map(|f| F::to_montgomery_form(*f))
|
||||
.collect_vec();
|
||||
|
||||
// copy to GPU
|
||||
let gpu_eval_point = self.gpu.htod_copy(point)?;
|
||||
let mut evals = self.convert_to_montgomery(evals, 1 << num_vars, 1 << 5)?;
|
||||
|
||||
let mut num_vars = num_vars;
|
||||
let mut results = vec![];
|
||||
let mut offset = 0;
|
||||
while num_vars > 0 {
|
||||
let log2_chunk_size = 2;
|
||||
let chunk_size = 1 << log2_chunk_size;
|
||||
let (data_size_per_block, block_num) = if num_vars < 10 + log2_chunk_size {
|
||||
(1 << num_vars, 1)
|
||||
} else {
|
||||
(
|
||||
1 << (10 + log2_chunk_size),
|
||||
1 << (num_vars - 10 - log2_chunk_size),
|
||||
)
|
||||
};
|
||||
// each block produces single result and store to `buf`
|
||||
let buf = self.gpu.htod_copy(vec![
|
||||
FieldBinding::default();
|
||||
1 << (num_vars - log2_chunk_size)
|
||||
])?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
let eval = self.gpu.get_func("multilinear", "eval").unwrap();
|
||||
unsafe {
|
||||
eval.launch(
|
||||
LaunchConfig::for_num_elems((block_num << 10) as u32),
|
||||
(
|
||||
&evals,
|
||||
&buf,
|
||||
&gpu_eval_point,
|
||||
data_size_per_block,
|
||||
chunk_size,
|
||||
offset,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
results = self.gpu.sync_reclaim(buf)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
if block_num == 1 {
|
||||
break;
|
||||
} else {
|
||||
evals = self
|
||||
.gpu
|
||||
.htod_copy(results.iter().cloned().step_by(1 << 10).collect_vec())?;
|
||||
num_vars -= 10 + log2_chunk_size;
|
||||
offset += 10 + log2_chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(F::from_montgomery_form(results[0]))
|
||||
pub fn malloc_on_device(&self, len: usize) -> Result<CudaSlice<FieldBinding>, DriverError> {
|
||||
let device_ptr = unsafe { self.gpu.alloc(len << 5)? };
|
||||
Ok(device_ptr)
|
||||
}
|
||||
|
||||
pub fn mul(&self, values: &[F; 2]) -> Result<F, DriverError> {
|
||||
let now = Instant::now();
|
||||
let gpu_values = self
|
||||
.gpu
|
||||
.htod_copy(values.map(|v| F::to_montgomery_form(v)).to_vec())?;
|
||||
let results = self.gpu.htod_copy(vec![FieldBinding::default(); 1])?;
|
||||
println!("Time taken to initialise data: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let mul = self.gpu.get_func("scalar_multiplication", "mul").unwrap();
|
||||
unsafe {
|
||||
mul.launch(
|
||||
LaunchConfig::for_num_elems(1 as u32),
|
||||
(&gpu_values, &results),
|
||||
)?;
|
||||
pub fn dtoh_sync_copy(
|
||||
&self,
|
||||
device_data: CudaView<FieldBinding>,
|
||||
convert_to_montgomery_form: bool,
|
||||
) -> Result<Vec<F>, DriverError> {
|
||||
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
|
||||
let mut target = vec![F::ZERO; host_data.len()];
|
||||
if convert_to_montgomery_form {
|
||||
parallelize(&mut target, |(target, start)| {
|
||||
target
|
||||
.iter_mut()
|
||||
.zip(host_data.iter().skip(start))
|
||||
.for_each(|(target, &host_data)| {
|
||||
*target = F::from_montgomery_form(host_data);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
parallelize(&mut target, |(target, start)| {
|
||||
target
|
||||
.iter_mut()
|
||||
.zip(host_data.iter().skip(start))
|
||||
.for_each(|(target, &host_data)| {
|
||||
*target = F::from_canonical_form(host_data);
|
||||
});
|
||||
});
|
||||
}
|
||||
println!("Time taken to call kernel: {:.2?}", now.elapsed());
|
||||
Ok(target)
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let results = self.gpu.sync_reclaim(results)?;
|
||||
println!("Time taken to synchronize: {:.2?}", now.elapsed());
|
||||
Ok(F::from_canonical_form(results[0]))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use cudarc::driver::DriverError;
|
||||
use ff::Field;
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{cpu, MULTILINEAR_POLY_KERNEL, SCALAR_MULTIPLICATION_KERNEL};
|
||||
|
||||
use super::GPUApiWrapper;
|
||||
|
||||
#[test]
|
||||
fn test_eval() -> Result<(), DriverError> {
|
||||
let num_vars = 22;
|
||||
let rng = OsRng::default();
|
||||
let evals = (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
let point = (0..num_vars).map(|_| Fr::random(rng)).collect_vec();
|
||||
|
||||
let gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.load_ptx(
|
||||
MULTILINEAR_POLY_KERNEL,
|
||||
"multilinear",
|
||||
&["convert_to_montgomery", "eval"],
|
||||
)?;
|
||||
|
||||
let now = Instant::now();
|
||||
let eval_poly_result_by_cpu = cpu::multilinear::evaluate(&evals, &point);
|
||||
println!("Time taken to evaluate on cpu: {:.2?}", now.elapsed());
|
||||
|
||||
let now = Instant::now();
|
||||
let eval_poly_result_by_gpu = gpu_api_wrapper.eval(num_vars, &evals, &point)?;
|
||||
println!("Time taken to evaluate on gpu: {:.2?}", now.elapsed());
|
||||
|
||||
assert_eq!(eval_poly_result_by_cpu, eval_poly_result_by_gpu);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_multiplication() -> Result<(), DriverError> {
|
||||
let rng = OsRng::default();
|
||||
let values = [(); 2].map(|_| Fr::random(rng));
|
||||
let expected = values[0] * values[1];
|
||||
|
||||
let gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.load_ptx(
|
||||
SCALAR_MULTIPLICATION_KERNEL,
|
||||
"scalar_multiplication",
|
||||
&["mul"],
|
||||
)?;
|
||||
|
||||
let now = Instant::now();
|
||||
let actual = gpu_api_wrapper.mul(&values)?;
|
||||
println!("Time taken to evaluate on gpu: {:.2?}", now.elapsed());
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
Ok(())
|
||||
pub fn max_blocks_per_sm(&self) -> Result<usize, DriverError> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR)? as usize)
|
||||
}
|
||||
|
||||
pub fn max_threads_per_sm(&self) -> Result<usize, DriverError> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR)? as usize)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user