Support --program-path flag in ere-server (#270)

This commit is contained in:
Han
2026-01-16 23:15:21 +09:00
committed by GitHub
parent 0e41aa7e95
commit a9fb04dea8
10 changed files with 264 additions and 67 deletions

35
Cargo.lock generated
View File

@@ -3871,7 +3871,6 @@ dependencies = [
"ere-server",
"ere-test-utils",
"ere-zkvm-interface",
"parking_lot",
"paste",
"serde",
"tempfile",
@@ -4193,9 +4192,12 @@ dependencies = [
"clap",
"indexmap 2.10.0",
"serde",
"serde-untagged",
"serde_json",
"serde_yaml",
"strum 0.27.2",
"thiserror 2.0.12",
"toml 0.8.23",
]
[[package]]
@@ -12717,6 +12719,18 @@ dependencies = [
"serde",
]
[[package]]
name = "serde-untagged"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9faf48a4a2d2693be24c6289dbe26552776eb7737074e6722891fadbe6c5058"
dependencies = [
"erased-serde",
"serde",
"serde_core",
"typeid",
]
[[package]]
name = "serde_arrays"
version = "0.1.0"
@@ -12847,6 +12861,19 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap 2.10.0",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]]
name = "serdect"
version = "0.2.0"
@@ -15136,6 +15163,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "untrusted"
version = "0.7.1"

View File

@@ -58,7 +58,6 @@ dashmap = "6.1.0"
digest = { version = "0.10.7", default-features = false }
eyre = "0.6.12"
indexmap = "2.10.0"
parking_lot = "0.12.5"
paste = "1.0.15"
postcard = { version = "1.0.8", default-features = false }
prost = "0.13"
@@ -68,6 +67,7 @@ rkyv = { version = "0.8.12", default-features = false }
serde = { version = "1.0.219", default-features = false }
serde_bytes = { version = "0.11.19", default-features = false }
serde_json = "1.0.142"
serde-untagged = "0.1"
serde_yaml = "0.9.34"
sha2 = { version = "0.10.9", default-features = false }
strum = "0.27.2"

View File

@@ -7,7 +7,6 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
parking_lot.workspace = true
serde = { workspace = true, features = ["derive"] }
tempfile.workspace = true
thiserror.workspace = true

View File

@@ -6,61 +6,50 @@ use crate::api::{
use ere_zkvm_interface::zkvm::{
Input, ProgramExecutionReport, ProgramProvingReport, Proof, ProofKind, PublicValues,
};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::time::sleep;
use twirp::{Client, Request, reqwest};
pub use twirp::{TwirpErrorResponse, url::Url};
pub use twirp::{
TwirpErrorResponse,
url::{ParseError, Url},
};
#[derive(Debug, Error)]
#[allow(non_camel_case_types)]
pub enum Error {
#[error("Invalid URL: {0}")]
ParseUrl(#[from] ParseError),
#[error("zkVM method error: {0}")]
zkVM(String),
#[error("Connection to zkVM server timeout after 5 minutes")]
ConnectionTimeout,
#[error("RPC error: {0}")]
Rpc(#[from] TwirpErrorResponse),
}
/// zkVM client of the `zkVMServer`.
#[allow(non_camel_case_types)]
#[derive(Clone, Debug)]
pub struct zkVMClient {
client: Client,
}
impl zkVMClient {
pub async fn new(url: Url) -> Result<Self, Error> {
const TIMEOUT: Duration = Duration::from_secs(300); // 5mins
const INTERVAL: Duration = Duration::from_millis(500);
pub fn new(endpoint: Url, http_client: reqwest::Client) -> Result<Self, Error> {
Ok(Self {
client: Client::new(endpoint.join("twirp")?, http_client, Vec::new(), None),
})
}
let http_client = reqwest::Client::new();
let start = Instant::now();
loop {
if start.elapsed() > TIMEOUT {
return Err(Error::ConnectionTimeout);
}
match http_client.get(url.join("health").unwrap()).send().await {
Ok(response) if response.status().is_success() => break,
_ => sleep(INTERVAL).await,
}
}
let client = Client::new(url.join("twirp").unwrap(), http_client, Vec::new(), None);
Ok(Self { client })
pub fn from_endpoint(endpoint: Url) -> Result<Self, Error> {
Self::new(endpoint, reqwest::Client::new())
}
pub async fn execute(
&self,
input: &Input,
input: Input,
) -> Result<(PublicValues, ProgramExecutionReport), Error> {
let request = Request::new(ExecuteRequest {
input_stdin: input.stdin.clone(),
input_proofs: input.proofs.clone(),
input_stdin: input.stdin,
input_proofs: input.proofs,
});
let response = self.client.execute(request).await?;
@@ -78,12 +67,12 @@ impl zkVMClient {
pub async fn prove(
&self,
input: &Input,
input: Input,
proof_kind: ProofKind,
) -> Result<(PublicValues, Proof, ProgramProvingReport), Error> {
let request = Request::new(ProveRequest {
input_stdin: input.stdin.clone(),
input_proofs: input.proofs.clone(),
input_stdin: input.stdin,
input_proofs: input.proofs,
proof_kind: proof_kind as i32,
});
@@ -101,10 +90,11 @@ impl zkVMClient {
}
}
pub async fn verify(&self, proof: &Proof) -> Result<PublicValues, Error> {
pub async fn verify(&self, proof: Proof) -> Result<PublicValues, Error> {
let proof_kind = proof.kind() as i32;
let request = Request::new(VerifyRequest {
proof: proof.as_bytes().to_vec(),
proof_kind: proof.kind() as i32,
proof: proof.into_bytes(),
proof_kind,
});
let response = self.client.verify(request).await?;

View File

@@ -40,8 +40,13 @@ const _: () = {
#[derive(Parser)]
#[command(author, version)]
struct Args {
/// Port number for the server to listen on.
#[arg(long, default_value = "3000")]
port: u16,
/// Optional path to read the program from. If not specified, reads from stdin.
#[arg(long)]
program_path: Option<String>,
/// Prover resource type.
#[command(subcommand)]
resource: ProverResourceType,
}
@@ -54,9 +59,16 @@ async fn main() -> Result<(), Error> {
let args = Args::parse();
// Read serialized program from stdin.
let mut program = Vec::new();
io::stdin().read_to_end(&mut program)?;
// Read serialized program from file or stdin.
let program = if let Some(path) = args.program_path {
std::fs::read(&path).with_context(|| format!("Failed to read program from {path}"))?
} else {
let mut program = Vec::new();
io::stdin()
.read_to_end(&mut program)
.context("Failed to read program from stdin")?;
program
};
let zkvm = construct_zkvm(program, args.resource)?;
let server = Arc::new(zkVMServer::new(zkvm));

View File

@@ -15,7 +15,10 @@ use crate::{
},
zkVMKind,
};
use ere_server::client::{Url, zkVMClient};
use ere_server::{
api::twirp::reqwest::Client,
client::{self, Url, zkVMClient},
};
use ere_zkvm_interface::{
CommonError,
zkvm::{
@@ -23,9 +26,15 @@ use ere_zkvm_interface::{
PublicValues, zkVM,
},
};
use parking_lot::RwLock;
use std::{future::Future, iter};
use std::{
future::Future,
iter,
pin::Pin,
sync::OnceLock,
time::{Duration, Instant},
};
use tempfile::TempDir;
use tokio::{sync::RwLock, time::sleep};
use tracing::{error, info};
mod error;
@@ -248,13 +257,14 @@ impl ServerContainer {
&program.0,
)?;
let endpoint = Url::parse(&format!("http://{host}:{port}")).unwrap();
let client = block_on(zkVMClient::new(endpoint))?;
let endpoint = Url::parse(&format!("http://{host}:{port}"))?;
let http_client = Client::new();
block_on(wait_until_healthy(&endpoint, http_client.clone()))?;
Ok(ServerContainer {
name,
tempdir,
client,
client: zkVMClient::new(endpoint, http_client)?,
})
}
}
@@ -296,15 +306,46 @@ impl DockerizedzkVM {
&self.resource
}
fn with_retry<T, F>(&self, mut f: F) -> anyhow::Result<T>
pub async fn execute_async(
&self,
input: Input,
) -> anyhow::Result<(PublicValues, ProgramExecutionReport)> {
self.with_retry(|client| {
let input = input.clone();
Box::pin(async move { client.execute(input).await })
})
.await
}
pub async fn prove_async(
&self,
input: Input,
proof_kind: ProofKind,
) -> anyhow::Result<(PublicValues, Proof, ProgramProvingReport)> {
self.with_retry(|client| {
let input = input.clone();
Box::pin(async move { client.prove(input, proof_kind).await })
})
.await
}
pub async fn verify_async(&self, proof: Proof) -> anyhow::Result<PublicValues> {
self.with_retry(|client| {
let proof = proof.clone();
Box::pin(async move { client.verify(proof).await })
})
.await
}
async fn with_retry<T, F>(&self, f: F) -> anyhow::Result<T>
where
F: FnMut(&zkVMClient) -> Result<T, ere_server::client::Error>,
F: Fn(zkVMClient) -> Pin<Box<dyn Future<Output = Result<T, client::Error>> + Send>>,
{
const MAX_RETRY: usize = 3;
let mut attempt = 1;
loop {
let err = match f(&self.container.read().as_ref().unwrap().client) {
let err = match f(self.container.read().await.as_ref().unwrap().client.clone()).await {
Ok(ok) => return Ok(ok),
Err(err) => Error::from(err),
};
@@ -319,7 +360,7 @@ impl DockerizedzkVM {
error!("Rpc failed (attempt {attempt}/{MAX_RETRY}): {err}, checking container...");
let mut container = self.container.write();
let mut container = self.container.write().await;
if docker_container_exists(&container.as_ref().unwrap().name).is_ok_and(|exists| exists)
{
info!("Container is still running, retrying...");
@@ -340,7 +381,7 @@ impl DockerizedzkVM {
impl zkVM for DockerizedzkVM {
fn execute(&self, input: &Input) -> anyhow::Result<(PublicValues, ProgramExecutionReport)> {
self.with_retry(|client| block_on(client.execute(input)))
block_on(self.execute_async(input.clone()))
}
fn prove(
@@ -348,11 +389,11 @@ impl zkVM for DockerizedzkVM {
input: &Input,
proof_kind: ProofKind,
) -> anyhow::Result<(PublicValues, Proof, ProgramProvingReport)> {
self.with_retry(|client| block_on(client.prove(input, proof_kind)))
block_on(self.prove_async(input.clone(), proof_kind))
}
fn verify(&self, proof: &Proof) -> anyhow::Result<PublicValues> {
self.with_retry(|client| block_on(client.verify(proof)))
block_on(self.verify_async(proof.clone()))
}
fn name(&self) -> &'static str {
@@ -364,13 +405,35 @@ impl zkVM for DockerizedzkVM {
}
}
fn block_on<T>(future: impl Future<Output = T>) -> T {
match tokio::runtime::Handle::try_current() {
Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
Err(_) => tokio::runtime::Runtime::new().unwrap().block_on(future),
async fn wait_until_healthy(endpoint: &Url, http_client: Client) -> Result<(), Error> {
const TIMEOUT: Duration = Duration::from_secs(300); // 5mins
const INTERVAL: Duration = Duration::from_millis(500);
let http_client = http_client.clone();
let start = Instant::now();
loop {
if start.elapsed() > TIMEOUT {
return Err(Error::ConnectionTimeout);
}
match http_client.get(endpoint.join("health")?).send().await {
Ok(response) if response.status().is_success() => break Ok(()),
_ => sleep(INTERVAL).await,
}
}
}
fn block_on<T>(future: impl Future<Output = T>) -> T {
static FALLBACK_RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
FALLBACK_RT
.get_or_init(|| tokio::runtime::Runtime::new().expect("Failed to create runtime"))
.handle()
.clone()
});
tokio::task::block_in_place(|| handle.block_on(future))
}
#[cfg(test)]
mod test {
use crate::{

View File

@@ -1,12 +1,12 @@
use ere_server::client::{self, TwirpErrorResponse};
use ere_server::client::{self, ParseError, TwirpErrorResponse};
use ere_zkvm_interface::CommonError;
use thiserror::Error;
impl From<client::Error> for Error {
fn from(value: client::Error) -> Self {
match value {
client::Error::ParseUrl(err) => Self::ParseUrl(err),
client::Error::zkVM(err) => Self::zkVM(err),
client::Error::ConnectionTimeout => Self::ConnectionTimeout,
client::Error::Rpc(err) => Self::Rpc(err),
}
}
@@ -17,6 +17,8 @@ impl From<client::Error> for Error {
pub enum Error {
#[error(transparent)]
CommonError(#[from] CommonError),
#[error(transparent)]
ParseUrl(#[from] ParseError),
#[error("zkVM method error: {0}")]
zkVM(String),
#[error("Connection to zkVM server timeout after 5 minutes")]

View File

@@ -11,6 +11,7 @@ auto_impl.workspace = true
bincode = { workspace = true, features = ["alloc", "serde"] }
indexmap = { workspace = true, features = ["serde"] }
serde = { workspace = true, features = ["derive"] }
serde-untagged.workspace = true
strum = { workspace = true, features = ["derive"] }
thiserror.workspace = true
@@ -20,6 +21,8 @@ clap = { workspace = true, features = ["derive"], optional = true }
[dev-dependencies]
bincode = { workspace = true, features = ["alloc", "serde"] }
serde_json.workspace = true
serde_yaml.workspace = true
toml.workspace = true
[lints]
workspace = true

View File

@@ -30,11 +30,19 @@ impl Proof {
ProofKind::from(self)
}
/// Returns inner proof as bytes.
/// Returns inner proof as bytes slice.
pub fn as_bytes(&self) -> &[u8] {
match self {
Self::Compressed(bytes) => bytes,
Self::Groth16(bytes) => bytes,
}
}
/// Returns inner proof as bytes vec.
pub fn into_bytes(self) -> Vec<u8> {
match self {
Self::Compressed(bytes) => bytes,
Self::Groth16(bytes) => bytes,
}
}
}

View File

@@ -1,16 +1,15 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, de::Unexpected};
use serde_untagged::UntaggedEnumVisitor;
/// Configuration for network-based proving
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct NetworkProverConfig {
#[cfg_attr(feature = "clap", arg(long))]
/// The endpoint URL of the prover network service
pub endpoint: String,
#[cfg_attr(feature = "clap", arg(long))]
pub endpoint: String,
/// Optional API key for authentication
#[cfg_attr(feature = "clap", arg(long))]
pub api_key: Option<String>,
}
@@ -25,8 +24,7 @@ impl NetworkProverConfig {
}
/// ResourceType specifies what resource will be used to create the proofs.
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "clap", derive(clap::Subcommand))]
pub enum ProverResourceType {
#[default]
@@ -48,3 +46,92 @@ impl ProverResourceType {
}
}
}
impl Serialize for ProverResourceType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::Cpu => "cpu".serialize(serializer),
Self::Gpu => "gpu".serialize(serializer),
Self::Network(config) => config.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for ProverResourceType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
UntaggedEnumVisitor::new()
.string(|resource| match resource {
"cpu" => Ok(Self::Cpu),
"gpu" => Ok(Self::Gpu),
_ => Err(serde::de::Error::invalid_value(
Unexpected::Str(resource),
&r#""cpu" or "gpu""#,
)),
})
.map(|map| map.deserialize().map(Self::Network))
.deserialize(deserializer)
}
}
#[cfg(test)]
mod test {
use crate::zkvm::resource::ProverResourceType;
use core::fmt::Debug;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Config {
resources: Vec<ProverResourceType>,
}
fn test_round_trip<'de, SE: Debug, DE: Debug>(
config: &'de str,
ser: impl Fn(&Config) -> Result<String, SE>,
de: impl Fn(&'de str) -> Result<Config, DE>,
) {
assert_eq!(config.trim(), ser(&de(config).unwrap()).unwrap().trim())
}
#[test]
fn test_round_trip_toml() {
const TOML: &str = r#"
resources = ["cpu", "gpu", { endpoint = "http://localhost:3000" }]
"#;
test_round_trip(TOML, toml::to_string, toml::from_str);
}
#[test]
fn test_round_trip_yaml() {
const YAML: &str = r#"
resources:
- cpu
- gpu
- endpoint: http://localhost:3000
api_key: null
"#;
test_round_trip(YAML, serde_yaml::to_string, serde_yaml::from_str);
}
#[test]
fn test_round_trip_json() {
const JSON: &str = r#"
{
"resources": [
"cpu",
"gpu",
{
"endpoint": "",
"api_key": null
}
]
}
"#;
test_round_trip(JSON, serde_json::to_string_pretty, serde_json::from_str);
}
}