Implement in-memory transcript

This commit is contained in:
Soowon Jeong
2024-10-30 10:20:15 +00:00
parent bc8340d1d2
commit ffaf9c5f47
10 changed files with 387 additions and 145 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@ Cargo.lock
*.o
.vscode

View File

@@ -14,6 +14,8 @@ itertools = "0.10.5"
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.3", package = "halo2curves" }
rand = "0.8"
num-integer = "0.1.45"
sha3 = "0.10.6"
num-bigint = "0.4.3"
[build-dependencies]
bindgen = "0.66.1"

View File

@@ -1,7 +1,10 @@
use ff::PrimeField;
use itertools::Itertools;
use crate::cpu::{arithmetic::barycentric_weights, parallel::parallelize};
use crate::{
cpu::{arithmetic::barycentric_weights, parallel::parallelize},
transcript::Keccak256Transcript,
};
use super::arithmetic::barycentric_interpolate;
@@ -36,6 +39,40 @@ pub(crate) fn fold_into_half_in_place<F: PrimeField>(poly: &mut [F], challenge:
});
}
pub(crate) fn verify_sumcheck_transcript<F: PrimeField>(
num_vars: usize,
max_degree: usize,
sum: F,
transcript: &mut Keccak256Transcript<F>,
) -> bool {
let points_vec: Vec<F> = (0..max_degree + 1)
.map(|i| F::from_u128(i as u128))
.collect();
let weights = barycentric_weights(&points_vec);
let mut expected_sum = sum;
for round_index in 0..num_vars {
let evals = transcript.read_field_elements(max_degree + 1).unwrap();
let round_poly_eval_at_0 = evals[0];
let round_poly_eval_at_1 = evals[1];
let computed_sum = round_poly_eval_at_0 + round_poly_eval_at_1;
// Check r_{i}(α_i) == r_{i+1}(0) + r_{i+1}(1)
if computed_sum != expected_sum {
println!("computed_sum : {:?}", computed_sum);
println!("expected_sum : {:?}", expected_sum);
println!("round index : {}", round_index);
return false;
}
let challenge = transcript.squeeze_challenge();
// Compute r_{i}(α_i) using barycentric interpolation
expected_sum = barycentric_interpolate(&weights, &points_vec, &evals, &challenge);
}
true
}
pub(crate) fn verify_sumcheck<F: PrimeField>(
num_vars: usize,
max_degree: usize,

View File

@@ -32,16 +32,16 @@ extern "C" __global__ void sum(fr* data, fr* result, unsigned int stride, unsign
}
extern "C" __global__ void fold_into_half(
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* challenge
unsigned int num_vars, unsigned int initial_poly_size, unsigned int num_blocks_per_poly, fr* polys, fr* buf, const fr* eval_point
) {
int tid = (blockIdx.x % num_blocks_per_poly) * blockDim.x + threadIdx.x;
const int stride = 1 << (num_vars - 1);
const int buf_offset = (blockIdx.x / num_blocks_per_poly) * stride;
const int poly_offset = (blockIdx.x / num_blocks_per_poly) * initial_poly_size;
while (tid < stride) {
if (*challenge == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid];
else if (*challenge == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride];
else buf[buf_offset + tid] = (*challenge) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid];
if (*eval_point == fr::zero()) buf[buf_offset + tid] = polys[poly_offset + tid];
else if (*eval_point == fr::one()) buf[buf_offset + tid] = polys[poly_offset + tid + stride];
else buf[buf_offset + tid] = (*eval_point) * (polys[poly_offset + tid + stride] - polys[poly_offset + tid]) + polys[poly_offset + tid];
tid += blockDim.x * num_blocks_per_poly;
}
}
@@ -58,10 +58,3 @@ extern "C" __global__ void fold_into_half_in_place(
tid += blockDim.x * num_blocks_per_poly;
}
}
// TODO : Pass transcript and squeeze random challenge using hash function
extern "C" __global__ void squeeze_challenge(fr* challenges, unsigned int index) {
if (threadIdx.x == 0) {
challenges[index] = fr(1034);
}
}

View File

@@ -1,11 +1,12 @@
use std::cell::{RefCell, RefMut};
use cudarc::driver::{CudaView, CudaViewMut, DriverError, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaSlice, CudaView, CudaViewMut, LaunchAsync, LaunchConfig};
use ff::PrimeField;
use crate::{
fieldbinding::{FromFieldBinding, ToFieldBinding},
FieldBinding, GPUApiWrapper,
transcript::Keccak256Transcript,
FieldBinding, GPUApiWrapper, LibraryError,
};
impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
@@ -18,9 +19,10 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
polys: &mut CudaViewMut<FieldBinding>,
device_ks: &[CudaView<FieldBinding>],
buf: RefCell<CudaViewMut<FieldBinding>>,
challenges: &mut CudaViewMut<FieldBinding>,
challenge: &mut CudaSlice<FieldBinding>,
round_evals: RefCell<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
transcript: &mut Keccak256Transcript<F>,
) -> Result<(), LibraryError> {
let initial_poly_num_vars = num_vars;
for round in 0..num_vars {
self.eval_at_k_and_combine(
@@ -32,16 +34,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
device_ks,
buf.borrow_mut(),
round_evals.borrow_mut(),
transcript,
)?;
// squeeze challenge
self.squeeze_challenge(round, challenges)?;
let alpha = vec![transcript.squeeze_challenge()];
self.overwrite_to_device(alpha.as_slice(), challenge)
.map_err(|e| LibraryError::Driver(e))?;
// fold_into_half_in_place
self.fold_into_half_in_place(
initial_poly_num_vars,
round,
num_polys,
polys,
&challenges.slice(round..round + 1),
&challenge,
)?;
}
Ok(())
@@ -57,8 +62,13 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
device_ks: &[CudaView<FieldBinding>],
mut buf: RefMut<CudaViewMut<FieldBinding>>,
mut round_evals: RefMut<CudaViewMut<FieldBinding>>,
) -> Result<(), DriverError> {
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
transcript: &mut Keccak256Transcript<F>,
) -> Result<(), LibraryError> {
let num_blocks_per_poly = self
.max_blocks_per_sm()
.map_err(|e| LibraryError::Driver(e))?
/ num_polys
* self.num_sm().map_err(|e| LibraryError::Driver(e))?;
let num_threads_per_block = 1024;
for k in 0..max_degree + 1 {
let fold_into_half = self.gpu.get_func("sumcheck", "fold_into_half").unwrap();
@@ -68,17 +78,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
shared_mem_bytes: 0,
};
unsafe {
fold_into_half.launch(
launch_config,
(
initial_poly_num_vars - round,
1 << initial_poly_num_vars,
num_blocks_per_poly,
polys,
&mut *buf,
&device_ks[k],
),
)?;
fold_into_half
.launch(
launch_config,
(
initial_poly_num_vars - round,
1 << initial_poly_num_vars,
num_blocks_per_poly,
polys,
&mut *buf,
&device_ks[k],
),
)
.map_err(|e| LibraryError::Driver(e))?;
};
let size = 1 << (initial_poly_num_vars - round - 1);
let combine = self.gpu.get_func("sumcheck", "combine").unwrap();
@@ -88,7 +100,9 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
shared_mem_bytes: 0,
};
unsafe {
combine.launch(launch_config, (&mut *buf, size, num_polys))?;
combine
.launch(launch_config, (&mut *buf, size, num_polys))
.map_err(|e| LibraryError::Driver(e))?;
};
let sum = self.gpu.get_func("sumcheck", "sum").unwrap();
let launch_config = LaunchConfig {
@@ -105,26 +119,14 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
size >> 1,
round * (max_degree + 1) + k,
),
)?;
)
.map_err(|e| LibraryError::Driver(e))?;
};
}
Ok(())
}
pub(crate) fn squeeze_challenge(
&self,
round: usize,
challenges: &mut CudaViewMut<FieldBinding>,
) -> Result<(), DriverError> {
let squeeze_challenge = self.gpu.get_func("sumcheck", "squeeze_challenge").unwrap();
let launch_config = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
squeeze_challenge.launch(launch_config, (challenges, round))?;
}
let fes = self
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)
.map_err(|e| LibraryError::Driver(e))?;
transcript.write_field_elements(&fes)?;
Ok(())
}
@@ -134,13 +136,17 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
round: usize,
num_polys: usize,
polys: &mut CudaViewMut<FieldBinding>,
challenge: &CudaView<FieldBinding>,
) -> Result<(), DriverError> {
challenge: &CudaSlice<FieldBinding>,
) -> Result<(), LibraryError> {
let fold_into_half_in_place = self
.gpu
.get_func("sumcheck", "fold_into_half_in_place")
.unwrap();
let num_blocks_per_poly = self.max_blocks_per_sm()? / num_polys * self.num_sm()?;
let num_blocks_per_poly = self
.max_blocks_per_sm()
.map_err(|e| LibraryError::Driver(e))?
/ num_polys
* self.num_sm().map_err(|e| LibraryError::Driver(e))?;
let num_threads_per_block = 1024;
let launch_config = LaunchConfig {
grid_dim: ((num_blocks_per_poly * num_polys) as u32, 1, 1),
@@ -148,16 +154,18 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
shared_mem_bytes: 0,
};
unsafe {
fold_into_half_in_place.launch(
launch_config,
(
initial_poly_num_vars - round,
1 << initial_poly_num_vars,
num_blocks_per_poly,
polys,
challenge,
),
)?;
fold_into_half_in_place
.launch(
launch_config,
(
initial_poly_num_vars - round,
1 << initial_poly_num_vars,
num_blocks_per_poly,
polys,
challenge,
),
)
.map_err(|e| LibraryError::Driver(e))?;
};
Ok(())
}
@@ -167,16 +175,19 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
mod tests {
use std::{cell::RefCell, time::Instant};
use cudarc::{driver::DriverError, nvrtc::Ptx};
use cudarc::nvrtc::Ptx;
use ff::Field;
use halo2curves::bn256::Fr;
use itertools::Itertools;
use rand::rngs::OsRng;
use crate::{cpu, fieldbinding::ToFieldBinding, GPUApiWrapper, MULTILINEAR_PTX, SUMCHECK_PTX};
use crate::{
cpu, fieldbinding::ToFieldBinding, transcript::Keccak256Transcript, GPUApiWrapper,
LibraryError, MULTILINEAR_PTX, SUMCHECK_PTX,
};
#[test]
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
fn test_eval_at_k_and_combine() -> Result<(), LibraryError> {
let num_vars = 10;
let num_polys = 3;
let max_degree = 3;
@@ -198,17 +209,24 @@ mod tests {
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&["fold_into_half", "combine", "sum"],
)?;
let mut gpu_api_wrapper =
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)
.map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&["fold_into_half", "combine", "sum"],
)
.map_err(|e| LibraryError::Driver(e))?;
let mut cpu_round_evals = vec![];
let now = Instant::now();
@@ -227,19 +245,27 @@ mod tests {
);
// copy polynomials to device
let gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
let gpu_polys = gpu_api_wrapper
.copy_to_device(&polys.concat())
.map_err(|e| LibraryError::Driver(e))?;
let device_ks = (0..max_degree + 1)
.map(|k| {
gpu_api_wrapper
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
.collect::<Result<Vec<_>, _>>()
.map_err(|e| LibraryError::Driver(e))?;
let mut buf = gpu_api_wrapper
.malloc_on_device(num_polys << (num_vars - 1))
.map_err(|e| LibraryError::Driver(e))?;
let buf_view = RefCell::new(buf.slice_mut(..));
let mut round_evals = gpu_api_wrapper.malloc_on_device(max_degree as usize + 1)?;
let mut round_evals = gpu_api_wrapper
.malloc_on_device(max_degree as usize + 1)
.map_err(|e| LibraryError::Driver(e))?;
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
let round = 0;
let mut transcript = Keccak256Transcript::<Fr>::new();
let now = Instant::now();
gpu_api_wrapper.eval_at_k_and_combine(
num_vars,
@@ -253,13 +279,15 @@ mod tests {
.collect_vec(),
buf_view.borrow_mut(),
round_evals_view.borrow_mut(),
&mut transcript,
)?;
println!(
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
now.elapsed()
);
let gpu_round_evals = gpu_api_wrapper
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)?;
.dtoh_sync_copy(&round_evals.slice(0..(max_degree + 1) as usize), true)
.map_err(|e| LibraryError::Driver(e))?;
cpu_round_evals
.iter()
.zip_eq(gpu_round_evals.iter())
@@ -271,7 +299,7 @@ mod tests {
}
#[test]
fn test_fold_into_half_in_place() -> Result<(), DriverError> {
fn test_fold_into_half_in_place() -> Result<(), LibraryError> {
let num_vars = 6;
let num_polys = 4;
@@ -290,21 +318,33 @@ mod tests {
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&["fold_into_half_in_place"],
)?;
let mut gpu_api_wrapper =
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)
.map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&["fold_into_half_in_place"],
)
.map_err(|e| LibraryError::Driver(e))?;
// copy polynomials to device
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
let mut gpu_polys = gpu_api_wrapper
.copy_to_device(&polys.concat())
.map_err(|e| LibraryError::Driver(e))?;
let challenge = Fr::random(rng);
let gpu_challenge = gpu_api_wrapper.copy_to_device(&vec![challenge])?;
let gpu_challenge = gpu_api_wrapper
.copy_to_device(&vec![challenge])
.map_err(|e| LibraryError::Driver(e))?;
let round = 0;
let now = Instant::now();
@@ -313,7 +353,7 @@ mod tests {
round,
num_polys,
&mut gpu_polys.slice_mut(..),
&gpu_challenge.slice(..),
&gpu_challenge,
)?;
println!(
"Time taken to fold_into_half_in_place on gpu: {:.2?}",
@@ -327,7 +367,8 @@ mod tests {
true,
)
})
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
.collect::<Result<Vec<Vec<Fr>>, _>>()
.map_err(|e| LibraryError::Driver(e))?;
let now = Instant::now();
(0..num_polys)
@@ -348,8 +389,8 @@ mod tests {
}
#[test]
fn test_prove_sumcheck() -> Result<(), DriverError> {
let num_vars = 25;
fn test_prove_sumcheck() -> Result<(), LibraryError> {
let num_vars = 4;
let num_polys = 2;
let max_degree = 2;
@@ -368,26 +409,34 @@ mod tests {
})
.collect_vec();
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)?;
gpu_api_wrapper.gpu.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&[
"fold_into_half",
"fold_into_half_in_place",
"combine",
"sum",
"squeeze_challenge",
],
)?;
let mut gpu_api_wrapper =
GPUApiWrapper::<Fr>::setup().map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(MULTILINEAR_PTX),
"multilinear",
&["convert_to_montgomery_form"],
)
.map_err(|e| LibraryError::Driver(e))?;
gpu_api_wrapper
.gpu
.load_ptx(
Ptx::from_src(SUMCHECK_PTX),
"sumcheck",
&[
"fold_into_half",
"fold_into_half_in_place",
"combine",
"sum",
],
)
.map_err(|e| LibraryError::Driver(e))?;
let mut transcript = Keccak256Transcript::<Fr>::new();
let now = Instant::now();
let mut gpu_polys = gpu_api_wrapper.copy_to_device(&polys.concat())?;
let mut gpu_polys = gpu_api_wrapper
.copy_to_device(&polys.concat())
.map_err(|e| LibraryError::Driver(e))?;
let sum = (0..1 << num_vars).fold(Fr::ZERO, |acc, index| {
acc + polys.iter().map(|poly| poly[index]).product::<Fr>()
});
@@ -397,12 +446,19 @@ mod tests {
.gpu
.htod_copy(vec![Fr::to_montgomery_form(Fr::from(k as u64))])
})
.collect::<Result<Vec<_>, _>>()?;
let mut buf = gpu_api_wrapper.malloc_on_device(num_polys << (num_vars - 1))?;
.collect::<Result<Vec<_>, _>>()
.map_err(|e| LibraryError::Driver(e))?;
let mut buf = gpu_api_wrapper
.malloc_on_device(num_polys << (num_vars - 1))
.map_err(|e| LibraryError::Driver(e))?;
let buf_view = RefCell::new(buf.slice_mut(..));
let mut challenges = gpu_api_wrapper.malloc_on_device(num_vars)?;
let mut round_evals = gpu_api_wrapper.malloc_on_device(num_vars * (max_degree + 1))?;
let mut challenges = gpu_api_wrapper
.malloc_on_device(1)
.map_err(|e| LibraryError::Driver(e))?;
let mut round_evals = gpu_api_wrapper
.malloc_on_device(num_vars * (max_degree + 1))
.map_err(|e| LibraryError::Driver(e))?;
let round_evals_view = RefCell::new(round_evals.slice_mut(..));
println!("Time taken to copy data to device : {:.2?}", now.elapsed());
@@ -418,35 +474,34 @@ mod tests {
.map(|device_k| device_k.slice(..))
.collect_vec(),
buf_view,
&mut challenges.slice_mut(..),
&mut challenges,
round_evals_view,
&mut transcript,
)?;
gpu_api_wrapper.gpu.synchronize()?;
gpu_api_wrapper
.gpu
.synchronize()
.map_err(|e| LibraryError::Driver(e))?;
println!(
"Time taken to prove sumcheck on gpu : {:.2?}",
now.elapsed()
);
let challenges = gpu_api_wrapper.dtoh_sync_copy(&challenges.slice(..), true)?;
let round_evals = (0..num_vars)
.map(|i| {
gpu_api_wrapper.dtoh_sync_copy(
&round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
true,
)
})
.collect::<Result<Vec<Vec<Fr>>, _>>()?;
let round_evals = round_evals
.iter()
.map(|round_evals| round_evals.as_slice())
.collect_vec();
let result = cpu::sumcheck::verify_sumcheck(
num_vars,
max_degree,
sum,
&challenges[..],
&round_evals[..],
);
// let challenges = gpu_api_wrapper
// .dtoh_sync_copy(&challenges.slice(..), true)
// .map_err(|e| LibraryError::Driver(e))?;
// let round_evals = (0..num_vars)
// .map(|i| {
// gpu_api_wrapper.dtoh_sync_copy(
// &round_evals.slice(i * (max_degree + 1)..(i + 1) * (max_degree + 1)),
// true,
// )
// })
// .collect::<Result<Vec<Vec<Fr>>, _>>()
// .map_err(|e| LibraryError::Driver(e))?;
let result =
cpu::sumcheck::verify_sumcheck_transcript(num_vars, max_degree, sum, &mut transcript);
assert!(result);
Ok(())
}

View File

@@ -15,6 +15,8 @@ use std::time::Instant;
mod cpu;
pub mod fieldbinding;
pub mod gpu;
pub mod transcript;
pub mod utils;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
@@ -71,6 +73,21 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
Ok(device_data)
}
pub fn overwrite_to_device(
&self,
host_data: &[F],
dst: &mut CudaSlice<FieldBinding>,
) -> Result<(), DriverError> {
self.gpu.htod_copy_into(
host_data
.into_iter()
.map(|&eval| F::to_canonical_form(eval))
.collect(),
dst,
)?;
Ok(())
}
pub fn malloc_on_device(&self, len: usize) -> Result<CudaSlice<FieldBinding>, DriverError> {
let device_ptr = unsafe { self.gpu.alloc::<FieldBinding>(len)? };
Ok(device_ptr)
@@ -123,3 +140,9 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
Ok(self.gpu.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)? as usize)
}
}
#[derive(Debug)]
pub enum LibraryError {
Driver(DriverError),
Transcript(std::io::ErrorKind, String),
}

View File

@@ -0,0 +1,80 @@
use crate::{
fieldbinding::{FromFieldBinding, ToFieldBinding},
utils::{arithmetic::fe_mod_from_le_bytes, hash::Hash},
GPUApiWrapper, LibraryError,
};
use ff::PrimeField;
use halo2curves::serde::SerdeObject;
use sha3::Keccak256;
use std::{
io::{Cursor, Read, Write},
marker::PhantomData,
};
pub type Keccak256Transcript<F> = CudaTranscript<Keccak256, F>;
#[derive(Debug, Default)]
pub struct CudaTranscript<H, F> {
pub stream: Cursor<Vec<u8>>,
pub state: H,
marker: PhantomData<F>,
}
impl<H: Hash, F: PrimeField> CudaTranscript<H, F> {
pub fn new() -> Self {
Self::default()
}
pub fn squeeze_challenge(&mut self) -> F {
let hash = self.state.finalize_fixed_reset();
self.state.update(&hash);
fe_mod_from_le_bytes(hash)
}
fn common_field_element(&mut self, fe: &F) -> Result<(), LibraryError> {
self.state.update_field_element(fe);
Ok(())
}
fn write_field_element(&mut self, fe: &F) -> Result<(), LibraryError> {
self.common_field_element(fe)?;
let mut repr = fe.to_repr();
repr.as_mut().reverse();
self.stream
.write_all(repr.as_ref())
.map_err(|err| LibraryError::Transcript(err.kind(), err.to_string()))
}
pub fn write_field_elements<'a>(
&mut self,
fes: impl IntoIterator<Item = &'a F>,
) -> Result<(), LibraryError>
where
F: 'a,
{
for fe in fes.into_iter() {
self.write_field_element(fe)?;
}
Ok(())
}
fn read_field_element(&mut self) -> Result<F, LibraryError> {
let mut repr = <F as PrimeField>::Repr::default();
self.stream
.read_exact(repr.as_mut())
.map_err(|err| LibraryError::Transcript(err.kind(), err.to_string()))?;
repr.as_mut().reverse();
let fe = F::from_repr_vartime(repr).ok_or_else(|| {
LibraryError::Transcript(
std::io::ErrorKind::Other,
"Invalid field element encoding in proof".to_string(),
)
})?;
self.common_field_element(&fe)?;
Ok(fe)
}
pub fn read_field_elements(&mut self, n: usize) -> Result<Vec<F>, LibraryError> {
(0..n).map(|_| self.read_field_element()).collect()
}
}

2
sumcheck/src/utils.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod arithmetic;
pub mod hash;

View File

@@ -0,0 +1,18 @@
use ff::PrimeField;
use num_bigint::BigUint;
pub fn modulus<F: PrimeField>() -> BigUint {
BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64
}
pub fn fe_mod_from_le_bytes<F: PrimeField>(bytes: impl AsRef<[u8]>) -> F {
fe_from_le_bytes((BigUint::from_bytes_le(bytes.as_ref()) % modulus::<F>()).to_bytes_le())
}
pub fn fe_from_le_bytes<F: PrimeField>(bytes: impl AsRef<[u8]>) -> F {
let bytes = bytes.as_ref();
let mut repr = F::Repr::default();
assert!(bytes.len() <= repr.as_ref().len());
repr.as_mut()[..bytes.len()].copy_from_slice(bytes);
F::from_repr(repr).unwrap()
}

View File

@@ -0,0 +1,31 @@
use ff::PrimeField;
use sha3::digest::{Digest, HashMarker};
use std::fmt::Debug;
pub use sha3::{
digest::{FixedOutputReset, Output, Update},
Keccak256,
};
pub trait Hash:
'static + Sized + Clone + Debug + FixedOutputReset + Default + Update + HashMarker
{
fn new() -> Self {
Self::default()
}
fn update_field_element(&mut self, field: &impl PrimeField) {
Digest::update(self, field.to_repr());
}
fn digest(data: impl AsRef<[u8]>) -> Output<Self> {
let mut hasher = Self::default();
hasher.update(data.as_ref());
hasher.finalize()
}
}
impl<T: 'static + Sized + Clone + Debug + FixedOutputReset + Default + Update + HashMarker> Hash
for T
{
}