mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 04:17:59 -05:00
Implement in-memory transcript
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ Cargo.lock
|
||||
|
||||
*.o
|
||||
|
||||
.vscode
|
||||
|
||||
@@ -14,6 +14,8 @@ itertools = "0.10.5"
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.3", package = "halo2curves" }
|
||||
rand = "0.8"
|
||||
num-integer = "0.1.45"
|
||||
sha3 = "0.10.6"
|
||||
num-bigint = "0.4.3"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.66.1"
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::cpu::{arithmetic::barycentric_weights, parallel::parallelize};
|
||||
use crate::{
|
||||
cpu::{arithmetic::barycentric_weights, parallel::parallelize},
|
||||
transcript::Keccak256Transcript,
|
||||
};
|
||||
|
||||
use super::arithmetic::barycentric_interpolate;
|
||||
|
||||
@@ -36,6 +39,40 @@ pub(crate) fn fold_into_half_in_place<F: PrimeField>(poly: &mut [F], challenge:
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn verify_sumcheck_transcript<F: PrimeField>(
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
sum: F,
|
||||
transcript: &mut Keccak256Transcript<F>,
|
||||
) -> bool {
|
||||
let points_vec: Vec<F> = (0..max_degree + 1)
|
||||
.map(|i| F::from_u128(i as u128))
|
||||
.collect();
|
||||
let weights = barycentric_weights(&points_vec);
|
||||
let mut expected_sum = sum;
|
||||
for round_index in 0..num_vars {
|
||||
let evals = transcript.read_field_elements(max_degree + 1).unwrap();
|
||||
let round_poly_eval_at_0 = evals[0];
|
||||
let round_poly_eval_at_1 = evals[1];
|
||||
|
||||
let computed_sum = round_poly_eval_at_0 + round_poly_eval_at_1;
|
||||
|
||||
// Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1)
|
||||
if computed_sum != expected_sum {
|
||||
println!("computed_sum : {:?}", computed_sum);
|
||||
println!("expected_sum : {:?}", expected_sum);
|
||||
println!("round index : {}", round_index);
|
||||
return false;
|
||||
}
|
||||
|
||||
let challenge = transcript.squeeze_challenge();
|
||||
|
||||
// Compute r_{i}(α_i) using barycentric interpolation
|
||||
expected_sum = barycentric_interpolate(&weights, &points_vec, &evals, &challenge);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn verify_sumcheck<F: PrimeField>(
|
||||
num_vars: usize,
|
||||
max_degree: usize,
|
||||
|
||||
@@ -32,16 +32,16 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half(
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
|
||||
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* eval_point
|
||||
) {
|
||||
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
|
||||
const int stride = 1 << (num_vars - 1);
|
||||
const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride;
|
||||
const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size;
|
||||
while (tid < stride) {
|
||||
if (*challenge == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid];
|
||||
else if (*challenge == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride];
|
||||
else buf[buf_offset + tid] = (*challenge) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid];
|
||||
if (*eval_point == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid];
|
||||
else if (*eval_point == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride];
|
||||
else buf[buf_offset + tid] = (*eval_point) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid];
|
||||
tid += blockDim.x * num_blocks_per_poly;
|
||||
}
|
||||
}
|
||||
@@ -58,10 +58,3 @@ extern "C" __global__ void fold_into_half_in_place(
|
||||
tid += blockDim.x * num_blocks_per_poly;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO : Pass transcript and squeeze random challenge using hash function
|
||||
extern "C" __global__ void squeeze_challenge(fr* challenges, unsigned int index) {
|
||||
if (threadIdx.x == 0) {
|
||||
challenges[index] = fr(1034);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::cell::{RefCell, RefMut};
|
||||
|
||||
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, LaunchAsync, LaunchConfig};
|
||||
use ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||
FieldBinding, GPUApiWrapper,
|
||||
transcript::Keccak256Transcript,
|
||||
FieldBinding, GPUApiWrapper, LibraryError,
|
||||
};
|
||||
|
||||
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
@@ -18,9 +19,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
device_ks: &[CudaView<FieldBinding>],
|
||||
buf: RefCell<CudaViewMut<FieldBinding>>,
|
||||
challenges: &mut CudaViewMut<FieldBinding>,
|
||||
challenge: &mut CudaSlice<FieldBinding>,
|
||||
round_evals: RefCell<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
transcript: &mut Keccak256Transcript<F>,
|
||||
) -> Result<(), LibraryError> {
|
||||
let initial_poly_num_vars = num_vars;
|
||||
for round in 0..num_vars {
|
||||
self.eval_at_k_and_combine(
|
||||
@@ -32,16 +34,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
device_ks,
|
||||
buf.borrow_mut(),
|
||||
round_evals.borrow_mut(),
|
||||
transcript,
|
||||
)?;
|
||||
// squeeze challenge
|
||||
self.squeeze_challenge(round, challenges)?;
|
||||
let alpha = vec![transcript.squeeze_challenge()];
|
||||
self.overwrite_to_device(alpha.as_slice(), challenge)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
// fold_into_half_in_place
|
||||
self.fold_into_half_in_place(
|
||||
initial_poly_num_vars,
|
||||
round,
|
||||
num_polys,
|
||||
polys,
|
||||
&challenges.slice(round..round + 1),
|
||||
&challenge,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
@@ -57,8 +62,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
device_ks: &[CudaView<FieldBinding>],
|
||||
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
||||
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
||||
transcript: &mut Keccak256Transcript<F>,
|
||||
) -> Result<(), LibraryError> {
|
||||
let num_blocks_per_poly = self
|
||||
.max_blocks_per_sm()
|
||||
.map_err(|e| LibraryError::Driver(e))?
|
||||
/ num_polys
|
||||
* self.num_sm().map_err(|e| LibraryError::Driver(e))?;
|
||||
let num_threads_per_block = 1024;
|
||||
for k in 0..max_degree + 1 {
|
||||
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
|
||||
@@ -68,17 +78,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
fold_into_half.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
&mut *buf,
|
||||
&device_ks[k],
|
||||
),
|
||||
)?;
|
||||
fold_into_half
|
||||
.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
&mut *buf,
|
||||
&device_ks[k],
|
||||
),
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
};
|
||||
let size = 1 << (initial_poly_num_vars - round - 1);
|
||||
let combine = self.gpu.get_func("sumcheck", "combine").unwrap();
|
||||
@@ -88,7 +100,9 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
combine.launch(launch_config, (&mut *buf, size, num_polys))?;
|
||||
combine
|
||||
.launch(launch_config, (&mut *buf, size, num_polys))
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
};
|
||||
let sum = self.gpu.get_func("sumcheck", "sum").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
@@ -105,26 +119,14 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
size >> 1,
|
||||
round * (max_degree + 1) + k,
|
||||
),
|
||||
)?;
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn squeeze_challenge(
|
||||
&self,
|
||||
round: usize,
|
||||
challenges: &mut CudaViewMut<FieldBinding>,
|
||||
) -> Result<(), DriverError> {
|
||||
let squeeze_challenge = self.gpu.get_func("sumcheck", "squeeze_challenge").unwrap();
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: (1, 1, 1),
|
||||
block_dim: (1, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
squeeze_challenge.launch(launch_config, (challenges, round))?;
|
||||
}
|
||||
let fes = self
|
||||
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
transcript.write_field_elements(&fes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -134,13 +136,17 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
round: usize,
|
||||
num_polys: usize,
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
challenge: &CudaView<FieldBinding>,
|
||||
) -> Result<(), DriverError> {
|
||||
challenge: &CudaSlice<FieldBinding>,
|
||||
) -> Result<(), LibraryError> {
|
||||
let fold_into_half_in_place = self
|
||||
.gpu
|
||||
.get_func("sumcheck", "fold_into_half_in_place")
|
||||
.unwrap();
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
||||
let num_blocks_per_poly = self
|
||||
.max_blocks_per_sm()
|
||||
.map_err(|e| LibraryError::Driver(e))?
|
||||
/ num_polys
|
||||
* self.num_sm().map_err(|e| LibraryError::Driver(e))?;
|
||||
let num_threads_per_block = 1024;
|
||||
let launch_config = LaunchConfig {
|
||||
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
|
||||
@@ -148,16 +154,18 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe {
|
||||
fold_into_half_in_place.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
challenge,
|
||||
),
|
||||
)?;
|
||||
fold_into_half_in_place
|
||||
.launch(
|
||||
launch_config,
|
||||
(
|
||||
initial_poly_num_vars - round,
|
||||
1 << initial_poly_num_vars,
|
||||
num_blocks_per_poly,
|
||||
polys,
|
||||
challenge,
|
||||
),
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
@@ -167,16 +175,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
mod tests {
|
||||
use std::{cell::RefCell, time::Instant};
|
||||
|
||||
use cudarc::{driver::DriverError, nvrtc::Ptx};
|
||||
use cudarc::nvrtc::Ptx;
|
||||
use ff::Field;
|
||||
use halo2curves::bn256::Fr;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
|
||||
use crate::{
|
||||
cpu, fieldbinding::ToFieldBinding, transcript::Keccak256Transcript, GPUApiWrapper,
|
||||
LibraryError, MULTILINEAR_PTX, SUMCHECK_PTX,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||
fn test_eval_at_k_and_combine() -> Result<(), LibraryError> {
|
||||
let num_vars = 10;
|
||||
let num_polys = 3;
|
||||
let max_degree = 3;
|
||||
@@ -198,17 +209,24 @@ mod tests {
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half", "combine", "sum"],
|
||||
)?;
|
||||
let mut gpu_api_wrapper =
|
||||
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half", "combine", "sum"],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
|
||||
let mut cpu_round_evals = vec![];
|
||||
let now = Instant::now();
|
||||
@@ -227,19 +245,27 @@ mod tests {
|
||||
);
|
||||
|
||||
// copy polynomials to device
|
||||
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let gpu_polys = gpu_api_wrapper
|
||||
.copy_to_device(&polys.concat())
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let device_ks = (0..max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let mut buf = gpu_api_wrapper
|
||||
.malloc_on_device(num_polys << (num_vars - 1))
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
|
||||
let mut round_evals = gpu_api_wrapper
|
||||
.malloc_on_device(max_degree as usize + 1)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
let round = 0;
|
||||
let mut transcript = Keccak256Transcript::<Fr>::new();
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper.eval_at_k_and_combine(
|
||||
num_vars,
|
||||
@@ -253,13 +279,15 @@ mod tests {
|
||||
.collect_vec(),
|
||||
buf_view.borrow_mut(),
|
||||
round_evals_view.borrow_mut(),
|
||||
&mut transcript,
|
||||
)?;
|
||||
println!(
|
||||
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
let gpu_round_evals = gpu_api_wrapper
|
||||
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
|
||||
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
cpu_round_evals
|
||||
.iter()
|
||||
.zip_eq(gpu_round_evals.iter())
|
||||
@@ -271,7 +299,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_into_half_in_place() -> Result<(), DriverError> {
|
||||
fn test_fold_into_half_in_place() -> Result<(), LibraryError> {
|
||||
let num_vars = 6;
|
||||
let num_polys = 4;
|
||||
|
||||
@@ -290,21 +318,33 @@ mod tests {
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half_in_place"],
|
||||
)?;
|
||||
let mut gpu_api_wrapper =
|
||||
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&["fold_into_half_in_place"],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
// copy polynomials to device
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let mut gpu_polys = gpu_api_wrapper
|
||||
.copy_to_device(&polys.concat())
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let challenge = Fr::random(rng);
|
||||
let gpu_challenge = gpu_api_wrapper.copy_to_device(&vec![challenge])?;
|
||||
|
||||
let gpu_challenge = gpu_api_wrapper
|
||||
.copy_to_device(&vec![challenge])
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let round = 0;
|
||||
|
||||
let now = Instant::now();
|
||||
@@ -313,7 +353,7 @@ mod tests {
|
||||
round,
|
||||
num_polys,
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
&gpu_challenge.slice(..),
|
||||
&gpu_challenge,
|
||||
)?;
|
||||
println!(
|
||||
"Time taken to fold_into_half_in_place on gpu: {:.2?}",
|
||||
@@ -327,7 +367,8 @@ mod tests {
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
|
||||
.collect::<Result<Vec<Vec<Fr>>, _>>()
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
|
||||
let now = Instant::now();
|
||||
(0..num_polys)
|
||||
@@ -348,8 +389,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||
let num_vars = 25;
|
||||
fn test_prove_sumcheck() -> Result<(), LibraryError> {
|
||||
let num_vars = 4;
|
||||
let num_polys = 2;
|
||||
let max_degree = 2;
|
||||
|
||||
@@ -368,26 +409,34 @@ mod tests {
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&[
|
||||
"fold_into_half",
|
||||
"fold_into_half_in_place",
|
||||
"combine",
|
||||
"sum",
|
||||
"squeeze_challenge",
|
||||
],
|
||||
)?;
|
||||
|
||||
let mut gpu_api_wrapper =
|
||||
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(MULTILINEAR_PTX),
|
||||
"multilinear",
|
||||
&["convert_to_montgomery_form"],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.load_ptx(
|
||||
Ptx::from_src(SUMCHECK_PTX),
|
||||
"sumcheck",
|
||||
&[
|
||||
"fold_into_half",
|
||||
"fold_into_half_in_place",
|
||||
"combine",
|
||||
"sum",
|
||||
],
|
||||
)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let mut transcript = Keccak256Transcript::<Fr>::new();
|
||||
let now = Instant::now();
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
|
||||
let mut gpu_polys = gpu_api_wrapper
|
||||
.copy_to_device(&polys.concat())
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
|
||||
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
|
||||
});
|
||||
@@ -397,12 +446,19 @@ mod tests {
|
||||
.gpu
|
||||
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let mut buf = gpu_api_wrapper
|
||||
.malloc_on_device(num_polys << (num_vars - 1))
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
|
||||
let mut challenges = gpu_api_wrapper.malloc_on_device(num_vars)?;
|
||||
let mut round_evals = gpu_api_wrapper.malloc_on_device(num_vars * (max_degree + 1))?;
|
||||
let mut challenges = gpu_api_wrapper
|
||||
.malloc_on_device(1)
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let mut round_evals = gpu_api_wrapper
|
||||
.malloc_on_device(num_vars * (max_degree + 1))
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
|
||||
println!("Time taken to copy data to device : {:.2?}", now.elapsed());
|
||||
|
||||
@@ -418,35 +474,34 @@ mod tests {
|
||||
.map(|device_k| device_k.slice(..))
|
||||
.collect_vec(),
|
||||
buf_view,
|
||||
&mut challenges.slice_mut(..),
|
||||
&mut challenges,
|
||||
round_evals_view,
|
||||
&mut transcript,
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
gpu_api_wrapper
|
||||
.gpu
|
||||
.synchronize()
|
||||
.map_err(|e| LibraryError::Driver(e))?;
|
||||
println!(
|
||||
"Time taken to prove sumcheck on gpu : {:.2?}",
|
||||
now.elapsed()
|
||||
);
|
||||
|
||||
let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
|
||||
let round_evals = (0..num_vars)
|
||||
.map(|i| {
|
||||
gpu_api_wrapper.dtoh_sync_copy(
|
||||
&round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||
true,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
|
||||
let round_evals = round_evals
|
||||
.iter()
|
||||
.map(|round_evals| round_evals.as_slice())
|
||||
.collect_vec();
|
||||
let result = cpu::sumcheck::verify_sumcheck(
|
||||
num_vars,
|
||||
max_degree,
|
||||
sum,
|
||||
&challenges[..],
|
||||
&round_evals[..],
|
||||
);
|
||||
// let challenges = gpu_api_wrapper
|
||||
// .dtoh_sync_copy(&challenges.slice(..), true)
|
||||
// .map_err(|e| LibraryError::Driver(e))?;
|
||||
// let round_evals = (0..num_vars)
|
||||
// .map(|i| {
|
||||
// gpu_api_wrapper.dtoh_sync_copy(
|
||||
// &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
|
||||
// true,
|
||||
// )
|
||||
// })
|
||||
// .collect::<Result<Vec<Vec<Fr>>, _>>()
|
||||
// .map_err(|e| LibraryError::Driver(e))?;
|
||||
|
||||
let result =
|
||||
cpu::sumcheck::verify_sumcheck_transcript(num_vars, max_degree, sum, &mut transcript);
|
||||
assert!(result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ use std::time::Instant;
|
||||
mod cpu;
|
||||
pub mod fieldbinding;
|
||||
pub mod gpu;
|
||||
pub mod transcript;
|
||||
pub mod utils;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
@@ -71,6 +73,21 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
Ok(device_data)
|
||||
}
|
||||
|
||||
pub fn overwrite_to_device(
|
||||
&self,
|
||||
host_data: &[F],
|
||||
dst: &mut CudaSlice<FieldBinding>,
|
||||
) -> Result<(), DriverError> {
|
||||
self.gpu.htod_copy_into(
|
||||
host_data
|
||||
.into_iter()
|
||||
.map(|&eval| F::to_canonical_form(eval))
|
||||
.collect(),
|
||||
dst,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn malloc_on_device(&self, len: usize) -> Result<CudaSlice<FieldBinding>, DriverError> {
|
||||
let device_ptr = unsafe { self.gpu.alloc::<FieldBinding>(len)? };
|
||||
Ok(device_ptr)
|
||||
@@ -123,3 +140,9 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)? as usize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum LibraryError {
|
||||
Driver(DriverError),
|
||||
Transcript(std::io::ErrorKind, String),
|
||||
}
|
||||
|
||||
80
sumcheck/src/transcript.rs
Normal file
80
sumcheck/src/transcript.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use crate::{
|
||||
fieldbinding::{FromFieldBinding, ToFieldBinding},
|
||||
utils::{arithmetic::fe_mod_from_le_bytes, hash::Hash},
|
||||
GPUApiWrapper, LibraryError,
|
||||
};
|
||||
use ff::PrimeField;
|
||||
use halo2curves::serde::SerdeObject;
|
||||
use sha3::Keccak256;
|
||||
use std::{
|
||||
io::{Cursor, Read, Write},
|
||||
marker::PhantomData,
|
||||
};
|
||||
|
||||
pub type Keccak256Transcript<F> = CudaTranscript<Keccak256, F>;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CudaTranscript<H, F> {
|
||||
pub stream: Cursor<Vec<u8>>,
|
||||
pub state: H,
|
||||
marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<H: Hash, F: PrimeField> CudaTranscript<H, F> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn squeeze_challenge(&mut self) -> F {
|
||||
let hash = self.state.finalize_fixed_reset();
|
||||
self.state.update(&hash);
|
||||
fe_mod_from_le_bytes(hash)
|
||||
}
|
||||
|
||||
fn common_field_element(&mut self, fe: &F) -> Result<(), LibraryError> {
|
||||
self.state.update_field_element(fe);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_field_element(&mut self, fe: &F) -> Result<(), LibraryError> {
|
||||
self.common_field_element(fe)?;
|
||||
let mut repr = fe.to_repr();
|
||||
repr.as_mut().reverse();
|
||||
self.stream
|
||||
.write_all(repr.as_ref())
|
||||
.map_err(|err| LibraryError::Transcript(err.kind(), err.to_string()))
|
||||
}
|
||||
|
||||
pub fn write_field_elements<'a>(
|
||||
&mut self,
|
||||
fes: impl IntoIterator<Item = &'a F>,
|
||||
) -> Result<(), LibraryError>
|
||||
where
|
||||
F: 'a,
|
||||
{
|
||||
for fe in fes.into_iter() {
|
||||
self.write_field_element(fe)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_field_element(&mut self) -> Result<F, LibraryError> {
|
||||
let mut repr = <F as PrimeField>::Repr::default();
|
||||
self.stream
|
||||
.read_exact(repr.as_mut())
|
||||
.map_err(|err| LibraryError::Transcript(err.kind(), err.to_string()))?;
|
||||
repr.as_mut().reverse();
|
||||
let fe = F::from_repr_vartime(repr).ok_or_else(|| {
|
||||
LibraryError::Transcript(
|
||||
std::io::ErrorKind::Other,
|
||||
"Invalid field element encoding in proof".to_string(),
|
||||
)
|
||||
})?;
|
||||
self.common_field_element(&fe)?;
|
||||
Ok(fe)
|
||||
}
|
||||
|
||||
pub fn read_field_elements(&mut self, n: usize) -> Result<Vec<F>, LibraryError> {
|
||||
(0..n).map(|_| self.read_field_element()).collect()
|
||||
}
|
||||
}
|
||||
2
sumcheck/src/utils.rs
Normal file
2
sumcheck/src/utils.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod arithmetic;
|
||||
pub mod hash;
|
||||
18
sumcheck/src/utils/arithmetic.rs
Normal file
18
sumcheck/src/utils/arithmetic.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use ff::PrimeField;
|
||||
use num_bigint::BigUint;
|
||||
|
||||
pub fn modulus<F: PrimeField>() -> BigUint {
|
||||
BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64
|
||||
}
|
||||
|
||||
pub fn fe_mod_from_le_bytes<F: PrimeField>(bytes: impl AsRef<[u8]>) -> F {
|
||||
fe_from_le_bytes((BigUint::from_bytes_le(bytes.as_ref()) % modulus::<F>()).to_bytes_le())
|
||||
}
|
||||
|
||||
pub fn fe_from_le_bytes<F: PrimeField>(bytes: impl AsRef<[u8]>) -> F {
|
||||
let bytes = bytes.as_ref();
|
||||
let mut repr = F::Repr::default();
|
||||
assert!(bytes.len() <= repr.as_ref().len());
|
||||
repr.as_mut()[..bytes.len()].copy_from_slice(bytes);
|
||||
F::from_repr(repr).unwrap()
|
||||
}
|
||||
31
sumcheck/src/utils/hash.rs
Normal file
31
sumcheck/src/utils/hash.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use ff::PrimeField;
|
||||
use sha3::digest::{Digest, HashMarker};
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub use sha3::{
|
||||
digest::{FixedOutputReset, Output, Update},
|
||||
Keccak256,
|
||||
};
|
||||
|
||||
pub trait Hash:
|
||||
'static + Sized + Clone + Debug + FixedOutputReset + Default + Update + HashMarker
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn update_field_element(&mut self, field: &impl PrimeField) {
|
||||
Digest::update(self, field.to_repr());
|
||||
}
|
||||
|
||||
fn digest(data: impl AsRef<[u8]>) -> Output<Self> {
|
||||
let mut hasher = Self::default();
|
||||
hasher.update(data.as_ref());
|
||||
hasher.finalize()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static + Sized + Clone + Debug + FixedOutputReset + Default + Update + HashMarker> Hash
|
||||
for T
|
||||
{
|
||||
}
|
||||
Reference in New Issue
Block a user