From 536ebe374b0e89a50b93fe08f35874a6f5b12970 Mon Sep 17 00:00:00 2001 From: Kevaundray Wedderburn Date: Mon, 19 May 2025 14:16:37 +0100 Subject: [PATCH 1/2] modify interface to add ProverResourceType --- crates/zkvm-interface/src/lib.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/zkvm-interface/src/lib.rs b/crates/zkvm-interface/src/lib.rs index 398e5ff..367d9fb 100644 --- a/crates/zkvm-interface/src/lib.rs +++ b/crates/zkvm-interface/src/lib.rs @@ -15,12 +15,20 @@ pub trait Compiler { fn compile(path_to_program: &Path) -> Result; } +/// ResourceType specifies what resource will be used to create the proofs. +#[derive(Debug, Copy, Clone, Default)] +pub enum ProverResourceType { + #[default] + Cpu, + Gpu, +} + #[allow(non_camel_case_types)] /// zkVM trait to abstract away the differences between each zkVM pub trait zkVM { type Error: std::error::Error + Send + Sync + 'static; - fn new(program_bytes: C::Program) -> Self; + fn new(program_bytes: C::Program, resource: ProverResourceType) -> Self; /// Executes the given program with the inputs accumulated in the Input struct. /// For RISCV programs, `program_bytes` will be the ELF binary From c28df50599e5d5aa0739b52da5440b8b47748d1e Mon Sep 17 00:00:00 2001 From: Kevaundray Wedderburn Date: Mon, 19 May 2025 14:16:57 +0100 Subject: [PATCH 2/2] refactor zkVM impls -- only SP1 currently uses it --- crates/ere-jolt/src/lib.rs | 13 +++-- crates/ere-openvm/src/lib.rs | 15 ++++-- crates/ere-pico/src/lib.rs | 7 ++- crates/ere-sp1/src/lib.rs | 95 ++++++++++++++++++++++++++++-------- 4 files changed, 98 insertions(+), 32 deletions(-) diff --git a/crates/ere-jolt/src/lib.rs b/crates/ere-jolt/src/lib.rs index 6a5b7ae..03c3906 100644 --- a/crates/ere-jolt/src/lib.rs +++ b/crates/ere-jolt/src/lib.rs @@ -5,7 +5,9 @@ use utils::{ deserialize_public_input_with_proof, package_name_from_manifest, serialize_public_input_with_proof, }; -use zkvm_interface::{Compiler, Input, ProgramExecutionReport, ProgramProvingReport, zkVM}; +use zkvm_interface::{ + Compiler, Input, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM, +}; mod jolt_methods; mod utils; @@ -45,7 +47,10 @@ pub struct EreJolt { impl zkVM for EreJolt { type Error = JoltError; - fn new(program: ::Program) -> Self { + fn new( + program: ::Program, + _resource_type: ProverResourceType, + ) -> Self { EreJolt { program: program } } @@ -105,7 +110,7 @@ impl zkVM for EreJolt { mod tests { use crate::{EreJolt, JOLT_TARGET}; use std::path::PathBuf; - use zkvm_interface::{Compiler, Input, zkVM}; + use zkvm_interface::{Compiler, Input, ProverResourceType, zkVM}; // TODO: for now, we just get one test file // TODO: but this should get the whole directory and compile each test @@ -135,7 +140,7 @@ mod tests { let mut inputs = Input::new(); inputs.write(&(1 as u32)).unwrap(); - let zkvm = EreJolt::new(program); + let zkvm = EreJolt::new(program, ProverResourceType::Cpu); let _execution = zkvm.execute(&inputs).unwrap(); } // #[test] diff --git a/crates/ere-openvm/src/lib.rs b/crates/ere-openvm/src/lib.rs index 73e85be..ec04237 100644 --- a/crates/ere-openvm/src/lib.rs +++ b/crates/ere-openvm/src/lib.rs @@ -11,7 +11,9 @@ use openvm_stark_sdk::config::{ baby_bear_poseidon2::BabyBearPoseidon2Engine, }; use openvm_transpiler::elf::Elf; -use zkvm_interface::{Compiler, ProgramExecutionReport, ProgramProvingReport, zkVM}; +use zkvm_interface::{ + Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM, +}; mod error; use error::{CompileError, OpenVMError, VerifyError}; @@ -48,7 +50,10 @@ pub struct EreOpenVM { impl zkVM for EreOpenVM { type Error = OpenVMError; - fn new(program: ::Program) -> Self { + fn new( + program: ::Program, + _resource_type: ProverResourceType, + ) -> Self { Self { program } } @@ -186,7 +191,7 @@ mod tests { let test_guest_path = get_compile_test_guest_program_path(); let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed"); let empty_input = zkvm_interface::Input::new(); - let zkvm = EreOpenVM::new(elf); + let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu); zkvm.execute(&empty_input).unwrap(); } @@ -198,7 +203,7 @@ mod tests { let mut input = zkvm_interface::Input::new(); input.write(&10u64).unwrap(); - let zkvm = EreOpenVM::new(elf); + let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu); zkvm.execute(&input).unwrap(); } @@ -209,7 +214,7 @@ mod tests { let mut input = zkvm_interface::Input::new(); input.write(&10u64).unwrap(); - let zkvm = EreOpenVM::new(elf); + let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu); let (proof, _) = zkvm.prove(&input).unwrap(); zkvm.verify(&proof).expect("proof should verify"); } diff --git a/crates/ere-pico/src/lib.rs b/crates/ere-pico/src/lib.rs index e17c56e..9bb00e7 100644 --- a/crates/ere-pico/src/lib.rs +++ b/crates/ere-pico/src/lib.rs @@ -1,6 +1,6 @@ use pico_sdk::client::DefaultProverClient; use std::process::Command; -use zkvm_interface::{Compiler, ProgramProvingReport, zkVM}; +use zkvm_interface::{Compiler, ProgramProvingReport, ProverResourceType, zkVM}; mod error; use error::PicoError; @@ -54,7 +54,10 @@ pub struct ErePico { impl zkVM for ErePico { type Error = PicoError; - fn new(program_bytes: ::Program) -> Self { + fn new( + program_bytes: ::Program, + _resource_type: ProverResourceType, + ) -> Self { ErePico { program: program_bytes, } diff --git a/crates/ere-sp1/src/lib.rs b/crates/ere-sp1/src/lib.rs index 47a3926..4e2c7af 100644 --- a/crates/ere-sp1/src/lib.rs +++ b/crates/ere-sp1/src/lib.rs @@ -2,17 +2,74 @@ use compile::compile_sp1_program; use sp1_sdk::{ - CpuProver, Prover, ProverClient, SP1ProofWithPublicValues, SP1ProvingKey, SP1Stdin, + CpuProver, CudaProver, Prover, ProverClient, SP1ProofWithPublicValues, SP1ProvingKey, SP1Stdin, SP1VerifyingKey, }; use tracing::info; -use zkvm_interface::{Compiler, ProgramExecutionReport, ProgramProvingReport, zkVM}; +use zkvm_interface::{ + Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM, +}; mod compile; mod error; use error::{ExecuteError, ProveError, SP1Error, VerifyError}; +enum ProverType { + Cpu(CpuProver), + Gpu(CudaProver), +} + +impl ProverType { + fn setup( + &self, + program: &::Program, + ) -> (SP1ProvingKey, SP1VerifyingKey) { + match self { + ProverType::Cpu(cpu_prover) => cpu_prover.setup(program), + ProverType::Gpu(cuda_prover) => cuda_prover.setup(program), + } + } + + fn execute( + &self, + program: &::Program, + input: &SP1Stdin, + ) -> Result<(sp1_sdk::SP1PublicValues, sp1_sdk::ExecutionReport), ExecuteError> { + let cpu_executor_builder = match self { + ProverType::Cpu(cpu_prover) => cpu_prover.execute(program, input), + ProverType::Gpu(cuda_prover) => cuda_prover.execute(program, input), + }; + + cpu_executor_builder + .run() + .map_err(|e| ExecuteError::Client(e.into())) + } + fn prove( + &self, + pk: &SP1ProvingKey, + input: &SP1Stdin, + ) -> Result { + match self { + ProverType::Cpu(cpu_prover) => cpu_prover.prove(pk, input).core().run(), + ProverType::Gpu(cuda_prover) => cuda_prover.prove(pk, input).core().run(), + } + .map_err(|e| ProveError::Client(e.into())) + } + + fn verify( + &self, + proof: &SP1ProofWithPublicValues, + vk: &SP1VerifyingKey, + ) -> Result<(), error::SP1Error> { + match self { + ProverType::Cpu(cpu_prover) => cpu_prover.verify(proof, vk), + ProverType::Gpu(cuda_prover) => cuda_prover.verify(proof, vk), + } + .map_err(|e| SP1Error::Verify(VerifyError::Client(e.into()))) + } +} + #[allow(non_camel_case_types)] pub struct RV32_IM_SUCCINCT_ZKVM_ELF; pub struct EreSP1 { @@ -22,7 +79,7 @@ pub struct EreSP1 { /// Verification key vk: SP1VerifyingKey, /// Proof and Verification orchestrator - client: CpuProver, + client: ProverType, } impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF { @@ -38,8 +95,14 @@ impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF { impl zkVM for EreSP1 { type Error = SP1Error; - fn new(program: ::Program) -> Self { - let client = ProverClient::builder().cpu().build(); + fn new( + program: ::Program, + resource: ProverResourceType, + ) -> Self { + let client = match resource { + ProverResourceType::Cpu => ProverType::Cpu(ProverClient::builder().cpu().build()), + ProverResourceType::Gpu => ProverType::Gpu(ProverClient::builder().cuda().build()), + }; let (pk, vk) = client.setup(&program); Self { @@ -59,12 +122,7 @@ impl zkVM for EreSP1 { stdin.write_slice(input); } - let (_, exec_report) = self - .client - .execute(&self.program, &stdin) - .run() - .map_err(|e| ExecuteError::Client(e.into()))?; - + let (_, exec_report) = self.client.execute(&self.program, &stdin)?; let total_num_cycles = exec_report.total_instruction_count(); let region_cycles: indexmap::IndexMap<_, _> = exec_report.cycle_tracker.into_iter().collect(); @@ -87,12 +145,7 @@ impl zkVM for EreSP1 { } let start = std::time::Instant::now(); - let proof_with_inputs = self - .client - .prove(&self.pk, &stdin) - .core() - .run() - .map_err(|e| ProveError::Client(e.into()))?; + let proof_with_inputs = self.client.prove(&self.pk, &stdin)?; let proving_time = start.elapsed(); let bytes = bincode::serialize(&proof_with_inputs) @@ -147,7 +200,7 @@ mod execute_tests { input_builder.write(&n).unwrap(); input_builder.write(&a).unwrap(); - let zkvm = EreSP1::new(elf_bytes); + let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu); let result = zkvm.execute(&input_builder); @@ -163,7 +216,7 @@ mod execute_tests { let empty_input = Input::new(); - let zkvm = EreSP1::new(elf_bytes); + let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu); let result = zkvm.execute(&empty_input); assert!( @@ -207,7 +260,7 @@ mod prove_tests { input_builder.write(&n).unwrap(); input_builder.write(&a).unwrap(); - let zkvm = EreSP1::new(elf_bytes); + let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu); let proof_bytes = match zkvm.prove(&input_builder) { Ok((prove_result, _)) => prove_result, @@ -232,7 +285,7 @@ mod prove_tests { let empty_input = Input::new(); - let zkvm = EreSP1::new(elf_bytes); + let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu); let prove_result = zkvm.prove(&empty_input); assert!(prove_result.is_err()) }