mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-10 12:58:02 -05:00
Fix error in update state of transcript
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@ Cargo.lock
|
||||
|
||||
*.o
|
||||
|
||||
.vscode/
|
||||
|
||||
@@ -97,10 +97,10 @@ pub(crate) fn verify_sumcheck_transcript<F: PrimeField + halo2curves::serde::Ser
|
||||
let mut expected_sum = sum;
|
||||
for round_index in 0..num_vars {
|
||||
let evals: Vec<F> = transcript
|
||||
.read_field_elements((max_degree + 1) * 32)
|
||||
.read_field_elements(max_degree + 1)
|
||||
.unwrap()
|
||||
.chunks(32)
|
||||
.map(|e| from_u8_to_f::<F>(e.to_vec())[0])
|
||||
.map(|e| from_u8_to_f::<F>(e))
|
||||
.collect_vec();
|
||||
if evals.len() != max_degree + 1 {
|
||||
return false;
|
||||
|
||||
@@ -8,8 +8,8 @@ class Transcript {
|
||||
private:
|
||||
uint8_t* start;
|
||||
uint8_t* cursor;
|
||||
fr state;
|
||||
public:
|
||||
fr state;
|
||||
__device__ void init_transcript(uint8_t* start, uint8_t* cursor, fr* state);
|
||||
__device__ void write_field_element(fr fe);
|
||||
__device__ fr read_field_element();
|
||||
|
||||
@@ -37,6 +37,7 @@ extern "C" __global__ void sum(
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) t.write_field_element(data[0]);
|
||||
state[0] = data[0];
|
||||
}
|
||||
|
||||
extern "C" __global__ void fold_into_half(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::cell::{RefCell, RefMut};
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
|
||||
use ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
@@ -19,7 +19,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
buf: RefCell<CudaViewMut<FieldBinding>>,
|
||||
transcript: &mut TranscriptInner,
|
||||
transcript_state: &CudaSlice<FieldBinding>,
|
||||
transcript_state: RefCell<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let initial_poly_num_vars = num_vars;
|
||||
for round in 0..num_vars {
|
||||
@@ -31,7 +31,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
&polys.slice(..),
|
||||
buf.borrow_mut(),
|
||||
transcript,
|
||||
transcript_state,
|
||||
transcript_state.borrow_mut(),
|
||||
)?;
|
||||
// fold_into_half_in_place
|
||||
self.fold_into_half_in_place(
|
||||
@@ -40,7 +40,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
num_polys,
|
||||
polys,
|
||||
transcript,
|
||||
transcript_state,
|
||||
transcript_state.borrow_mut(),
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
@@ -55,7 +55,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
polys: &CudaView<FieldBinding>,
|
||||
mut buf: RefMut<CudaViewMut<FieldBinding>>,
|
||||
transcript: &mut TranscriptInner,
|
||||
transcript_state: &CudaSlice<FieldBinding>,
|
||||
mut transcript_state: RefMut<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
|
||||
let num_threads_per_block = 1024;
|
||||
@@ -110,7 +110,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
round * (max_degree + 1) + k,
|
||||
&mut start_view,
|
||||
&mut cursor,
|
||||
transcript_state,
|
||||
&mut *transcript_state,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
@@ -126,7 +126,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
num_polys: usize,
|
||||
polys: &mut CudaViewMut<FieldBinding>,
|
||||
transcript: &mut TranscriptInner,
|
||||
state: &CudaSlice<FieldBinding>,
|
||||
mut state: RefMut<CudaViewMut<FieldBinding>>,
|
||||
) -> Result<(), DriverError> {
|
||||
let fold_into_half_in_place = self
|
||||
.gpu
|
||||
@@ -152,7 +152,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
polys,
|
||||
&mut start_view,
|
||||
&mut cursor,
|
||||
state,
|
||||
&mut *state,
|
||||
),
|
||||
)?;
|
||||
};
|
||||
@@ -255,9 +255,10 @@ mod tests {
|
||||
transcript.get_cuda_slice(&mut gpu_api_wrapper, count, add_len)?;
|
||||
|
||||
let state = vec![transcript.state];
|
||||
let state_slice = gpu_api_wrapper
|
||||
let mut state_slice = gpu_api_wrapper
|
||||
.copy_to_device(&state.as_slice())
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let state_view = RefCell::new(state_slice.slice_mut(..));
|
||||
let round = 0;
|
||||
let now = Instant::now();
|
||||
gpu_api_wrapper
|
||||
@@ -269,7 +270,7 @@ mod tests {
|
||||
&gpu_polys.slice(..),
|
||||
buf_view.borrow_mut(),
|
||||
&mut transcript_inner,
|
||||
&state_slice,
|
||||
state_view.borrow_mut(),
|
||||
)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
gpu_api_wrapper
|
||||
@@ -332,9 +333,10 @@ mod tests {
|
||||
.copy_to_device(&polys.concat())
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let challenge = Fr::random(rng);
|
||||
let gpu_challenge = gpu_api_wrapper
|
||||
let mut gpu_challenge = gpu_api_wrapper
|
||||
.copy_to_device(&vec![challenge])
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let challenge_view = RefCell::new(gpu_challenge.slice_mut(..));
|
||||
let round = 0;
|
||||
|
||||
let count = 0;
|
||||
@@ -351,7 +353,7 @@ mod tests {
|
||||
num_polys,
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
&mut transcript_inner,
|
||||
&gpu_challenge,
|
||||
challenge_view.borrow_mut(),
|
||||
)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
gpu_api_wrapper
|
||||
@@ -421,7 +423,6 @@ mod tests {
|
||||
"fold_into_half_in_place",
|
||||
"combine",
|
||||
"sum",
|
||||
"squeeze_challenge",
|
||||
],
|
||||
)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
@@ -438,9 +439,6 @@ mod tests {
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let buf_view = RefCell::new(buf.slice_mut(..));
|
||||
|
||||
let mut challenges = gpu_api_wrapper
|
||||
.malloc_on_device(num_vars)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let mut round_evals = gpu_api_wrapper
|
||||
.malloc_on_device(num_vars * (max_degree + 1))
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
@@ -451,9 +449,11 @@ mod tests {
|
||||
let add_len = num_vars * (max_degree + 1) * 32;
|
||||
let challenge = vec![Fr::zero()];
|
||||
let mut transcript = CudaKeccakTranscript::<Fr>::new(&Fr::zero());
|
||||
let gpu_challenge = gpu_api_wrapper
|
||||
let mut gpu_challenge = gpu_api_wrapper
|
||||
.copy_to_device(&challenge)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
let challenge_view = RefCell::new(gpu_challenge.slice_mut(..));
|
||||
|
||||
let mut transcript_inner =
|
||||
transcript.get_cuda_slice(&mut gpu_api_wrapper, count, add_len)?;
|
||||
|
||||
@@ -467,7 +467,7 @@ mod tests {
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
buf_view,
|
||||
&mut transcript_inner,
|
||||
&gpu_challenge,
|
||||
challenge_view,
|
||||
)
|
||||
.map_err(|err| LibraryError::Driver(err))?;
|
||||
gpu_api_wrapper
|
||||
|
||||
@@ -5,7 +5,6 @@ use crate::{
|
||||
use cudarc::driver::CudaSlice;
|
||||
use ff::PrimeField;
|
||||
use halo2curves::serde::SerdeObject;
|
||||
use itertools::Itertools;
|
||||
use std::io::{Cursor, Read, Write};
|
||||
|
||||
pub enum Hash {
|
||||
@@ -34,14 +33,8 @@ impl TranscriptInner {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_u8_to_f<F: PrimeField + SerdeObject>(v: Vec<u8>) -> Vec<F> {
|
||||
let src: Vec<&[u8]> = v.chunks(32).collect();
|
||||
src.into_iter()
|
||||
.map(|l| {
|
||||
let data = l.chunks(8).collect_vec();
|
||||
F::from_raw_bytes_unchecked(data.concat().as_slice()) * F::ONE
|
||||
})
|
||||
.collect_vec()
|
||||
pub fn from_u8_to_f<F: PrimeField + SerdeObject>(v: &[u8]) -> F {
|
||||
F::from_raw_bytes_unchecked(v) * F::ONE
|
||||
}
|
||||
|
||||
pub struct CudaKeccakTranscript<F> {
|
||||
@@ -79,13 +72,12 @@ impl<F: PrimeField + SerdeObject> CudaKeccakTranscript<F> {
|
||||
Ok(new_t)
|
||||
}
|
||||
|
||||
pub fn read_field_element(&mut self) -> Result<Vec<u8>, LibraryError> {
|
||||
let mut repr: Vec<u8> = vec![];
|
||||
pub fn read_field_element(&mut self) -> Result<[u8; 32], LibraryError> {
|
||||
let mut repr: [u8; 32] = [0; 32];
|
||||
self.stream
|
||||
.read_exact(repr.as_mut())
|
||||
.map_err(|_| LibraryError::Transcript)?;
|
||||
|
||||
self.state = from_u8_to_f::<F>(repr.clone())[0];
|
||||
self.state = from_u8_to_f::<F>(&repr);
|
||||
Ok(repr)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user