mirror of
https://github.com/pseXperiments/cuda-sumcheck.git
synced 2026-01-09 15:38:01 -05:00
Remove unnecessary synchronization and time consuming data generation from test code
This commit is contained in:
@@ -178,14 +178,24 @@ mod tests {
|
||||
#[test]
|
||||
fn test_eval_at_k_and_combine() -> Result<(), DriverError> {
|
||||
let num_vars = 10;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
let num_polys = 3;
|
||||
let max_degree = 3;
|
||||
let rng = OsRng::default();
|
||||
|
||||
let combine_function = |args: &Vec<Fr>| args.iter().product();
|
||||
|
||||
let polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.map(|_| {
|
||||
(0..1 << num_vars)
|
||||
.map(|i| {
|
||||
if i < 1024 {
|
||||
Fr::random(rng)
|
||||
} else {
|
||||
Fr::from(i)
|
||||
}
|
||||
})
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
@@ -233,7 +243,6 @@ mod tests {
|
||||
buf_view.borrow_mut(),
|
||||
round_evals_view.borrow_mut(),
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to eval_at_k_and_combine on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
@@ -257,7 +266,17 @@ mod tests {
|
||||
|
||||
let rng = OsRng::default();
|
||||
let mut polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.map(|_| {
|
||||
(0..1 << num_vars)
|
||||
.map(|i| {
|
||||
if i < 1024 {
|
||||
Fr::random(rng)
|
||||
} else {
|
||||
Fr::from(i)
|
||||
}
|
||||
})
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
@@ -285,7 +304,6 @@ mod tests {
|
||||
&mut gpu_polys.slice_mut(..),
|
||||
&gpu_challenge.slice(..),
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to fold_into_half_in_place on gpu: {:.2?}",
|
||||
now.elapsed()
|
||||
@@ -320,13 +338,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_prove_sumcheck() -> Result<(), DriverError> {
|
||||
let num_vars = 12;
|
||||
let num_vars = 23;
|
||||
let num_polys = 4;
|
||||
let max_degree = 4;
|
||||
|
||||
let rng = OsRng::default();
|
||||
let polys = (0..num_polys)
|
||||
.map(|_| (0..1 << num_vars).map(|_| Fr::random(rng)).collect_vec())
|
||||
.map(|_| {
|
||||
(0..1 << num_vars)
|
||||
.map(|i| {
|
||||
if i < 1024 {
|
||||
Fr::random(rng)
|
||||
} else {
|
||||
Fr::from(i)
|
||||
}
|
||||
})
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let mut gpu_api_wrapper = GPUApiWrapper::<Fr>::setup()?;
|
||||
@@ -371,7 +399,6 @@ mod tests {
|
||||
&mut challenges.slice_mut(..),
|
||||
round_evals_view,
|
||||
)?;
|
||||
gpu_api_wrapper.gpu.synchronize()?;
|
||||
println!(
|
||||
"Time taken to prove sumcheck on gpu : {:.2?}",
|
||||
now.elapsed()
|
||||
|
||||
@@ -79,11 +79,11 @@ impl<F: PrimeField + FromFieldBinding<F> + ToFieldBinding<F>> GPUApiWrapper<F> {
|
||||
pub fn dtoh_sync_copy(
|
||||
&self,
|
||||
device_data: CudaView<FieldBinding>,
|
||||
convert_to_montgomery_form: bool,
|
||||
convert_from_montgomery_form: bool,
|
||||
) -> Result<Vec<F>, DriverError> {
|
||||
let host_data = self.gpu.dtoh_sync_copy(&device_data)?;
|
||||
let mut target = vec![F::ZERO; host_data.len()];
|
||||
if convert_to_montgomery_form {
|
||||
if convert_from_montgomery_form {
|
||||
parallelize(&mut target, |(target, start)| {
|
||||
target
|
||||
.iter_mut()
|
||||
|
||||
Reference in New Issue
Block a user