make API stateful

This commit is contained in:
Kevaundray Wedderburn
2025-05-14 14:04:18 +01:00
parent cd63d58511
commit 99ad510dae
6 changed files with 84 additions and 62 deletions

View File

@@ -79,9 +79,9 @@ let guest = std::path::Path::new("guest/hello");
let elf = RV32_IM_SUCCINCT_ZKVM_ELF::compile(guest)?; // compile
let mut io = Input::new();
io.write(&42u32)?;
let (proof, _report) = EreSP1::prove(&elf, &io)?; // prove
EreSP1::verify(&elf, &proof)?; // verify
let zkvm = EreSP1::new(elf);
let (proof, _report) = zkvm.prove(&io)?; // prove
zkvm.verify(&elf, &proof)?; // verify
```
### 4. Run the Test Suite

View File

@@ -38,17 +38,24 @@ impl Compiler for JOLT_TARGET {
}
}
pub struct EreJolt;
pub struct EreJolt {
program: <JOLT_TARGET as Compiler>::Program,
}
impl zkVM<JOLT_TARGET> for EreJolt {
type Error = JoltError;
fn new(program: <JOLT_TARGET as Compiler>::Program) -> Self {
EreJolt { program: program }
}
fn execute(
program_bytes: &<JOLT_TARGET as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<zkvm_interface::ProgramExecutionReport, Self::Error> {
// TODO: check ProgramSummary
let summary = program_bytes
let summary = self
.program
.clone()
.trace_analyze::<jolt::F>(inputs.bytes());
let trace_len = summary.trace_len();
@@ -57,14 +64,14 @@ impl zkVM<JOLT_TARGET> for EreJolt {
}
fn prove(
program: &<JOLT_TARGET as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), Self::Error> {
// TODO: make this stateful and do in setup since its expensive and should be done once per program;
let preprocessed_key = preprocess_prover(&program);
let preprocessed_key = preprocess_prover(&self.program);
let now = std::time::Instant::now();
let (output_bytes, proof) = prove_generic(program, preprocessed_key, inputs);
let (output_bytes, proof) = prove_generic(&self.program, preprocessed_key, inputs);
let elapsed = now.elapsed();
let proof_with_public_inputs =
@@ -73,11 +80,8 @@ impl zkVM<JOLT_TARGET> for EreJolt {
Ok((proof_with_public_inputs, ProgramProvingReport::new(elapsed)))
}
fn verify(
program: &<JOLT_TARGET as Compiler>::Program,
proof_with_public_inputs: &[u8],
) -> Result<(), Self::Error> {
let preprocessed_verifier = preprocess_verifier(program);
fn verify(&self, proof_with_public_inputs: &[u8]) -> Result<(), Self::Error> {
let preprocessed_verifier = preprocess_verifier(&self.program);
let (public_inputs, proof) =
deserialize_public_input_with_proof(proof_with_public_inputs).unwrap();
@@ -131,7 +135,8 @@ mod tests {
let mut inputs = Input::new();
inputs.write(&(1 as u32)).unwrap();
let _execution = EreJolt::execute(&program, &inputs).unwrap();
let zkvm = EreJolt::new(program);
let _execution = zkvm.execute(&inputs).unwrap();
}
// #[test]
// fn test_prove_verify() {

View File

@@ -41,13 +41,19 @@ impl Compiler for OPENVM_TARGET {
}
}
pub struct EreOpenVM;
pub struct EreOpenVM {
program: <OPENVM_TARGET as Compiler>::Program,
}
impl zkVM<OPENVM_TARGET> for EreOpenVM {
type Error = OpenVMError;
fn new(program: <OPENVM_TARGET as Compiler>::Program) -> Self {
Self { program }
}
fn execute(
program: &<OPENVM_TARGET as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<zkvm_interface::ProgramExecutionReport, Self::Error> {
let sdk = Sdk::new();
@@ -59,7 +65,7 @@ impl zkVM<OPENVM_TARGET> for EreOpenVM {
.build();
let exe = sdk
.transpile(program.clone(), vm_cfg.transpiler())
.transpile(self.program.clone(), vm_cfg.transpiler())
.map_err(|e| CompileError::Client(e.into()))?;
let mut stdin = StdIn::default();
@@ -75,7 +81,7 @@ impl zkVM<OPENVM_TARGET> for EreOpenVM {
}
fn prove(
program: &<OPENVM_TARGET as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), Self::Error> {
// TODO: We need a stateful version in order to not spend a lot of time
@@ -90,7 +96,7 @@ impl zkVM<OPENVM_TARGET> for EreOpenVM {
.build();
let app_exe = sdk
.transpile(program.clone(), vm_cfg.transpiler())
.transpile(self.program.clone(), vm_cfg.transpiler())
.map_err(|e| CompileError::Client(e.into()))?;
let mut stdin = StdIn::default();
@@ -119,10 +125,7 @@ impl zkVM<OPENVM_TARGET> for EreOpenVM {
Ok((proof_bytes, ProgramProvingReport::new(elapsed)))
}
fn verify(
_program: &<OPENVM_TARGET as Compiler>::Program,
mut proof: &[u8],
) -> Result<(), Self::Error> {
fn verify(&self, mut proof: &[u8]) -> Result<(), Self::Error> {
let sdk = Sdk::new();
let vm_cfg = SdkVmConfig::builder()
.system(Default::default())
@@ -183,8 +186,9 @@ mod tests {
let test_guest_path = get_compile_test_guest_program_path();
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
let empty_input = zkvm_interface::Input::new();
let zkvm = EreOpenVM::new(elf);
EreOpenVM::execute(&elf, &empty_input).unwrap();
zkvm.execute(&empty_input).unwrap();
}
#[test]
@@ -193,7 +197,9 @@ mod tests {
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
let mut input = zkvm_interface::Input::new();
input.write(&10u64).unwrap();
EreOpenVM::execute(&elf, &input).unwrap();
let zkvm = EreOpenVM::new(elf);
zkvm.execute(&input).unwrap();
}
#[test]
@@ -202,7 +208,9 @@ mod tests {
let elf = OPENVM_TARGET::compile(&test_guest_path).expect("compilation failed");
let mut input = zkvm_interface::Input::new();
input.write(&10u64).unwrap();
let (proof, _) = EreOpenVM::prove(&elf, &input).unwrap();
EreOpenVM::verify(&elf, &proof).expect("proof should verify");
let zkvm = EreOpenVM::new(elf);
let (proof, _) = zkvm.prove(&input).unwrap();
zkvm.verify(&proof).expect("proof should verify");
}
}

View File

@@ -47,23 +47,31 @@ impl Compiler for PICO_TARGET {
}
}
pub struct ErePico;
pub struct ErePico {
program: <PICO_TARGET as Compiler>::Program,
}
impl zkVM<PICO_TARGET> for ErePico {
type Error = PicoError;
fn new(program_bytes: <PICO_TARGET as Compiler>::Program) -> Self {
ErePico {
program: program_bytes,
}
}
fn execute(
_program_bytes: &<PICO_TARGET as Compiler>::Program,
&self,
_inputs: &zkvm_interface::Input,
) -> Result<zkvm_interface::ProgramExecutionReport, Self::Error> {
todo!("pico currently does not have an execute method exposed via the SDK")
}
fn prove(
program_bytes: &<PICO_TARGET as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), Self::Error> {
let client = DefaultProverClient::new(program_bytes);
let client = DefaultProverClient::new(&self.program);
let mut stdin = client.new_stdin_builder();
for input in inputs.chunked_iter() {
@@ -91,11 +99,8 @@ impl zkVM<PICO_TARGET> for ErePico {
Ok((proof_serialized, ProgramProvingReport::new(elapsed)))
}
fn verify(
program_bytes: &<PICO_TARGET as Compiler>::Program,
_proof: &[u8],
) -> Result<(), Self::Error> {
let client = DefaultProverClient::new(program_bytes);
fn verify(&self, _proof: &[u8]) -> Result<(), Self::Error> {
let client = DefaultProverClient::new(&self.program);
let _vk = client.riscv_vk();

View File

@@ -13,7 +13,9 @@ use error::{ExecuteError, ProveError, SP1Error, VerifyError};
#[allow(non_camel_case_types)]
pub struct RV32_IM_SUCCINCT_ZKVM_ELF;
pub struct EreSP1;
pub struct EreSP1 {
program: <RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
}
impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF {
type Error = SP1Error;
@@ -28,8 +30,12 @@ impl Compiler for RV32_IM_SUCCINCT_ZKVM_ELF {
impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
type Error = SP1Error;
fn new(program: <RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program) -> Self {
Self { program }
}
fn execute(
program_bytes: &<RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<zkvm_interface::ProgramExecutionReport, Self::Error> {
// TODO: This is expensive, should move it out and make the struct stateful
@@ -41,7 +47,7 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
}
let (_, exec_report) = client
.execute(&program_bytes, &stdin)
.execute(&self.program, &stdin)
.run()
.map_err(|e| ExecuteError::Client(e.into()))?;
@@ -51,7 +57,7 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
}
fn prove(
program_bytes: &<RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
&self,
inputs: &zkvm_interface::Input,
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), Self::Error> {
info!("Generating proof…");
@@ -59,7 +65,7 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
// TODO: This is expensive, should move it out and make the struct stateful
let client = ProverClient::builder().cpu().build();
// TODO: This can also be cached
let (pk, _vk) = client.setup(&program_bytes);
let (pk, _vk) = client.setup(&self.program);
let mut stdin = SP1Stdin::new();
for input in inputs.chunked_iter() {
@@ -80,14 +86,11 @@ impl zkVM<RV32_IM_SUCCINCT_ZKVM_ELF> for EreSP1 {
Ok((bytes, ProgramProvingReport::new(proving_time)))
}
fn verify(
program_bytes: &<RV32_IM_SUCCINCT_ZKVM_ELF as Compiler>::Program,
proof: &[u8],
) -> Result<(), Self::Error> {
fn verify(&self, proof: &[u8]) -> Result<(), Self::Error> {
info!("Verifying proof…");
let client = ProverClient::from_env();
let (_pk, vk) = client.setup(&program_bytes);
let (_pk, vk) = client.setup(&self.program);
let proof: SP1ProofWithPublicValues = bincode::deserialize(proof)
.map_err(|err| SP1Error::Verify(VerifyError::Bincode(err)))?;
@@ -132,7 +135,9 @@ mod execute_tests {
input_builder.write(&n).unwrap();
input_builder.write(&a).unwrap();
let result = EreSP1::execute(&elf_bytes, &input_builder);
let zkvm = EreSP1::new(elf_bytes);
let result = zkvm.execute(&input_builder);
if let Err(e) = &result {
panic!("Execution error: {:?}", e);
@@ -146,7 +151,8 @@ mod execute_tests {
let empty_input = Input::new();
let result = EreSP1::execute(&elf_bytes, &empty_input);
let zkvm = EreSP1::new(elf_bytes);
let result = zkvm.execute(&empty_input);
assert!(
result.is_err(),
@@ -189,7 +195,9 @@ mod prove_tests {
input_builder.write(&n).unwrap();
input_builder.write(&a).unwrap();
let proof_bytes = match EreSP1::prove(&elf_bytes, &input_builder) {
let zkvm = EreSP1::new(elf_bytes);
let proof_bytes = match zkvm.prove(&input_builder) {
Ok((prove_result, _)) => prove_result,
Err(err) => {
panic!("Proving error in test: {:?}", err);
@@ -198,7 +206,7 @@ mod prove_tests {
assert!(!proof_bytes.is_empty(), "Proof bytes should not be empty.");
let verify_results = EreSP1::verify(&elf_bytes, &proof_bytes).is_ok();
let verify_results = zkvm.verify(&proof_bytes).is_ok();
assert!(verify_results);
// TODO: Check public inputs
@@ -212,7 +220,8 @@ mod prove_tests {
let empty_input = Input::new();
let prove_result = EreSP1::prove(&elf_bytes, &empty_input);
let zkvm = EreSP1::new(elf_bytes);
let prove_result = zkvm.prove(&empty_input);
assert!(prove_result.is_err())
}
}

View File

@@ -8,7 +8,6 @@ pub use input::Input;
/// Compiler trait for compiling programs into an opaque sequence of bytes.
pub trait Compiler {
type Error: std::error::Error + Send + Sync + 'static;
// TODO: check if this can be removed and we just use bytes
type Program: Clone + Send + Sync;
/// Compiles the program and returns the program
@@ -20,25 +19,21 @@ pub trait Compiler {
pub trait zkVM<C: Compiler> {
type Error: std::error::Error + Send + Sync + 'static;
fn new(program_bytes: C::Program) -> Self;
/// Executes the given program with the inputs accumulated in the Input struct.
/// For RISCV programs, `program_bytes` will be the ELF binary
fn execute(
program_bytes: &C::Program,
inputs: &Input,
) -> Result<ProgramExecutionReport, Self::Error>;
fn execute(&self, inputs: &Input) -> Result<ProgramExecutionReport, Self::Error>;
/// Creates a proof for a given program
fn prove(
program_bytes: &C::Program,
inputs: &Input,
) -> Result<(Vec<u8>, ProgramProvingReport), Self::Error>;
fn prove(&self, inputs: &Input) -> Result<(Vec<u8>, ProgramProvingReport), Self::Error>;
/// Verifies a proof for the given program
/// TODO: Pass public inputs too and check that they match if they come with the
/// TODO: proof, or append them if they do not.
/// TODO: We can also just have this return the public inputs, but then the user needs
/// TODO: ensure they check it for correct #[must_use]
fn verify(program_bytes: &C::Program, proof: &[u8]) -> Result<(), Self::Error>;
fn verify(&self, proof: &[u8]) -> Result<(), Self::Error>;
}
/// ProgramExecutionReport produces information about a particular program