diff --git a/examples/toy_bn254.rs b/examples/toy_bn254.rs index 0271bd3..55a5811 100644 --- a/examples/toy_bn254.rs +++ b/examples/toy_bn254.rs @@ -1,14 +1,10 @@ use std::{collections::HashMap, env::current_dir, time::Instant}; use nova_scotia::{ - circom::reader::load_r1cs, create_public_params, create_recursive_circuit, FileLocation, F, S, + circom::reader::load_r1cs, continue_recursive_circuit, create_public_params, + create_recursive_circuit, FileLocation, F, S, }; -use nova_snark::{ - provider, - traits::{circuit::StepCircuit, Group}, - CompressedSNARK, PublicParams, -}; -use pasta_curves::group::ff::Field; +use nova_snark::{provider, CompressedSNARK, PublicParams}; use serde_json::json; fn run_test(circuit_filepath: String, witness_gen_filepath: String) { @@ -58,9 +54,9 @@ fn run_test(circuit_filepath: String, witness_gen_filepath: String) { println!("Creating a RecursiveSNARK..."); let start = Instant::now(); - let recursive_snark = create_recursive_circuit( - FileLocation::PathBuf(witness_generator_file), - r1cs, + let mut recursive_snark = create_recursive_circuit( + FileLocation::PathBuf(witness_generator_file.clone()), + r1cs.clone(), private_inputs, start_public_input.to_vec(), &pp, @@ -82,6 +78,11 @@ fn run_test(circuit_filepath: String, witness_gen_filepath: String) { ); assert!(res.is_ok()); + let z_last = res.unwrap().0; + + assert_eq!(z_last[0], F::::from(20)); + assert_eq!(z_last[1], F::::from(70)); + // produce a compressed SNARK println!("Generating a CompressedSNARK using Spartan with IPA-PC..."); let start = Instant::now(); @@ -110,6 +111,48 @@ fn run_test(circuit_filepath: String, witness_gen_filepath: String) { start.elapsed() ); assert!(res.is_ok()); + + // continue recursive circuit by adding 2 further steps + println!("Adding steps to our RecursiveSNARK..."); + let start = Instant::now(); + + let iteration_count_continue = 2; + + let mut private_inputs_continue = Vec::new(); + for i in 0..iteration_count_continue { + let mut private_input = HashMap::new(); + private_input.insert("adder".to_string(), json!(5 + i)); + private_inputs_continue.push(private_input); + } + + let res = continue_recursive_circuit( + &mut recursive_snark, + z_last, + FileLocation::PathBuf(witness_generator_file), + r1cs, + private_inputs_continue, + start_public_input.to_vec(), + &pp, + ); + assert!(res.is_ok()); + println!( + "Adding 2 steps to our RecursiveSNARK took {:?}", + start.elapsed() + ); + + // verify the recursive SNARK with the added steps + println!("Verifying a RecursiveSNARK..."); + let start = Instant::now(); + let res = recursive_snark.verify(&pp, iteration_count + iteration_count_continue, &start_public_input, &z0_secondary); + println!( + "RecursiveSNARK::verify: {:?}, took {:?}", + res, + start.elapsed() + ); + assert!(res.is_ok()); + + assert_eq!(res.clone().unwrap().0[0], F::::from(31)); + assert_eq!(res.unwrap().0[1], F::::from(115)); } fn main() { diff --git a/src/lib.rs b/src/lib.rs index be27caa..4e167fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -309,3 +309,128 @@ where Ok(recursive_snark) } + +#[cfg(not(target_family = "wasm"))] +pub fn continue_recursive_circuit( + recursive_snark: &mut RecursiveSNARK, C2>, + last_zi: Vec>, + witness_generator_file: FileLocation, + r1cs: R1CS>, + private_inputs: Vec>, + start_public_input: Vec>, + pp: &PublicParams, C2>, +) -> Result<(), std::io::Error> +where + G1: Group::Scalar>, + G2: Group::Scalar>, +{ + let root = current_dir().unwrap(); + let witness_generator_output = root.join("circom_witness.wtns"); + + let iteration_count = private_inputs.len(); + + let mut current_public_input = last_zi + .iter() + .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string()) + .collect::>(); + + let circuit_secondary = TrivialTestCircuit::default(); + let z0_secondary = vec![G2::Scalar::ZERO]; + + for i in 0..iteration_count { + let witness = compute_witness::( + current_public_input.clone(), + private_inputs[i].clone(), + witness_generator_file.clone(), + &witness_generator_output, + ); + + let circuit = CircomCircuit { + r1cs: r1cs.clone(), + witness: Some(witness), + }; + + let current_public_output = circuit.get_public_outputs(); + current_public_input = current_public_output + .iter() + .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string()) + .collect(); + + let res = recursive_snark.prove_step( + pp, + &circuit, + &circuit_secondary, + start_public_input.clone(), + z0_secondary.clone(), + ); + + assert!(res.is_ok()); + } + + fs::remove_file(witness_generator_output)?; + + Ok(()) +} + +#[cfg(target_family = "wasm")] +pub async fn continue_recursive_circuit( + recursive_snark: &mut RecursiveSNARK, C2>, + last_zi: Vec>, + witness_generator_file: FileLocation, + r1cs: R1CS>, + private_inputs: Vec>, + start_public_input: Vec>, + pp: &PublicParams, C2>, +) -> Result<(), std::io::Error> +where + G1: Group::Scalar>, + G2: Group::Scalar>, +{ + let root = current_dir().unwrap(); + let witness_generator_output = root.join("circom_witness.wtns"); + + let iteration_count = private_inputs.len(); + + let mut current_public_input = last_zi + .iter() + .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string()) + .collect::>(); + + let circuit_secondary = TrivialTestCircuit::default(); + let z0_secondary = vec![G2::Scalar::ZERO]; + + for i in 0..iteration_count { + let witness = compute_witness::( + current_public_input.clone(), + private_inputs[i].clone(), + witness_generator_file.clone(), + &witness_generator_output, + ) + .await; + + let circuit = CircomCircuit { + r1cs: r1cs.clone(), + witness: Some(witness), + }; + + let current_public_output = circuit.get_public_outputs(); + current_public_input = current_public_output + .iter() + .map(|&x| format!("{:?}", x).strip_prefix("0x").unwrap().to_string()) + .collect(); + + let res = recursive_snark.prove_step( + pp, + &circuit, + &circuit_secondary, + start_public_input.clone(), + z0_secondary.clone(), + ); + + assert!(res.is_ok()); + } + + fs::remove_file(witness_generator_output)?; + + Ok(()) +}