add initial jolt code

This commit is contained in:
Kevaundray Wedderburn
2025-05-13 17:22:45 +01:00
parent 810471711f
commit c290578665
6 changed files with 1192 additions and 32 deletions

869
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,10 @@ members = [
"crates/zkvm-interface",
# zkVMs
"crates/ere-sp1",
"crates/ere-risczero", "crates/ere-openvm", "crates/ere-pico",
"crates/ere-risczero",
"crates/ere-openvm",
"crates/ere-pico",
"crates/ere-jolt",
]
resolver = "2"
@@ -19,3 +22,9 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
# local dependencies
zkvm-interface = { path = "crates/zkvm-interface" }
[patch.crates-io]
# These patches are only needed by Jolt
ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "v0.5.0-optimize-mul-u64" }
ark-ec = { git = "https://github.com/a16z/arkworks-algebra", branch = "v0.5.0-optimize-mul-u64" }
ark-serialize = { git = "https://github.com/a16z/arkworks-algebra", branch = "v0.5.0-optimize-mul-u64" }

View File

@@ -0,0 +1,24 @@
[package]
name = "ere-jolt"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license.workspace = true
[dependencies]
zkvm-interface = { workspace = true }
jolt-sdk = { git = "https://github.com/kevaundray/jolt", branch = "kw/ere-fork", features = [
"host",
] }
jolt-core = { git = "https://github.com/kevaundray/jolt", branch = "kw/ere-fork", features = [
"host",
] }
jolt = { git = "https://github.com/kevaundray/jolt", branch = "kw/ere-fork", features = [
"host",
] }
thiserror = "2"
toml = "0.8"
ark-serialize = "0.5.0"
[lints]
workspace = true

View File

@@ -0,0 +1,91 @@
use zkvm_interface::Input;
pub fn preprocess_prover(
program: &jolt::host::Program,
) -> jolt::JoltProverPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> {
use jolt::{Jolt, JoltProverPreprocessing, MemoryLayout, RV32IJoltVM};
let (bytecode, memory_init) = program.decode();
let memory_layout = MemoryLayout::new(4096, 4096);
let preprocessing: JoltProverPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> =
RV32IJoltVM::prover_preprocess(
bytecode,
memory_layout,
memory_init,
1 << 20,
1 << 20,
1 << 24,
);
preprocessing
}
pub fn preprocess_verifier(
program: &jolt::host::Program,
) -> jolt::JoltVerifierPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> {
use jolt::{Jolt, JoltVerifierPreprocessing, MemoryLayout, RV32IJoltVM};
let (bytecode, memory_init) = program.decode();
let memory_layout = MemoryLayout::new(4096, 4096);
let preprocessing: JoltVerifierPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> =
RV32IJoltVM::verifier_preprocess(
bytecode,
memory_layout,
memory_init,
1 << 20,
1 << 20,
1 << 24,
);
preprocessing
}
pub fn verify_generic(
proof: jolt::JoltHyperKZGProof,
// TODO: input should be private input
inputs: Input,
outputs: Input,
preprocessing: jolt::JoltVerifierPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript>,
) -> bool {
use jolt::{Jolt, RV32IJoltVM, tracer};
let preprocessing = std::sync::Arc::new(preprocessing);
let preprocessing = (*preprocessing).clone();
let mut io_device = tracer::JoltDevice::new(
preprocessing.memory_layout.max_input_size,
preprocessing.memory_layout.max_output_size,
);
io_device.inputs = inputs.bytes().to_vec();
io_device.outputs = outputs.bytes().to_vec();
RV32IJoltVM::verify(
preprocessing,
proof.proof,
proof.commitments,
io_device,
None,
)
.is_ok()
}
pub fn prove_generic(
program: &jolt::host::Program,
preprocessing: jolt::JoltProverPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript>,
inputs: &Input,
) -> (Vec<u8>, jolt::JoltHyperKZGProof) {
use jolt::{Jolt, RV32IJoltVM};
let mut program = program.clone();
// Convert inputs to a flat vector
let input_bytes = inputs.bytes().to_vec();
let (io_device, trace) = program.trace(&input_bytes);
let (jolt_proof, jolt_commitments, output_io_device, _) =
RV32IJoltVM::prove(io_device, trace, preprocessing);
let proof = jolt::JoltHyperKZGProof {
proof: jolt_proof,
commitments: jolt_commitments,
};
(output_io_device.outputs.clone(), proof)
}

154
crates/ere-jolt/src/lib.rs Normal file
View File

@@ -0,0 +1,154 @@
use jolt_core::host::Program;
use jolt_methods::{preprocess_prover, preprocess_verifier, prove_generic, verify_generic};
use jolt_sdk::host::DEFAULT_TARGET_DIR;
use utils::{
deserialize_public_input_with_proof, package_name_from_manifest,
serialize_public_input_with_proof,
};
use zkvm_interface::{Compiler, Input, ProgramExecutionReport, ProgramProvingReport, zkVM};
mod jolt_methods;
mod utils;
pub struct JOLT_TARGET;
#[derive(Debug, thiserror::Error)]
pub enum JoltError {
#[error("Proof verification failed")]
ProofVerificationFailed,
}
impl Compiler for JOLT_TARGET {
type Error = JoltError;
type Program = Program;
fn compile(path_to_program: &std::path::Path) -> Result<Self::Program, Self::Error> {
let manifest_path = path_to_program.to_path_buf().join("Cargo.toml");
let package_name = package_name_from_manifest(&manifest_path).unwrap();
let mut program = Program::new(&package_name);
program.set_manifest_path(manifest_path);
program.set_memory_size(10485760u64);
program.set_stack_size(4096u64);
program.set_max_input_size(4096u64);
program.set_max_output_size(4096u64);
// TODO: Note that if this fails, it will panic
program.build(DEFAULT_TARGET_DIR);
// Read the ELF file and return its bytes
// let elf_path = program.elf.expect("expect elf path");
// println!("{:?}", elf_path);
// let elf_bytes = std::fs::read(elf_path).unwrap();
Ok(program)
}
}
pub struct EreJolt;
impl zkVM<JOLT_TARGET> for EreJolt {
type Error = JoltError;
fn execute(
program_bytes: &<JOLT_TARGET as Compiler>::Program,
inputs: &zkvm_interface::Input,
) -> Result<zkvm_interface::ProgramExecutionReport, Self::Error> {
// TODO: check ProgramSummary
let summary = program_bytes
.clone()
.trace_analyze::<jolt::F>(inputs.bytes());
let trace_len = summary.trace_len();
Ok(ProgramExecutionReport::new(trace_len as u64))
}
fn prove(
program: &<JOLT_TARGET as Compiler>::Program,
inputs: &zkvm_interface::Input,
) -> Result<(Vec<u8>, zkvm_interface::ProgramProvingReport), Self::Error> {
// TODO: make this stateful and do in setup since its expensive and should be done once per program;
let preprocessed_key = preprocess_prover(&program);
let now = std::time::Instant::now();
let (output_bytes, proof) = prove_generic(program, preprocessed_key, inputs);
let elapsed = now.elapsed();
let proof_with_public_inputs =
serialize_public_input_with_proof(&output_bytes, &proof).unwrap();
Ok((proof_with_public_inputs, ProgramProvingReport::new(elapsed)))
}
fn verify(
program: &<JOLT_TARGET as Compiler>::Program,
proof_with_public_inputs: &[u8],
) -> Result<(), Self::Error> {
let preprocessed_verifier = preprocess_verifier(program);
let (public_inputs, proof) =
deserialize_public_input_with_proof(proof_with_public_inputs).unwrap();
let mut outputs = Input::new();
assert!(public_inputs.is_empty());
outputs.write(&public_inputs).unwrap();
// TODO: I don't think we should require the inputs when verifying
let inputs = Input::new();
let valid = verify_generic(proof, inputs, outputs, preprocessed_verifier);
if valid {
Ok(())
} else {
Err(JoltError::ProofVerificationFailed)
}
}
}
#[cfg(test)]
mod tests {
use crate::{EreJolt, JOLT_TARGET};
use std::path::PathBuf;
use zkvm_interface::{Compiler, Input, zkVM};
// TODO: for now, we just get one test file
// TODO: but this should get the whole directory and compile each test
fn get_compile_test_guest_program_path() -> PathBuf {
let workspace_dir = env!("CARGO_WORKSPACE_DIR");
PathBuf::from(workspace_dir)
.join("tests")
.join("jolt")
.join("compile")
.join("basic")
.join("guest")
.canonicalize()
.expect("Failed to find or canonicalize test guest program at <CARGO_WORKSPACE_DIR>/tests/compile/jolt")
}
#[test]
fn test_compile_trait() {
let test_guest_path = get_compile_test_guest_program_path();
let program = JOLT_TARGET::compile(&test_guest_path).unwrap();
assert!(program.elf.is_some(), "elf has not been compiled");
}
#[test]
fn test_execute() {
let test_guest_path = get_compile_test_guest_program_path();
let program = JOLT_TARGET::compile(&test_guest_path).unwrap();
let mut inputs = Input::new();
inputs.write(&(1 as u32)).unwrap();
let _execution = EreJolt::execute(&program, &inputs).unwrap();
}
#[test]
fn test_prove_verify() {
let test_guest_path = get_compile_test_guest_program_path();
let program = JOLT_TARGET::compile(&test_guest_path).unwrap();
// TODO: I don't think we should require the inputs when verifying
let inputs = Input::new();
let (proof, _) = EreJolt::prove(&program, &inputs).unwrap();
EreJolt::verify(&program, &proof).unwrap();
}
}

View File

@@ -0,0 +1,75 @@
use jolt::JoltHyperKZGProof;
use std::{fs, path::Path};
use toml::Value;
use crate::JoltError;
/// Reads the `[package] name` out of a Cargo.toml.
///
/// * `manifest_path` absolute or relative path to a Cargo.toml.
/// * Returns → `String` with the package name (`fib`, `my_guest`, …).
pub(crate) fn package_name_from_manifest(manifest_path: &Path) -> Result<String, JoltError> {
let manifest = fs::read_to_string(manifest_path).unwrap();
let value: Value = manifest.parse::<Value>().unwrap();
value
.get("package")
.and_then(|pkg| pkg.get("name"))
.and_then(Value::as_str)
.map(|s| s.to_owned())
.ok_or_else(|| panic!("no [package] name found in {}", manifest_path.display()))
}
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError};
use std::io::Cursor;
/// Serializes the public input (as raw bytes) and proof into a single byte vector
pub fn serialize_public_input_with_proof(
public_input: &Vec<u8>,
proof: &JoltHyperKZGProof,
) -> Result<Vec<u8>, SerializationError> {
let mut buffer = Vec::new();
// First, serialize the length of the public input as u64
let public_input_size = public_input.len() as u64;
public_input_size.serialize_compressed(&mut buffer)?;
// Append the public input directly (it's already bytes)
buffer.extend_from_slice(public_input);
// Now serialize the proof
let mut proof_bytes = Vec::new();
proof.serialize_compressed(&mut proof_bytes)?;
// Append the serialized proof to the buffer
buffer.extend_from_slice(&proof_bytes);
Ok(buffer)
}
/// Deserializes a byte vector into a public input (Vec<u8>) and proof
pub fn deserialize_public_input_with_proof(
bytes: &[u8],
) -> Result<(Vec<u8>, JoltHyperKZGProof), SerializationError> {
let mut cursor = Cursor::new(bytes);
// Read the size of the public input
let public_input_size: u64 = CanonicalDeserialize::deserialize_compressed(&mut cursor)?;
// Get the current position after reading the size
let current_position = cursor.position() as usize;
let public_input_end = current_position + public_input_size as usize;
if public_input_end > bytes.len() {
return Err(SerializationError::InvalidData);
}
// Extract the public input bytes directly
let public_input = bytes[current_position..public_input_end].to_vec();
// The rest is the proof
let proof_bytes = &bytes[public_input_end..];
let proof = JoltHyperKZGProof::deserialize_compressed(&mut Cursor::new(proof_bytes))?;
Ok((public_input, proof))
}