Fix error in update state of transcript

This commit is contained in:
Soowon Jeong
2024-09-30 04:36:49 +00:00
parent 7470db88e7
commit 4f3a6c55a2
6 changed files with 28 additions and 34 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ Cargo.lock
*.o
.vscode/

View File

@@ -97,10 +97,10 @@ pub(crate) fn verify_sumcheck_transcript<F: PrimeField + halo2curves::serde::Ser
let mut expected_sum = sum;
for round_index in 0..num_vars {
let evals: Vec<F> = transcript
.read_field_elements((max_degree + 1) * 32)
.read_field_elements(max_degree + 1)
.unwrap()
.chunks(32)
.map(|e| from_u8_to_f::<F>(e.to_vec())[0])
.map(|e| from_u8_to_f::<F>(e))
.collect_vec();
if evals.len() != max_degree + 1 {
return false;

View File

@@ -8,8 +8,8 @@ class Transcript {
private:
uint8_t* start;
uint8_t* cursor;
fr state;
public:
fr state;
__device__ void init_transcript(uint8_t* start, uint8_t* cursor, fr* state);
__device__ void write_field_element(fr fe);
__device__ fr read_field_element();

View File

@@ -37,6 +37,7 @@ extern "C" __global__ void sum(
__syncthreads();
}
if (tid == 0) t.write_field_element(data[0]);
state[0] = data[0];
}
extern "C" __global__ void fold_into_half(

View File

@@ -1,6 +1,6 @@
use std::cell::{RefCell, RefMut};
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use ff::PrimeField;
use crate::{
@@ -19,7 +19,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
polys: &mut CudaViewMut<FieldBinding>,
buf: RefCell<CudaViewMut<FieldBinding>>,
transcript: &mut TranscriptInner,
transcript_state: &CudaSlice<FieldBinding>,
transcript_state: RefCell<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
let initial_poly_num_vars = num_vars;
for round in 0..num_vars {
@@ -31,7 +31,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
&polys.slice(..),
buf.borrow_mut(),
transcript,
transcript_state,
transcript_state.borrow_mut(),
)?;
// fold_into_half_in_place
self.fold_into_half_in_place(
@@ -40,7 +40,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
num_polys,
polys,
transcript,
transcript_state,
transcript_state.borrow_mut(),
)?;
}
Ok(())
@@ -55,7 +55,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
polys: &CudaView<FieldBinding>,
mut buf: RefMut<CudaViewMut<FieldBinding>>,
transcript: &mut TranscriptInner,
transcript_state: &CudaSlice<FieldBinding>,
mut transcript_state: RefMut<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
let num_threads_per_block = 1024;
@@ -110,7 +110,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
round * (max_degree + 1) + k,
&mut start_view,
&mut cursor,
transcript_state,
&mut *transcript_state,
),
)?;
};
@@ -126,7 +126,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
num_polys: usize,
polys: &mut CudaViewMut<FieldBinding>,
transcript: &mut TranscriptInner,
state: &CudaSlice<FieldBinding>,
mut state: RefMut<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
let fold_into_half_in_place = self
.gpu
@@ -152,7 +152,7 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
polys,
&mut start_view,
&mut cursor,
state,
&mut *state,
),
)?;
};
@@ -255,9 +255,10 @@ mod tests {
transcript.get_cuda_slice(&mut gpu_api_wrapper, count, add_len)?;
let state = vec![transcript.state];
let state_slice = gpu_api_wrapper
let mut state_slice = gpu_api_wrapper
.copy_to_device(&state.as_slice())
.map_err(|err| LibraryError::Driver(err))?;
let state_view = RefCell::new(state_slice.slice_mut(..));
let round = 0;
let now = Instant::now();
gpu_api_wrapper
@@ -269,7 +270,7 @@ mod tests {
&gpu_polys.slice(..),
buf_view.borrow_mut(),
&mut transcript_inner,
&state_slice,
state_view.borrow_mut(),
)
.map_err(|err| LibraryError::Driver(err))?;
gpu_api_wrapper
@@ -332,9 +333,10 @@ mod tests {
.copy_to_device(&polys.concat())
.map_err(|err| LibraryError::Driver(err))?;
let challenge = Fr::random(rng);
let gpu_challenge = gpu_api_wrapper
let mut gpu_challenge = gpu_api_wrapper
.copy_to_device(&vec![challenge])
.map_err(|err| LibraryError::Driver(err))?;
let challenge_view = RefCell::new(gpu_challenge.slice_mut(..));
let round = 0;
let count = 0;
@@ -351,7 +353,7 @@ mod tests {
num_polys,
&mut gpu_polys.slice_mut(..),
&mut transcript_inner,
&gpu_challenge,
challenge_view.borrow_mut(),
)
.map_err(|err| LibraryError::Driver(err))?;
gpu_api_wrapper
@@ -421,7 +423,6 @@ mod tests {
"fold_into_half_in_place",
"combine",
"sum",
"squeeze_challenge",
],
)
.map_err(|err| LibraryError::Driver(err))?;
@@ -438,9 +439,6 @@ mod tests {
.map_err(|err| LibraryError::Driver(err))?;
let buf_view = RefCell::new(buf.slice_mut(..));
let mut challenges = gpu_api_wrapper
.malloc_on_device(num_vars)
.map_err(|err| LibraryError::Driver(err))?;
let mut round_evals = gpu_api_wrapper
.malloc_on_device(num_vars * (max_degree + 1))
.map_err(|err| LibraryError::Driver(err))?;
@@ -451,9 +449,11 @@ mod tests {
let add_len = num_vars * (max_degree + 1) * 32;
let challenge = vec![Fr::zero()];
let mut transcript = CudaKeccakTranscript::<Fr>::new(&Fr::zero());
let gpu_challenge = gpu_api_wrapper
let mut gpu_challenge = gpu_api_wrapper
.copy_to_device(&challenge)
.map_err(|err| LibraryError::Driver(err))?;
let challenge_view = RefCell::new(gpu_challenge.slice_mut(..));
let mut transcript_inner =
transcript.get_cuda_slice(&mut gpu_api_wrapper, count, add_len)?;
@@ -467,7 +467,7 @@ mod tests {
&mut gpu_polys.slice_mut(..),
buf_view,
&mut transcript_inner,
&gpu_challenge,
challenge_view,
)
.map_err(|err| LibraryError::Driver(err))?;
gpu_api_wrapper

View File

@@ -5,7 +5,6 @@ use crate::{
use cudarc::driver::CudaSlice;
use ff::PrimeField;
use halo2curves::serde::SerdeObject;
use itertools::Itertools;
use std::io::{Cursor, Read, Write};
pub enum Hash {
@@ -34,14 +33,8 @@ impl TranscriptInner {
}
}
pub fn from_u8_to_f<F: PrimeField + SerdeObject>(v: Vec<u8>) -> Vec<F> {
let src: Vec<&[u8]> = v.chunks(32).collect();
src.into_iter()
.map(|l| {
let data = l.chunks(8).collect_vec();
F::from_raw_bytes_unchecked(data.concat().as_slice()) * F::ONE
})
.collect_vec()
pub fn from_u8_to_f<F: PrimeField + SerdeObject>(v: &[u8]) -> F {
F::from_raw_bytes_unchecked(v) * F::ONE
}
pub struct CudaKeccakTranscript<F> {
@@ -79,13 +72,12 @@ impl<F: PrimeField + SerdeObject> CudaKeccakTranscript<F> {
Ok(new_t)
}
pub fn read_field_element(&mut self) -> Result<Vec<u8>, LibraryError> {
let mut repr: Vec<u8> = vec![];
pub fn read_field_element(&mut self) -> Result<[u8; 32], LibraryError> {
let mut repr: [u8; 32] = [0; 32];
self.stream
.read_exact(repr.as_mut())
.map_err(|_| LibraryError::Transcript)?;
self.state = from_u8_to_f::<F>(repr.clone())[0];
self.state = from_u8_to_f::<F>(&repr);
Ok(repr)
}