Support --program-path flag in ere-server (#270)

This commit is contained in:
Han
2026-01-16 23:15:21 +09:00
committed by GitHub
parent 0e41aa7e95
commit a9fb04dea8
10 changed files with 264 additions and 67 deletions

View File

@@ -6,61 +6,50 @@ use crate::api::{
use ere_zkvm_interface::zkvm::{
Input, ProgramExecutionReport, ProgramProvingReport, Proof, ProofKind, PublicValues,
};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::time::sleep;
use twirp::{Client, Request, reqwest};
pub use twirp::{TwirpErrorResponse, url::Url};
pub use twirp::{
TwirpErrorResponse,
url::{ParseError, Url},
};
#[derive(Debug, Error)]
#[allow(non_camel_case_types)]
pub enum Error {
#[error("Invalid URL: {0}")]
ParseUrl(#[from] ParseError),
#[error("zkVM method error: {0}")]
zkVM(String),
#[error("Connection to zkVM server timeout after 5 minutes")]
ConnectionTimeout,
#[error("RPC error: {0}")]
Rpc(#[from] TwirpErrorResponse),
}
/// zkVM client of the `zkVMServer`.
#[allow(non_camel_case_types)]
#[derive(Clone, Debug)]
pub struct zkVMClient {
client: Client,
}
impl zkVMClient {
pub async fn new(url: Url) -> Result<Self, Error> {
const TIMEOUT: Duration = Duration::from_secs(300); // 5mins
const INTERVAL: Duration = Duration::from_millis(500);
pub fn new(endpoint: Url, http_client: reqwest::Client) -> Result<Self, Error> {
Ok(Self {
client: Client::new(endpoint.join("twirp")?, http_client, Vec::new(), None),
})
}
let http_client = reqwest::Client::new();
let start = Instant::now();
loop {
if start.elapsed() > TIMEOUT {
return Err(Error::ConnectionTimeout);
}
match http_client.get(url.join("health").unwrap()).send().await {
Ok(response) if response.status().is_success() => break,
_ => sleep(INTERVAL).await,
}
}
let client = Client::new(url.join("twirp").unwrap(), http_client, Vec::new(), None);
Ok(Self { client })
pub fn from_endpoint(endpoint: Url) -> Result<Self, Error> {
Self::new(endpoint, reqwest::Client::new())
}
pub async fn execute(
&self,
input: &Input,
input: Input,
) -> Result<(PublicValues, ProgramExecutionReport), Error> {
let request = Request::new(ExecuteRequest {
input_stdin: input.stdin.clone(),
input_proofs: input.proofs.clone(),
input_stdin: input.stdin,
input_proofs: input.proofs,
});
let response = self.client.execute(request).await?;
@@ -78,12 +67,12 @@ impl zkVMClient {
pub async fn prove(
&self,
input: &Input,
input: Input,
proof_kind: ProofKind,
) -> Result<(PublicValues, Proof, ProgramProvingReport), Error> {
let request = Request::new(ProveRequest {
input_stdin: input.stdin.clone(),
input_proofs: input.proofs.clone(),
input_stdin: input.stdin,
input_proofs: input.proofs,
proof_kind: proof_kind as i32,
});
@@ -101,10 +90,11 @@ impl zkVMClient {
}
}
pub async fn verify(&self, proof: &Proof) -> Result<PublicValues, Error> {
pub async fn verify(&self, proof: Proof) -> Result<PublicValues, Error> {
let proof_kind = proof.kind() as i32;
let request = Request::new(VerifyRequest {
proof: proof.as_bytes().to_vec(),
proof_kind: proof.kind() as i32,
proof: proof.into_bytes(),
proof_kind,
});
let response = self.client.verify(request).await?;

View File

@@ -40,8 +40,13 @@ const _: () = {
#[derive(Parser)]
#[command(author, version)]
struct Args {
/// Port number for the server to listen on.
#[arg(long, default_value = "3000")]
port: u16,
/// Optional path to read the program from. If not specified, reads from stdin.
#[arg(long)]
program_path: Option<String>,
/// Prover resource type.
#[command(subcommand)]
resource: ProverResourceType,
}
@@ -54,9 +59,16 @@ async fn main() -> Result<(), Error> {
let args = Args::parse();
// Read serialized program from stdin.
let mut program = Vec::new();
io::stdin().read_to_end(&mut program)?;
// Read serialized program from file or stdin.
let program = if let Some(path) = args.program_path {
std::fs::read(&path).with_context(|| format!("Failed to read program from {path}"))?
} else {
let mut program = Vec::new();
io::stdin()
.read_to_end(&mut program)
.context("Failed to read program from stdin")?;
program
};
let zkvm = construct_zkvm(program, args.resource)?;
let server = Arc::new(zkVMServer::new(zkvm));