diff --git a/Cargo.lock b/Cargo.lock index 98f98c9..fdc0eaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3871,7 +3871,6 @@ dependencies = [ "ere-server", "ere-test-utils", "ere-zkvm-interface", - "parking_lot", "paste", "serde", "tempfile", @@ -4193,9 +4192,12 @@ dependencies = [ "clap", "indexmap 2.10.0", "serde", + "serde-untagged", "serde_json", + "serde_yaml", "strum 0.27.2", "thiserror 2.0.12", + "toml 0.8.23", ] [[package]] @@ -12717,6 +12719,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde-untagged" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9faf48a4a2d2693be24c6289dbe26552776eb7737074e6722891fadbe6c5058" +dependencies = [ + "erased-serde", + "serde", + "serde_core", + "typeid", +] + [[package]] name = "serde_arrays" version = "0.1.0" @@ -12847,6 +12861,19 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap 2.10.0", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "serdect" version = "0.2.0" @@ -15136,6 +15163,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index 7e83cda..fa77c8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,6 @@ dashmap = "6.1.0" digest = { version = "0.10.7", default-features = false } eyre = "0.6.12" indexmap = "2.10.0" -parking_lot = "0.12.5" paste = "1.0.15" postcard = { version = "1.0.8", default-features = false } prost = "0.13" @@ -68,6 +67,7 @@ rkyv = { version = "0.8.12", default-features = false } serde = { version = "1.0.219", default-features = false } serde_bytes = { version = "0.11.19", default-features = false } serde_json = "1.0.142" +serde-untagged = "0.1" serde_yaml = "0.9.34" sha2 = { version = "0.10.9", default-features = false } strum = "0.27.2" diff --git a/crates/dockerized/Cargo.toml b/crates/dockerized/Cargo.toml index 5eab617..4a72962 100644 --- a/crates/dockerized/Cargo.toml +++ b/crates/dockerized/Cargo.toml @@ -7,7 +7,6 @@ license.workspace = true [dependencies] anyhow.workspace = true -parking_lot.workspace = true serde = { workspace = true, features = ["derive"] } tempfile.workspace = true thiserror.workspace = true diff --git a/crates/dockerized/server/src/client.rs b/crates/dockerized/server/src/client.rs index 2e84037..1280ae8 100644 --- a/crates/dockerized/server/src/client.rs +++ b/crates/dockerized/server/src/client.rs @@ -6,61 +6,50 @@ use crate::api::{ use ere_zkvm_interface::zkvm::{ Input, ProgramExecutionReport, ProgramProvingReport, Proof, ProofKind, PublicValues, }; -use std::time::{Duration, Instant}; use thiserror::Error; -use tokio::time::sleep; use twirp::{Client, Request, reqwest}; -pub use twirp::{TwirpErrorResponse, url::Url}; +pub use twirp::{ + TwirpErrorResponse, + url::{ParseError, Url}, +}; #[derive(Debug, Error)] #[allow(non_camel_case_types)] pub enum Error { + #[error("Invalid URL: {0}")] + ParseUrl(#[from] ParseError), #[error("zkVM method error: {0}")] zkVM(String), - #[error("Connection to zkVM server timeout after 5 minutes")] - ConnectionTimeout, #[error("RPC error: {0}")] Rpc(#[from] TwirpErrorResponse), } /// zkVM client of the `zkVMServer`. #[allow(non_camel_case_types)] +#[derive(Clone, Debug)] pub struct zkVMClient { client: Client, } impl zkVMClient { - pub async fn new(url: Url) -> Result { - const TIMEOUT: Duration = Duration::from_secs(300); // 5mins - const INTERVAL: Duration = Duration::from_millis(500); + pub fn new(endpoint: Url, http_client: reqwest::Client) -> Result { + Ok(Self { + client: Client::new(endpoint.join("twirp")?, http_client, Vec::new(), None), + }) + } - let http_client = reqwest::Client::new(); - - let start = Instant::now(); - loop { - if start.elapsed() > TIMEOUT { - return Err(Error::ConnectionTimeout); - } - - match http_client.get(url.join("health").unwrap()).send().await { - Ok(response) if response.status().is_success() => break, - _ => sleep(INTERVAL).await, - } - } - - let client = Client::new(url.join("twirp").unwrap(), http_client, Vec::new(), None); - - Ok(Self { client }) + pub fn from_endpoint(endpoint: Url) -> Result { + Self::new(endpoint, reqwest::Client::new()) } pub async fn execute( &self, - input: &Input, + input: Input, ) -> Result<(PublicValues, ProgramExecutionReport), Error> { let request = Request::new(ExecuteRequest { - input_stdin: input.stdin.clone(), - input_proofs: input.proofs.clone(), + input_stdin: input.stdin, + input_proofs: input.proofs, }); let response = self.client.execute(request).await?; @@ -78,12 +67,12 @@ impl zkVMClient { pub async fn prove( &self, - input: &Input, + input: Input, proof_kind: ProofKind, ) -> Result<(PublicValues, Proof, ProgramProvingReport), Error> { let request = Request::new(ProveRequest { - input_stdin: input.stdin.clone(), - input_proofs: input.proofs.clone(), + input_stdin: input.stdin, + input_proofs: input.proofs, proof_kind: proof_kind as i32, }); @@ -101,10 +90,11 @@ impl zkVMClient { } } - pub async fn verify(&self, proof: &Proof) -> Result { + pub async fn verify(&self, proof: Proof) -> Result { + let proof_kind = proof.kind() as i32; let request = Request::new(VerifyRequest { - proof: proof.as_bytes().to_vec(), - proof_kind: proof.kind() as i32, + proof: proof.into_bytes(), + proof_kind, }); let response = self.client.verify(request).await?; diff --git a/crates/dockerized/server/src/main.rs b/crates/dockerized/server/src/main.rs index b1a7e4f..96a9899 100644 --- a/crates/dockerized/server/src/main.rs +++ b/crates/dockerized/server/src/main.rs @@ -40,8 +40,13 @@ const _: () = { #[derive(Parser)] #[command(author, version)] struct Args { + /// Port number for the server to listen on. #[arg(long, default_value = "3000")] port: u16, + /// Optional path to read the program from. If not specified, reads from stdin. + #[arg(long)] + program_path: Option, + /// Prover resource type. #[command(subcommand)] resource: ProverResourceType, } @@ -54,9 +59,16 @@ async fn main() -> Result<(), Error> { let args = Args::parse(); - // Read serialized program from stdin. - let mut program = Vec::new(); - io::stdin().read_to_end(&mut program)?; + // Read serialized program from file or stdin. + let program = if let Some(path) = args.program_path { + std::fs::read(&path).with_context(|| format!("Failed to read program from {path}"))? + } else { + let mut program = Vec::new(); + io::stdin() + .read_to_end(&mut program) + .context("Failed to read program from stdin")?; + program + }; let zkvm = construct_zkvm(program, args.resource)?; let server = Arc::new(zkVMServer::new(zkvm)); diff --git a/crates/dockerized/src/zkvm.rs b/crates/dockerized/src/zkvm.rs index 3ee308a..a27b25b 100644 --- a/crates/dockerized/src/zkvm.rs +++ b/crates/dockerized/src/zkvm.rs @@ -15,7 +15,10 @@ use crate::{ }, zkVMKind, }; -use ere_server::client::{Url, zkVMClient}; +use ere_server::{ + api::twirp::reqwest::Client, + client::{self, Url, zkVMClient}, +}; use ere_zkvm_interface::{ CommonError, zkvm::{ @@ -23,9 +26,15 @@ use ere_zkvm_interface::{ PublicValues, zkVM, }, }; -use parking_lot::RwLock; -use std::{future::Future, iter}; +use std::{ + future::Future, + iter, + pin::Pin, + sync::OnceLock, + time::{Duration, Instant}, +}; use tempfile::TempDir; +use tokio::{sync::RwLock, time::sleep}; use tracing::{error, info}; mod error; @@ -248,13 +257,14 @@ impl ServerContainer { &program.0, )?; - let endpoint = Url::parse(&format!("http://{host}:{port}")).unwrap(); - let client = block_on(zkVMClient::new(endpoint))?; + let endpoint = Url::parse(&format!("http://{host}:{port}"))?; + let http_client = Client::new(); + block_on(wait_until_healthy(&endpoint, http_client.clone()))?; Ok(ServerContainer { name, tempdir, - client, + client: zkVMClient::new(endpoint, http_client)?, }) } } @@ -296,15 +306,46 @@ impl DockerizedzkVM { &self.resource } - fn with_retry(&self, mut f: F) -> anyhow::Result + pub async fn execute_async( + &self, + input: Input, + ) -> anyhow::Result<(PublicValues, ProgramExecutionReport)> { + self.with_retry(|client| { + let input = input.clone(); + Box::pin(async move { client.execute(input).await }) + }) + .await + } + + pub async fn prove_async( + &self, + input: Input, + proof_kind: ProofKind, + ) -> anyhow::Result<(PublicValues, Proof, ProgramProvingReport)> { + self.with_retry(|client| { + let input = input.clone(); + Box::pin(async move { client.prove(input, proof_kind).await }) + }) + .await + } + + pub async fn verify_async(&self, proof: Proof) -> anyhow::Result { + self.with_retry(|client| { + let proof = proof.clone(); + Box::pin(async move { client.verify(proof).await }) + }) + .await + } + + async fn with_retry(&self, f: F) -> anyhow::Result where - F: FnMut(&zkVMClient) -> Result, + F: Fn(zkVMClient) -> Pin> + Send>>, { const MAX_RETRY: usize = 3; let mut attempt = 1; loop { - let err = match f(&self.container.read().as_ref().unwrap().client) { + let err = match f(self.container.read().await.as_ref().unwrap().client.clone()).await { Ok(ok) => return Ok(ok), Err(err) => Error::from(err), }; @@ -319,7 +360,7 @@ impl DockerizedzkVM { error!("Rpc failed (attempt {attempt}/{MAX_RETRY}): {err}, checking container..."); - let mut container = self.container.write(); + let mut container = self.container.write().await; if docker_container_exists(&container.as_ref().unwrap().name).is_ok_and(|exists| exists) { info!("Container is still running, retrying..."); @@ -340,7 +381,7 @@ impl DockerizedzkVM { impl zkVM for DockerizedzkVM { fn execute(&self, input: &Input) -> anyhow::Result<(PublicValues, ProgramExecutionReport)> { - self.with_retry(|client| block_on(client.execute(input))) + block_on(self.execute_async(input.clone())) } fn prove( @@ -348,11 +389,11 @@ impl zkVM for DockerizedzkVM { input: &Input, proof_kind: ProofKind, ) -> anyhow::Result<(PublicValues, Proof, ProgramProvingReport)> { - self.with_retry(|client| block_on(client.prove(input, proof_kind))) + block_on(self.prove_async(input.clone(), proof_kind)) } fn verify(&self, proof: &Proof) -> anyhow::Result { - self.with_retry(|client| block_on(client.verify(proof))) + block_on(self.verify_async(proof.clone())) } fn name(&self) -> &'static str { @@ -364,13 +405,35 @@ impl zkVM for DockerizedzkVM { } } -fn block_on(future: impl Future) -> T { - match tokio::runtime::Handle::try_current() { - Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)), - Err(_) => tokio::runtime::Runtime::new().unwrap().block_on(future), +async fn wait_until_healthy(endpoint: &Url, http_client: Client) -> Result<(), Error> { + const TIMEOUT: Duration = Duration::from_secs(300); // 5mins + const INTERVAL: Duration = Duration::from_millis(500); + + let http_client = http_client.clone(); + let start = Instant::now(); + loop { + if start.elapsed() > TIMEOUT { + return Err(Error::ConnectionTimeout); + } + + match http_client.get(endpoint.join("health")?).send().await { + Ok(response) if response.status().is_success() => break Ok(()), + _ => sleep(INTERVAL).await, + } } } +fn block_on(future: impl Future) -> T { + static FALLBACK_RT: OnceLock = OnceLock::new(); + let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| { + FALLBACK_RT + .get_or_init(|| tokio::runtime::Runtime::new().expect("Failed to create runtime")) + .handle() + .clone() + }); + tokio::task::block_in_place(|| handle.block_on(future)) +} + #[cfg(test)] mod test { use crate::{ diff --git a/crates/dockerized/src/zkvm/error.rs b/crates/dockerized/src/zkvm/error.rs index 58aa117..c61aa2d 100644 --- a/crates/dockerized/src/zkvm/error.rs +++ b/crates/dockerized/src/zkvm/error.rs @@ -1,12 +1,12 @@ -use ere_server::client::{self, TwirpErrorResponse}; +use ere_server::client::{self, ParseError, TwirpErrorResponse}; use ere_zkvm_interface::CommonError; use thiserror::Error; impl From for Error { fn from(value: client::Error) -> Self { match value { + client::Error::ParseUrl(err) => Self::ParseUrl(err), client::Error::zkVM(err) => Self::zkVM(err), - client::Error::ConnectionTimeout => Self::ConnectionTimeout, client::Error::Rpc(err) => Self::Rpc(err), } } @@ -17,6 +17,8 @@ impl From for Error { pub enum Error { #[error(transparent)] CommonError(#[from] CommonError), + #[error(transparent)] + ParseUrl(#[from] ParseError), #[error("zkVM method error: {0}")] zkVM(String), #[error("Connection to zkVM server timeout after 5 minutes")] diff --git a/crates/zkvm-interface/Cargo.toml b/crates/zkvm-interface/Cargo.toml index 884d6b4..8b3e024 100644 --- a/crates/zkvm-interface/Cargo.toml +++ b/crates/zkvm-interface/Cargo.toml @@ -11,6 +11,7 @@ auto_impl.workspace = true bincode = { workspace = true, features = ["alloc", "serde"] } indexmap = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } +serde-untagged.workspace = true strum = { workspace = true, features = ["derive"] } thiserror.workspace = true @@ -20,6 +21,8 @@ clap = { workspace = true, features = ["derive"], optional = true } [dev-dependencies] bincode = { workspace = true, features = ["alloc", "serde"] } serde_json.workspace = true +serde_yaml.workspace = true +toml.workspace = true [lints] workspace = true diff --git a/crates/zkvm-interface/src/zkvm/proof.rs b/crates/zkvm-interface/src/zkvm/proof.rs index 757be3a..e173dbb 100644 --- a/crates/zkvm-interface/src/zkvm/proof.rs +++ b/crates/zkvm-interface/src/zkvm/proof.rs @@ -30,11 +30,19 @@ impl Proof { ProofKind::from(self) } - /// Returns inner proof as bytes. + /// Returns inner proof as bytes slice. pub fn as_bytes(&self) -> &[u8] { match self { Self::Compressed(bytes) => bytes, Self::Groth16(bytes) => bytes, } } + + /// Returns inner proof as bytes vec. + pub fn into_bytes(self) -> Vec { + match self { + Self::Compressed(bytes) => bytes, + Self::Groth16(bytes) => bytes, + } + } } diff --git a/crates/zkvm-interface/src/zkvm/resource.rs b/crates/zkvm-interface/src/zkvm/resource.rs index 672eaff..e804bc8 100644 --- a/crates/zkvm-interface/src/zkvm/resource.rs +++ b/crates/zkvm-interface/src/zkvm/resource.rs @@ -1,16 +1,15 @@ -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, de::Unexpected}; +use serde_untagged::UntaggedEnumVisitor; /// Configuration for network-based proving #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] #[cfg_attr(feature = "clap", derive(clap::Args))] pub struct NetworkProverConfig { - #[cfg_attr(feature = "clap", arg(long))] /// The endpoint URL of the prover network service - pub endpoint: String, - #[cfg_attr(feature = "clap", arg(long))] + pub endpoint: String, /// Optional API key for authentication + #[cfg_attr(feature = "clap", arg(long))] pub api_key: Option, } @@ -25,8 +24,7 @@ impl NetworkProverConfig { } /// ResourceType specifies what resource will be used to create the proofs. -#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] +#[derive(Debug, Clone, Default, PartialEq, Eq)] #[cfg_attr(feature = "clap", derive(clap::Subcommand))] pub enum ProverResourceType { #[default] @@ -48,3 +46,92 @@ impl ProverResourceType { } } } + +impl Serialize for ProverResourceType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Cpu => "cpu".serialize(serializer), + Self::Gpu => "gpu".serialize(serializer), + Self::Network(config) => config.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for ProverResourceType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + UntaggedEnumVisitor::new() + .string(|resource| match resource { + "cpu" => Ok(Self::Cpu), + "gpu" => Ok(Self::Gpu), + _ => Err(serde::de::Error::invalid_value( + Unexpected::Str(resource), + &r#""cpu" or "gpu""#, + )), + }) + .map(|map| map.deserialize().map(Self::Network)) + .deserialize(deserializer) + } +} + +#[cfg(test)] +mod test { + use crate::zkvm::resource::ProverResourceType; + use core::fmt::Debug; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize)] + struct Config { + resources: Vec, + } + + fn test_round_trip<'de, SE: Debug, DE: Debug>( + config: &'de str, + ser: impl Fn(&Config) -> Result, + de: impl Fn(&'de str) -> Result, + ) { + assert_eq!(config.trim(), ser(&de(config).unwrap()).unwrap().trim()) + } + + #[test] + fn test_round_trip_toml() { + const TOML: &str = r#" +resources = ["cpu", "gpu", { endpoint = "http://localhost:3000" }] + "#; + test_round_trip(TOML, toml::to_string, toml::from_str); + } + + #[test] + fn test_round_trip_yaml() { + const YAML: &str = r#" +resources: +- cpu +- gpu +- endpoint: http://localhost:3000 + api_key: null +"#; + test_round_trip(YAML, serde_yaml::to_string, serde_yaml::from_str); + } + + #[test] + fn test_round_trip_json() { + const JSON: &str = r#" +{ + "resources": [ + "cpu", + "gpu", + { + "endpoint": "", + "api_key": null + } + ] +} +"#; + test_round_trip(JSON, serde_json::to_string_pretty, serde_json::from_str); + } +}