This commit is contained in:
DoHoonKim8
2024-10-10 16:58:03 +00:00
parent c7276056ad
commit 685f17814d
4 changed files with 26 additions and 21 deletions

View File

@@ -42,9 +42,17 @@ pub fn inner_product<'a, 'b, F: Field>(
.unwrap_or_default()
}
pub fn barycentric_interpolate<E: ExtensionField>(weights: &[E], points: &[E::BaseField], evals: &[E], x: &E) -> E {
pub fn barycentric_interpolate<E: ExtensionField>(
weights: &[E],
points: &[E::BaseField],
evals: &[E],
x: &E,
) -> E {
let (coeffs, sum_inv) = {
let mut coeffs = points.iter().map(|point| *x - E::from_base(point)).collect_vec();
let mut coeffs = points
.iter()
.map(|point| *x - E::from_base(point))
.collect_vec();
coeffs.batch_invert();
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
*coeff *= weight;

View File

@@ -57,12 +57,8 @@ pub(crate) fn verify_sumcheck<E: ExtensionField>(
if sum != computed_sum {
return false;
}
let mut expected_sum = barycentric_interpolate::<E>(
&weights,
&points_vec,
evals[0],
&challenges[0],
);
let mut expected_sum =
barycentric_interpolate::<E>(&weights, &points_vec, evals[0], &challenges[0]);
// round 1..num_vars
for round_index in 1..num_vars {
if evals[round_index].len() != max_degree + 1 {

View File

@@ -8,7 +8,7 @@ use crate::{FieldBinding, GPUSumcheckProver, QuadraticExtFieldBinding};
impl<E> GPUSumcheckProver<E>
where
E: ExtensionField + From<QuadraticExtFieldBinding> + Into<QuadraticExtFieldBinding>,
E::BaseField: From<FieldBinding> + Into<FieldBinding>
E::BaseField: From<FieldBinding> + Into<FieldBinding>,
{
pub fn prove_sumcheck(
&self,
@@ -176,10 +176,7 @@ mod tests {
use cudarc::{
driver::{
result::{
event::{create, elapsed, record},
stream::wait_event,
},
result::event::{create, elapsed, record},
sys, DriverError,
},
nvrtc::Ptx,
@@ -238,7 +235,8 @@ mod tests {
);
// copy polynomials to device
let gpu_polys = gpu_api_wrapper.copy_exts_to_device(&polys.concat().into_iter().map(|b| b.into()).collect_vec())?;
let gpu_polys = gpu_api_wrapper
.copy_exts_to_device(&polys.concat().into_iter().map(|b| b.into()).collect_vec())?;
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
@@ -388,9 +386,14 @@ mod tests {
let now = Instant::now();
let mut gpu_polys = gpu_api_wrapper.copy_exts_to_device(&polys.concat())?;
let sum = (0..1 << num_vars).fold(GoldilocksExt2::ZERO, |acc, index| {
acc + polys.iter().map(|poly| poly[index]).product::<GoldilocksExt2>()
}).to_limbs()[0];
let sum = (0..1 << num_vars)
.fold(GoldilocksExt2::ZERO, |acc, index| {
acc + polys
.iter()
.map(|poly| poly[index])
.product::<GoldilocksExt2>()
})
.to_limbs()[0];
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper

View File

@@ -4,7 +4,6 @@
use std::{marker::PhantomData, sync::Arc, time::Instant};
use cudarc::driver::{CudaDevice, CudaSlice, CudaView, DeviceRepr, DriverError};
use ff::{Field, PrimeField};
use goldilocks::ExtensionField;
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
@@ -32,8 +31,7 @@ impl Default for QuadraticExtFieldBinding {
const SUMCHECK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
/// Struct for GPU sumcheck prover
pub struct GPUSumcheckProver<E>
{
pub struct GPUSumcheckProver<E> {
gpu: Arc<CudaDevice>,
_marker: PhantomData<E>,
}
@@ -41,7 +39,7 @@ pub struct GPUSumcheckProver<E>
impl<E> GPUSumcheckProver<E>
where
E: ExtensionField + From<QuadraticExtFieldBinding> + Into<QuadraticExtFieldBinding>,
E::BaseField: From<FieldBinding> + Into<FieldBinding>
E::BaseField: From<FieldBinding> + Into<FieldBinding>,
{
pub fn setup() -> Result<Self, DriverError> {
// setup GPU device