mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-10 12:58:02 -05:00
Fmt
This commit is contained in:
@@ -42,9 +42,17 @@ pub fn inner_product<'a, 'b, F: Field>(
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn barycentric_interpolate<E: ExtensionField>(weights: &[E], points: &[E::BaseField], evals: &[E], x: &E) -> E {
|
||||
pub fn barycentric_interpolate<E: ExtensionField>(
|
||||
weights: &[E],
|
||||
points: &[E::BaseField],
|
||||
evals: &[E],
|
||||
x: &E,
|
||||
) -> E {
|
||||
let (coeffs, sum_inv) = {
|
||||
let mut coeffs = points.iter().map(|point| *x - E::from_base(point)).collect_vec();
|
||||
let mut coeffs = points
|
||||
.iter()
|
||||
.map(|point| *x - E::from_base(point))
|
||||
.collect_vec();
|
||||
coeffs.batch_invert();
|
||||
coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| {
|
||||
*coeff *= weight;
|
||||
|
||||
@@ -57,12 +57,8 @@ pub(crate) fn verify_sumcheck<E: ExtensionField>(
|
||||
if sum != computed_sum {
|
||||
return false;
|
||||
}
|
||||
let mut expected_sum = barycentric_interpolate::<E>(
|
||||
&weights,
|
||||
&points_vec,
|
||||
evals[0],
|
||||
&challenges[0],
|
||||
);
|
||||
let mut expected_sum =
|
||||
barycentric_interpolate::<E>(&weights, &points_vec, evals[0], &challenges[0]);
|
||||
// round 1..num_vars
|
||||
for round_index in 1..num_vars {
|
||||
if evals[round_index].len() != max_degree + 1 {
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{FieldBinding, GPUSumcheckProver, QuadraticExtFieldBinding};
|
||||
impl<E> GPUSumcheckProver<E>
|
||||
where
|
||||
E: ExtensionField + From<QuadraticExtFieldBinding> + Into<QuadraticExtFieldBinding>,
|
||||
E::BaseField: From<FieldBinding> + Into<FieldBinding>
|
||||
E::BaseField: From<FieldBinding> + Into<FieldBinding>,
|
||||
{
|
||||
pub fn prove_sumcheck(
|
||||
&self,
|
||||
@@ -176,10 +176,7 @@ mod tests {
|
||||
|
||||
use cudarc::{
|
||||
driver::{
|
||||
result::{
|
||||
event::{create, elapsed, record},
|
||||
stream::wait_event,
|
||||
},
|
||||
result::event::{create, elapsed, record},
|
||||
sys, DriverError,
|
||||
},
|
||||
nvrtc::Ptx,
|
||||
@@ -238,7 +235,8 @@ mod tests {
|
||||
);
|
||||
|
||||
// copy polynomials to device
|
||||
let gpu_polys = gpu_api_wrapper.copy_exts_to_device(&polys.concat().into_iter().map(|b| b.into()).collect_vec())?;
|
||||
let gpu_polys = gpu_api_wrapper
|
||||
.copy_exts_to_device(&polys.concat().into_iter().map(|b| b.into()).collect_vec())?;
|
||||
let device_ks = (0..max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
@@ -388,9 +386,14 @@ mod tests {
|
||||
|
||||
let now = Instant::now();
|
||||
let mut gpu_polys = gpu_api_wrapper.copy_exts_to_device(&polys.concat())?;
|
||||
let sum = (0..1 << num_vars).fold(GoldilocksExt2::ZERO, |acc, index| {
|
||||
acc + polys.iter().map(|poly| poly[index]).product::<GoldilocksExt2>()
|
||||
}).to_limbs()[0];
|
||||
let sum = (0..1 << num_vars)
|
||||
.fold(GoldilocksExt2::ZERO, |acc, index| {
|
||||
acc + polys
|
||||
.iter()
|
||||
.map(|poly| poly[index])
|
||||
.product::<GoldilocksExt2>()
|
||||
})
|
||||
.to_limbs()[0];
|
||||
let device_ks = (0..max_degree + 1)
|
||||
.map(|k| {
|
||||
gpu_api_wrapper
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
use std::{marker::PhantomData, sync::Arc, time::Instant};
|
||||
|
||||
use cudarc::driver::{CudaDevice, CudaSlice, CudaView, DeviceRepr, DriverError};
|
||||
use ff::{Field, PrimeField};
|
||||
use goldilocks::ExtensionField;
|
||||
use itertools::Itertools;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
@@ -32,8 +31,7 @@ impl Default for QuadraticExtFieldBinding {
|
||||
const SUMCHECK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/sumcheck.ptx"));
|
||||
|
||||
/// Struct for GPU sumcheck prover
|
||||
pub struct GPUSumcheckProver<E>
|
||||
{
|
||||
pub struct GPUSumcheckProver<E> {
|
||||
gpu: Arc<CudaDevice>,
|
||||
_marker: PhantomData<E>,
|
||||
}
|
||||
@@ -41,7 +39,7 @@ pub struct GPUSumcheckProver<E>
|
||||
impl<E> GPUSumcheckProver<E>
|
||||
where
|
||||
E: ExtensionField + From<QuadraticExtFieldBinding> + Into<QuadraticExtFieldBinding>,
|
||||
E::BaseField: From<FieldBinding> + Into<FieldBinding>
|
||||
E::BaseField: From<FieldBinding> + Into<FieldBinding>,
|
||||
{
|
||||
pub fn setup() -> Result<Self, DriverError> {
|
||||
// setup GPU device
|
||||
|
||||
Reference in New Issue
Block a user