Refactor/zkvm 3 (#1684)

This commit is contained in:
Ho
2025-07-01 07:39:27 +09:00
committed by GitHub
parent 9dc57c6126
commit ea38ae7e96
60 changed files with 1495 additions and 592 deletions

View File

@@ -0,0 +1,32 @@
[patch."https://github.com/openvm-org/stark-backend.git"]
openvm-stark-backend = { git = "ssh://git@github.com/scroll-tech/openvm-stark-gpu.git", branch = "main", features = ["gpu"] }
openvm-stark-sdk = { git = "ssh://git@github.com/scroll-tech/openvm-stark-gpu.git", branch = "main", features = ["gpu"] }
[patch."https://github.com/Plonky3/Plonky3.git"]
p3-air = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-field = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-commit = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-matrix = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-baby-bear = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", features = [
"nightly-features",
], tag = "v0.2.0" }
p3-koala-bear = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-util = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-challenger = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-dft = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-fri = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-goldilocks = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-keccak = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-keccak-air = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-blake3 = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-mds = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-merkle-tree = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-monty-31 = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-poseidon = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-poseidon2 = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-poseidon2-air = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-symmetric = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-uni-stark = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }
p3-maybe-rayon = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" } # the "parallel" feature is NOT on by default to allow single-threaded benchmarking
p3-bn254-fr = { git = "ssh://git@github.com/scroll-tech/plonky3-gpu.git", tag = "v0.2.0" }

View File

@@ -1,7 +1,7 @@
[package]
name = "prover"
version = "0.1.0"
edition = "2021"
version.workspace = true
edition.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -2,12 +2,13 @@ mod prover;
mod types;
mod zk_circuits_handler;
use clap::{ArgAction, Parser};
use clap::{ArgAction, Parser, Subcommand};
use prover::{LocalProver, LocalProverConfig};
use scroll_proving_sdk::{
prover::ProverBuilder,
prover::{types::ProofType, ProverBuilder},
utils::{get_version, init_tracing},
};
use std::{fs::File, path::Path};
#[derive(Parser, Debug)]
#[command(disable_version_flag = true)]
@@ -16,6 +17,9 @@ struct Args {
#[arg(long = "config", default_value = "conf/config.json")]
config_file: String,
#[arg(long = "forkname")]
fork_name: Option<String>,
/// Version of this prover
#[arg(short, long, action = ArgAction::SetTrue)]
version: bool,
@@ -23,6 +27,40 @@ struct Args {
/// Path of log file
#[arg(long = "log.file")]
log_file: Option<String>,
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand, Debug)]
enum Commands {
/// Dump vk of this prover
Dump {
/// File to save the vks
file_name: String,
},
}
fn dump_vk(file: &Path, prover: &LocalProver, fork_name: &str) -> eyre::Result<()> {
let f = File::create(file)?;
#[derive(Debug, serde::Serialize)]
struct VKDump {
pub chunk_vk: String,
pub batch_vk: String,
pub bundle_vk: String,
}
let handler = prover.new_handler(fork_name);
let dump = VKDump {
chunk_vk: handler.get_vk(ProofType::Chunk),
batch_vk: handler.get_vk(ProofType::Batch),
bundle_vk: handler.get_vk(ProofType::Bundle),
};
serde_json::to_writer(f, &dump)?;
Ok(())
}
#[tokio::main]
@@ -37,14 +75,25 @@ async fn main() -> eyre::Result<()> {
}
let cfg = LocalProverConfig::from_file(args.config_file)?;
let default_fork_name = cfg.circuits.keys().next().unwrap().clone();
let sdk_config = cfg.sdk_config.clone();
let local_prover = LocalProver::new(cfg);
let prover = ProverBuilder::new(sdk_config, local_prover)
.build()
.await
.map_err(|e| eyre::eyre!("build prover fail: {e}"))?;
let local_prover = LocalProver::new(cfg.clone());
prover.run().await;
match args.command {
Some(Commands::Dump { file_name }) => {
let fork_name = args.fork_name.unwrap_or(default_fork_name);
println!("dump vk for {fork_name}");
dump_vk(Path::new(&file_name), &local_prover, &fork_name)?;
}
None => {
let prover = ProverBuilder::new(sdk_config, local_prover)
.build()
.await
.map_err(|e| eyre::eyre!("build prover fail: {e}"))?;
prover.run().await;
}
}
Ok(())
}

View File

@@ -1,6 +1,5 @@
use crate::zk_circuits_handler::{euclidV2::EuclidV2Handler, CircuitsHandler};
use async_trait::async_trait;
use base64::{prelude::BASE64_STANDARD, Engine};
use eyre::Result;
use scroll_proving_sdk::{
config::Config as SdkConfig,
@@ -9,6 +8,7 @@ use scroll_proving_sdk::{
GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest,
QueryTaskResponse, TaskStatus,
},
types::ProofType,
ProvingService,
},
};
@@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
fs::File,
sync::Arc,
sync::{Arc, OnceLock},
time::{SystemTime, UNIX_EPOCH},
};
use tokio::{runtime::Handle, sync::Mutex, task::JoinHandle};
@@ -45,6 +45,9 @@ impl LocalProverConfig {
pub struct CircuitConfig {
pub hard_fork_name: String,
pub workspace_path: String,
/// cached vk value to save some initial cost, for debugging only
#[serde(default)]
pub vks: HashMap<ProofType, String>,
}
pub struct LocalProver {
@@ -52,7 +55,7 @@ pub struct LocalProver {
next_task_id: u64,
current_task: Option<JoinHandle<Result<String>>>,
active_handler: Option<(String, Arc<dyn CircuitsHandler>)>,
handlers: HashMap<String, OnceLock<Arc<dyn CircuitsHandler>>>,
}
#[async_trait]
@@ -62,13 +65,13 @@ impl ProvingService for LocalProver {
}
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse {
let mut vks = vec![];
for hard_fork_name in self.config.circuits.keys() {
let handler = self.new_handler(hard_fork_name);
for (hard_fork_name, cfg) in self.config.circuits.iter() {
for proof_type in &req.proof_types {
let vk = handler.get_vk(*proof_type).await;
if let Some(vk) = vk {
vks.push(BASE64_STANDARD.encode(vk));
if let Some(vk) = cfg.vks.get(proof_type) {
vks.push(vk.clone())
} else {
let handler = self.get_or_init_handler(hard_fork_name);
vks.push(handler.get_vk(*proof_type));
}
}
}
@@ -76,11 +79,8 @@ impl ProvingService for LocalProver {
GetVkResponse { vks, error: None }
}
async fn prove(&mut self, req: ProveRequest) -> ProveResponse {
self.set_active_handler(&req.hard_fork_name);
match self
.do_prove(req, self.active_handler.as_ref().unwrap().1.clone())
.await
{
let handler = self.get_or_init_handler(&req.hard_fork_name);
match self.do_prove(req, handler).await {
Ok(resp) => resp,
Err(e) => ProveResponse {
status: TaskStatus::Failed,
@@ -133,11 +133,16 @@ impl ProvingService for LocalProver {
impl LocalProver {
pub fn new(config: LocalProverConfig) -> Self {
let handlers = config
.circuits
.keys()
.map(|k| (k.clone(), OnceLock::new()))
.collect();
Self {
config,
next_task_id: 0,
current_task: None,
active_handler: None,
handlers,
}
}
@@ -168,25 +173,25 @@ impl LocalProver {
})
}
fn set_active_handler(&mut self, hard_fork_name: &str) {
if let Some(handler) = &self.active_handler {
if handler.0 == hard_fork_name {
return;
}
}
self.active_handler = Some((hard_fork_name.to_string(), self.new_handler(hard_fork_name)));
fn get_or_init_handler(&self, hard_fork_name: &str) -> Arc<dyn CircuitsHandler> {
let lk = self
.handlers
.get(hard_fork_name)
.expect("coordinator should never sent unexpected forkname");
lk.get_or_init(|| self.new_handler(hard_fork_name)).clone()
}
fn new_handler(&self, hard_fork_name: &str) -> Arc<dyn CircuitsHandler> {
pub fn new_handler(&self, hard_fork_name: &str) -> Arc<dyn CircuitsHandler> {
// if we got assigned a task for an unknown hard fork, there is something wrong in the
// coordinator
let config = self.config.circuits.get(hard_fork_name).unwrap();
match hard_fork_name {
"euclidV2" => Arc::new(Arc::new(Mutex::new(EuclidV2Handler::new(
&config.workspace_path,
)))) as Arc<dyn CircuitsHandler>,
_ => unreachable!(),
// The new EuclidV2Handler is a universal handler
// We can add other handler implements if needed
"some future forkname" => unreachable!(),
_ => Arc::new(Arc::new(Mutex::new(EuclidV2Handler::new(config))))
as Arc<dyn CircuitsHandler>,
}
}
}

View File

@@ -11,7 +11,7 @@ use std::path::Path;
#[async_trait]
pub trait CircuitsHandler: Sync + Send {
async fn get_vk(&self, task_type: ProofType) -> Option<Vec<u8>>;
fn get_vk(&self, task_type: ProofType) -> String;
async fn get_proof_data(&self, prove_request: ProveRequest) -> Result<String>;
}
@@ -54,14 +54,12 @@ impl Phase {
let dir_cache = Some(workspace_path.join("cache"));
let path_app_config = workspace_path.join("bundle/openvm.toml");
let segment_len = Some((1 << 22) - 100);
match self {
Phase::EuclidV2 => ProverConfig {
dir_cache,
path_app_config,
segment_len,
path_app_exe: workspace_path.join("bundle/app.vmexe"),
..Default::default()
},
ProverConfig {
dir_cache,
path_app_config,
segment_len,
path_app_exe: workspace_path.join("bundle/app.vmexe"),
..Default::default()
}
}
}

View File

@@ -1,7 +1,13 @@
use std::{path::Path, sync::Arc};
use std::{
collections::HashMap,
path::Path,
sync::{Arc, OnceLock},
};
use super::{CircuitsHandler, Phase};
use crate::prover::CircuitConfig;
use async_trait::async_trait;
use base64::{prelude::BASE64_STANDARD, Engine};
use eyre::Result;
use scroll_proving_sdk::prover::{proving_service::ProveRequest, ProofType};
use scroll_zkvm_prover_euclid::{BatchProver, BundleProverEuclidV2, ChunkProver};
@@ -11,12 +17,14 @@ pub struct EuclidV2Handler {
chunk_prover: ChunkProver,
batch_prover: BatchProver,
bundle_prover: BundleProverEuclidV2,
cached_vks: HashMap<ProofType, OnceLock<String>>,
}
unsafe impl Send for EuclidV2Handler {}
impl EuclidV2Handler {
pub fn new(workspace_path: &str) -> Self {
pub fn new(cfg: &CircuitConfig) -> Self {
let workspace_path = &cfg.workspace_path;
let p = Phase::EuclidV2;
let workspace_path = Path::new(workspace_path);
let chunk_prover = ChunkProver::setup(p.phase_spec_chunk(workspace_path))
@@ -28,46 +36,80 @@ impl EuclidV2Handler {
let bundle_prover = BundleProverEuclidV2::setup(p.phase_spec_bundle(workspace_path))
.expect("Failed to setup bundle prover");
let build_vk_cache = |proof_type: ProofType| {
let vk = if let Some(vk) = cfg.vks.get(&proof_type) {
OnceLock::from(vk.clone())
} else {
OnceLock::new()
};
(proof_type, vk)
};
Self {
chunk_prover,
batch_prover,
bundle_prover,
cached_vks: HashMap::from([
build_vk_cache(ProofType::Chunk),
build_vk_cache(ProofType::Batch),
build_vk_cache(ProofType::Bundle),
]),
}
}
pub fn get_vk_and_cache(&self, task_type: ProofType) -> String {
match task_type {
ProofType::Chunk => self.cached_vks[&ProofType::Chunk]
.get_or_init(|| BASE64_STANDARD.encode(self.chunk_prover.get_app_vk())),
ProofType::Batch => self.cached_vks[&ProofType::Batch]
.get_or_init(|| BASE64_STANDARD.encode(self.batch_prover.get_app_vk())),
ProofType::Bundle => self.cached_vks[&ProofType::Bundle]
.get_or_init(|| BASE64_STANDARD.encode(self.bundle_prover.get_evm_vk())),
_ => unreachable!("Unsupported proof type {:?}", task_type),
}
.clone()
}
}
#[async_trait]
impl CircuitsHandler for Arc<Mutex<EuclidV2Handler>> {
async fn get_vk(&self, task_type: ProofType) -> Option<Vec<u8>> {
Some(match task_type {
ProofType::Chunk => self.try_lock().unwrap().chunk_prover.get_app_vk(),
ProofType::Batch => self.try_lock().unwrap().batch_prover.get_app_vk(),
ProofType::Bundle => self.try_lock().unwrap().bundle_prover.get_app_vk(),
_ => unreachable!("Unsupported proof type"),
})
fn get_vk(&self, task_type: ProofType) -> String {
self.try_lock()
.expect("get vk is on called before other entry is used")
.get_vk_and_cache(task_type)
}
async fn get_proof_data(&self, prove_request: ProveRequest) -> Result<String> {
let handler_self = self.lock().await;
let u_task: ProvingTask = serde_json::from_str(&prove_request.input)?;
let expected_vk = handler_self.get_vk_and_cache(prove_request.proof_type);
if BASE64_STANDARD.encode(&u_task.vk) != expected_vk {
eyre::bail!(
"vk is not match!, prove type {:?}, expected {}, get {}",
prove_request.proof_type,
expected_vk,
BASE64_STANDARD.encode(&u_task.vk),
);
}
let proof = match prove_request.proof_type {
ProofType::Chunk => self
.try_lock()
.unwrap()
ProofType::Chunk => handler_self
.chunk_prover
.gen_proof_universal(&u_task, false)?,
ProofType::Batch => self
.try_lock()
.unwrap()
ProofType::Batch => handler_self
.batch_prover
.gen_proof_universal(&u_task, false)?,
ProofType::Bundle => self
.try_lock()
.unwrap()
ProofType::Bundle => handler_self
.bundle_prover
.gen_proof_universal(&u_task, true)?,
_ => return Err(eyre::eyre!("Unsupported proof type")),
_ => {
return Err(eyre::eyre!(
"Unsupported proof type {:?}",
prove_request.proof_type
))
}
};
//TODO: check expected PI
Ok(serde_json::to_string(&proof)?)
}
}