Merge pull request #10 from eth-applied-research-group/kw/add-enum-for-gpu

feat: allow callers to specify GPU proving
This commit is contained in:
kevaundray
2025-05-19 14:36:41 +01:00
committed by GitHub
5 changed files with 107 additions and 33 deletions

View File

@@ -5,7 +5,9 @@ use utils::{
deserialize_public_input_with_proof, package_name_from_manifest,
serialize_public_input_with_proof,
};
use zkvm_interface::{Compiler, Input, ProgramExecutionReport, ProgramProvingReport, zkVM};
use zkvm_interface::{
Compiler, Input, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM,
};
mod jolt_methods;
mod utils;
@@ -45,7 +47,10 @@ pub struct EreJolt {
impl zkVM<JOLT_TARGET> for EreJolt {
type Error = JoltError;
fn new(program: <JOLT_TARGET as Compiler>::Program) -> Self {
fn new(
program: <JOLT_TARGET as Compiler>::Program,
_resource_type: ProverResourceType,
) -> Self {
EreJolt { program: program }
}
@@ -105,7 +110,7 @@ impl zkVM<JOLT_TARGET> for EreJolt {
mod tests {
use crate::{EreJolt, JOLT_TARGET};
use std::path::PathBuf;
use zkvm_interface::{Compiler, Input, zkVM};
use zkvm_interface::{Compiler, Input, ProverResourceType, zkVM};
// TODO: for now, we just get one test file
// TODO: but this should get the whole directory and compile each test
@@ -135,7 +140,7 @@ mod tests {
let mut inputs = Input::new();
inputs.write(&(1 as u32)).unwrap();
let zkvm = EreJolt::new(program);
let zkvm = EreJolt::new(program, ProverResourceType::Cpu);
let _execution = zkvm.execute(&inputs).unwrap();
}
// #[test]

View File

@@ -11,7 +11,9 @@ use openvm_stark_sdk::config::{
baby_bear_poseidon2::BabyBearPoseidon2Engine,
};
use openvm_transpiler::elf::Elf;
use zkvm_interface::{Compiler, ProgramExecutionReport, ProgramProvingReport, zkVM};
use zkvm_interface::{
Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM,
};
mod error;
use error::{CompileError, OpenVMError, VerifyError};
@@ -48,7 +50,10 @@ pub struct EreOpenVM {
impl zkVM<OPENVM_TARGET> for EreOpenVM {
type Error = OpenVMError;
fn new(program: <OPENVM_TARGET as Compiler>::Program) -> Self {
fn new(
program: <OPENVM_TARGET as Compiler>::Program,
_resource_type: ProverResourceType,
) -> Self {
Self { program }
}
@@ -186,7 +191,7 @@ mod tests {
let test_guest_path = get_compile_test_guest_program_path();
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
let empty_input = zkvm_interface::Input::new();
let zkvm = EreOpenVM::new(elf);
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
zkvm.execute(&empty_input).unwrap();
}
@@ -198,7 +203,7 @@ mod tests {
let mut input = zkvm_interface::Input::new();
input.write(&10u64).unwrap();
let zkvm = EreOpenVM::new(elf);
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
zkvm.execute(&input).unwrap();
}
@@ -209,7 +214,7 @@ mod tests {
let mut input = zkvm_interface::Input::new();
input.write(&10u64).unwrap();
let zkvm = EreOpenVM::new(elf);
let zkvm = EreOpenVM::new(elf, ProverResourceType::Cpu);
let (proof, _) = zkvm.prove(&input).unwrap();
zkvm.verify(&proof).expect("proof should verify");
}

View File

@@ -1,6 +1,6 @@
use pico_sdk::client::DefaultProverClient;
use std::process::Command;
use zkvm_interface::{Compiler, ProgramProvingReport, zkVM};
use zkvm_interface::{Compiler, ProgramProvingReport, ProverResourceType, zkVM};
mod error;
use error::PicoError;
@@ -54,7 +54,10 @@ pub struct ErePico {
impl zkVM<PICO_TARGET> for ErePico {
type Error = PicoError;
fn new(program_bytes: <PICO_TARGET as Compiler>::Program) -> Self {
fn new(
program_bytes: <PICO_TARGET as Compiler>::Program,
_resource_type: ProverResourceType,
) -> Self {
ErePico {
program: program_bytes,
}

View File

@@ -2,17 +2,74 @@
use compile::compile_sp1_program;
use sp1_sdk::{
CpuProver, Prover, ProverClient, SP1ProofWithPublicValues, SP1ProvingKey, SP1Stdin,
CpuProver, CudaProver, Prover, ProverClient, SP1ProofWithPublicValues, SP1ProvingKey, SP1Stdin,
SP1VerifyingKey,
};
use tracing::info;
use zkvm_interface::{Compiler, ProgramExecutionReport, ProgramProvingReport, zkVM};
use zkvm_interface::{
Compiler, ProgramExecutionReport, ProgramProvingReport, ProverResourceType, zkVM,
};
mod compile;
mod error;
use error::{ExecuteError, ProveError, SP1Error, VerifyError};
enum ProverType {
Cpu(CpuProver),
Gpu(CudaProver),
}
impl ProverType {
fn setup(
&self,
program: &<RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
) -> (SP1ProvingKey, SP1VerifyingKey) {
match self {
ProverType::Cpu(cpu_prover) => cpu_prover.setup(program),
ProverType::Gpu(cuda_prover) => cuda_prover.setup(program),
}
}
fn execute(
&self,
program: &<RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
input: &SP1Stdin,
) -> Result<(sp1_sdk::SP1PublicValues, sp1_sdk::ExecutionReport), ExecuteError> {
let cpu_executor_builder = match self {
ProverType::Cpu(cpu_prover) => cpu_prover.execute(program, input),
ProverType::Gpu(cuda_prover) => cuda_prover.execute(program, input),
};
cpu_executor_builder
.run()
.map_err(|e| ExecuteError::Client(e.into()))
}
fn prove(
&self,
pk: &SP1ProvingKey,
input: &SP1Stdin,
) -> Result<SP1ProofWithPublicValues, ProveError> {
match self {
ProverType::Cpu(cpu_prover) => cpu_prover.prove(pk, input).core().run(),
ProverType::Gpu(cuda_prover) => cuda_prover.prove(pk, input).core().run(),
}
.map_err(|e| ProveError::Client(e.into()))
}
fn verify(
&self,
proof: &SP1ProofWithPublicValues,
vk: &SP1VerifyingKey,
) -> Result<(), error::SP1Error> {
match self {
ProverType::Cpu(cpu_prover) => cpu_prover.verify(proof, vk),
ProverType::Gpu(cuda_prover) => cuda_prover.verify(proof, vk),
}
.map_err(|e| SP1Error::Verify(VerifyError::Client(e.into())))
}
}
#[allow(non_camel_case_types)]
pub struct RV32_IM_SUCCINCT_ZKVM_ELF;
pub struct EreSP1 {
@@ -22,7 +79,7 @@ pub struct EreSP1 {
/// Verification key
vk: SP1VerifyingKey,
/// Proof and Verification orchestrator
client: CpuProver,
client: ProverType,
}
impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF {
@@ -38,8 +95,14 @@ impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF {
impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
type Error = SP1Error;
fn new(program: <RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program) -> Self {
let client = ProverClient::builder().cpu().build();
fn new(
program: <RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
resource: ProverResourceType,
) -> Self {
let client = match resource {
ProverResourceType::Cpu => ProverType::Cpu(ProverClient::builder().cpu().build()),
ProverResourceType::Gpu => ProverType::Gpu(ProverClient::builder().cuda().build()),
};
let (pk, vk) = client.setup(&program);
Self {
@@ -59,12 +122,7 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
stdin.write_slice(input);
}
let (_, exec_report) = self
.client
.execute(&self.program, &stdin)
.run()
.map_err(|e| ExecuteError::Client(e.into()))?;
let (_, exec_report) = self.client.execute(&self.program, &stdin)?;
let total_num_cycles = exec_report.total_instruction_count();
let region_cycles: indexmap::IndexMap<_, _> =
exec_report.cycle_tracker.into_iter().collect();
@@ -87,12 +145,7 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
}
let start = std::time::Instant::now();
let proof_with_inputs = self
.client
.prove(&self.pk, &stdin)
.core()
.run()
.map_err(|e| ProveError::Client(e.into()))?;
let proof_with_inputs = self.client.prove(&self.pk, &stdin)?;
let proving_time = start.elapsed();
let bytes = bincode::serialize(&proof_with_inputs)
@@ -147,7 +200,7 @@ mod execute_tests {
input_builder.write(&n).unwrap();
input_builder.write(&a).unwrap();
let zkvm = EreSP1::new(elf_bytes);
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
let result = zkvm.execute(&input_builder);
@@ -163,7 +216,7 @@ mod execute_tests {
let empty_input = Input::new();
let zkvm = EreSP1::new(elf_bytes);
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
let result = zkvm.execute(&empty_input);
assert!(
@@ -207,7 +260,7 @@ mod prove_tests {
input_builder.write(&n).unwrap();
input_builder.write(&a).unwrap();
let zkvm = EreSP1::new(elf_bytes);
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
let proof_bytes = match zkvm.prove(&input_builder) {
Ok((prove_result, _)) => prove_result,
@@ -232,7 +285,7 @@ mod prove_tests {
let empty_input = Input::new();
let zkvm = EreSP1::new(elf_bytes);
let zkvm = EreSP1::new(elf_bytes, ProverResourceType::Cpu);
let prove_result = zkvm.prove(&empty_input);
assert!(prove_result.is_err())
}

View File

@@ -15,12 +15,20 @@ pub trait Compiler {
fn compile(path_to_program: &Path) -> Result<Self::Program, Self::Error>;
}
/// ResourceType specifies what resource will be used to create the proofs.
#[derive(Debug, Copy, Clone, Default)]
pub enum ProverResourceType {
#[default]
Cpu,
Gpu,
}
#[allow(non_camel_case_types)]
/// zkVM trait to abstract away the differences between each zkVM
pub trait zkVM<C: Compiler> {
type Error: std::error::Error + Send + Sync + 'static;
fn new(program_bytes: C::Program) -> Self;
fn new(program_bytes: C::Program, resource: ProverResourceType) -> Self;
/// Executes the given program with the inputs accumulated in the Input struct.
/// For RISCV programs, `program_bytes` will be the ELF binary