diff --git a/Cargo.lock b/Cargo.lock index 36e3eb6..13da385 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3825,6 +3825,7 @@ version = "0.0.15" dependencies = [ "ere-build-utils", "serde", + "strum 0.27.2", ] [[package]] diff --git a/crates/dockerized/common/Cargo.toml b/crates/dockerized/common/Cargo.toml index e5216f7..4a7384a 100644 --- a/crates/dockerized/common/Cargo.toml +++ b/crates/dockerized/common/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [dependencies] serde = { workspace = true, features = ["alloc", "derive"] } +strum = { workspace = true, features = ["derive"] } [build-dependencies] ere-build-utils.workspace = true diff --git a/crates/dockerized/common/build.rs b/crates/dockerized/common/build.rs index 071bae4..b39539b 100644 --- a/crates/dockerized/common/build.rs +++ b/crates/dockerized/common/build.rs @@ -9,7 +9,7 @@ fn main() { fn generate_crate_version() { let crate_version = format!( - "/// Crate version in format of `{{semantic_version}}{{git_sha:7}}`\npub const CRATE_VERSION: &str = \"{}\";", + "/// Crate version in format of `{{semantic_version}}-{{git_sha:7}}`\npub const CRATE_VERSION: &str = \"{}\";", detect_self_crate_version() ); diff --git a/crates/dockerized/common/src/compiler.rs b/crates/dockerized/common/src/compiler.rs index 5da3630..c92c174 100644 --- a/crates/dockerized/common/src/compiler.rs +++ b/crates/dockerized/common/src/compiler.rs @@ -1,31 +1,49 @@ use serde::{Deserialize, Serialize}; use std::{ + error::Error, fmt::{self, Display, Formatter}, - str::FromStr, }; +use strum::{Display, EnumIter, EnumString, IntoEnumIterator, IntoStaticStr}; /// Compiler kind to use to compile the guest. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + EnumIter, + EnumString, + IntoStaticStr, + Display, +)] #[serde(into = "String", try_from = "String")] +#[strum( + ascii_case_insensitive, + serialize_all = "kebab-case", + parse_err_fn = ParseError::from, + parse_err_ty = ParseError +)] pub enum CompilerKind { /// Stock Rust compiler Rust, /// Rust compiler with customized toolchain + #[strum(serialize = "rust-customized", serialize = "RustCustomized")] RustCustomized, /// Go compiler with customized toolchain + #[strum(serialize = "go-customized", serialize = "GoCustomized")] GoCustomized, /// Miden assembly compiler + #[strum(serialize = "miden-asm", serialize = "MidenAsm")] MidenAsm, } impl CompilerKind { pub fn as_str(&self) -> &'static str { - match self { - Self::Rust => "rust", - Self::RustCustomized => "rust-customized", - Self::GoCustomized => "go-customized", - Self::MidenAsm => "miden-asm", - } + self.into() } } @@ -35,30 +53,60 @@ impl From for String { } } -impl FromStr for CompilerKind { - type Err = String; - - fn from_str(s: &str) -> Result { - Ok(match s { - "rust" => Self::Rust, - "rust-customized" => Self::RustCustomized, - "go-customized" => Self::GoCustomized, - "miden-asm" => Self::MidenAsm, - _ => return Err(format!("Unsupported compiler kind {s}")), - }) - } -} - impl TryFrom for CompilerKind { - type Error = String; + type Error = ParseError; fn try_from(s: String) -> Result { s.parse() } } -impl Display for CompilerKind { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ParseError(String); + +impl From<&str> for ParseError { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let unsupported = &self.0; + let supported = Vec::from_iter(CompilerKind::iter().map(|k| k.as_str())).join(", "); + write!( + f, + "Unsupported compiler kind `{unsupported}`, expect one of [{supported}]", + ) + } +} + +impl Error for ParseError {} + +#[cfg(test)] +mod test { + use crate::compiler::{CompilerKind, CompilerKind::*, ParseError}; + + #[test] + fn parse_compiler_kind() { + // Valid + for (ss, kind) in [ + (["rust", "Rust"], Rust), + (["rust-customized", "RustCustomized"], RustCustomized), + (["go-customized", "GoCustomized"], GoCustomized), + (["miden-asm", "MidenAsm"], MidenAsm), + ] { + ss.iter().for_each(|s| assert_eq!(s.parse(), Ok(kind))); + assert_eq!(kind.as_str(), ss[0]); + } + + // Invalid + assert_eq!("xxx".parse::(), Err(ParseError::from("xxx"))); + assert_eq!( + ParseError::from("xxx").to_string(), + "Unsupported compiler kind `xxx`, expect one of \ + [rust, rust-customized, go-customized, miden-asm]" + .to_string() + ); } } diff --git a/crates/dockerized/common/src/zkvm.rs b/crates/dockerized/common/src/zkvm.rs index c42e722..804ef31 100644 --- a/crates/dockerized/common/src/zkvm.rs +++ b/crates/dockerized/common/src/zkvm.rs @@ -1,13 +1,33 @@ use serde::{Deserialize, Serialize}; use std::{ + error::Error, fmt::{self, Display, Formatter}, - str::FromStr, }; +use strum::{Display, EnumIter, EnumString, IntoEnumIterator, IntoStaticStr}; /// zkVM kind supported in Ere. #[allow(non_camel_case_types)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + EnumIter, + EnumString, + IntoStaticStr, + Display, +)] #[serde(into = "String", try_from = "String")] +#[strum( + ascii_case_insensitive, + serialize_all = "lowercase", + parse_err_fn = ParseError::from, + parse_err_ty = ParseError +)] pub enum zkVMKind { Airbender, Jolt, @@ -23,18 +43,7 @@ pub enum zkVMKind { impl zkVMKind { pub fn as_str(&self) -> &'static str { - match self { - Self::Airbender => "airbender", - Self::Jolt => "jolt", - Self::Miden => "miden", - Self::Nexus => "nexus", - Self::OpenVM => "openvm", - Self::Pico => "pico", - Self::Risc0 => "risc0", - Self::SP1 => "sp1", - Self::Ziren => "ziren", - Self::Zisk => "zisk", - } + self.into() } } @@ -44,36 +53,66 @@ impl From for String { } } -impl FromStr for zkVMKind { - type Err = String; - - fn from_str(s: &str) -> Result { - Ok(match s { - "airbender" => Self::Airbender, - "jolt" => Self::Jolt, - "miden" => Self::Miden, - "nexus" => Self::Nexus, - "openvm" => Self::OpenVM, - "pico" => Self::Pico, - "risc0" => Self::Risc0, - "sp1" => Self::SP1, - "ziren" => Self::Ziren, - "zisk" => Self::Zisk, - _ => return Err(s.to_string()), - }) - } -} - impl TryFrom for zkVMKind { - type Error = String; + type Error = ParseError; fn try_from(s: String) -> Result { s.parse() } } -impl Display for zkVMKind { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ParseError(String); + +impl From<&str> for ParseError { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let unsupported = &self.0; + let supported = Vec::from_iter(zkVMKind::iter().map(|k| k.as_str())).join(", "); + write!( + f, + "Unsupported zkVM kind `{unsupported}`, expect one of [{supported}]", + ) + } +} + +impl Error for ParseError {} + +#[cfg(test)] +mod test { + use crate::zkvm::{ParseError, zkVMKind}; + + #[test] + fn parse_zkvm_kind() { + // Valid + for (ss, kind) in [ + (["airbender", "Airbender"], zkVMKind::Airbender), + (["jolt", "Jolt"], zkVMKind::Jolt), + (["miden", "Miden"], zkVMKind::Miden), + (["nexus", "Nexus"], zkVMKind::Nexus), + (["openvm", "OpenVM"], zkVMKind::OpenVM), + (["pico", "Pico"], zkVMKind::Pico), + (["risc0", "Risc0"], zkVMKind::Risc0), + (["sp1", "SP1"], zkVMKind::SP1), + (["ziren", "Ziren"], zkVMKind::Ziren), + (["zisk", "Zisk"], zkVMKind::Zisk), + ] { + ss.iter().for_each(|s| assert_eq!(s.parse(), Ok(kind))); + assert_eq!(kind.as_str(), ss[0]); + } + + // Invalid + assert_eq!("xxx".parse::(), Err(ParseError::from("xxx"))); + assert_eq!( + ParseError::from("xxx").to_string(), + "Unsupported zkVM kind `xxx`, expect one of \ + [airbender, jolt, miden, nexus, openvm, pico, risc0, sp1, ziren, zisk]" + .to_string() + ); } } diff --git a/crates/dockerized/src/compiler.rs b/crates/dockerized/src/compiler.rs index cc29b7e..64a82a4 100644 --- a/crates/dockerized/src/compiler.rs +++ b/crates/dockerized/src/compiler.rs @@ -29,10 +29,6 @@ pub use error::Error; /// Images are cached and only rebuilt if they don't exist or if the /// `ERE_FORCE_REBUILD_DOCKER_IMAGE` environment variable is set. fn build_compiler_image(zkvm_kind: zkVMKind) -> Result<(), Error> { - let workspace_dir = workspace_dir(); - let docker_dir = workspace_dir.join("docker"); - let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str()); - let force_rebuild = force_rebuild(); let base_image = base_image(zkvm_kind, false); let base_zkvm_image = base_zkvm_image(zkvm_kind, false); @@ -43,6 +39,10 @@ fn build_compiler_image(zkvm_kind: zkVMKind) -> Result<(), Error> { return Ok(()); } + let workspace_dir = workspace_dir()?; + let docker_dir = workspace_dir.join("docker"); + let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str()); + // Build `ere-base` if force_rebuild || !docker_image_exists(&base_image)? { info!("Building image {base_image}..."); @@ -172,7 +172,7 @@ pub(crate) mod test { compiler_kind: CompilerKind, program: &'static str, ) -> SerializedProgram { - DockerizedCompiler::new(zkvm_kind, compiler_kind, workspace_dir()) + DockerizedCompiler::new(zkvm_kind, compiler_kind, workspace_dir().unwrap()) .unwrap() .compile(&testing_guest_directory(zkvm_kind.as_str(), program)) .unwrap() diff --git a/crates/dockerized/src/util.rs b/crates/dockerized/src/util.rs index 655a278..c060733 100644 --- a/crates/dockerized/src/util.rs +++ b/crates/dockerized/src/util.rs @@ -1,13 +1,16 @@ use std::path::PathBuf; +use ere_zkvm_interface::CommonError; + pub mod cuda; pub mod docker; -pub fn workspace_dir() -> PathBuf { +pub fn workspace_dir() -> Result { let mut dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); dir.pop(); dir.pop(); - dir.canonicalize().unwrap() + dir.canonicalize() + .map_err(|err| CommonError::io("Source code of Ere not found", err)) } pub fn home_dir() -> PathBuf { diff --git a/crates/dockerized/src/zkvm.rs b/crates/dockerized/src/zkvm.rs index ce4b76b..4d11f3b 100644 --- a/crates/dockerized/src/zkvm.rs +++ b/crates/dockerized/src/zkvm.rs @@ -40,10 +40,6 @@ pub use error::Error; /// Images are cached and only rebuilt if they don't exist or if the /// `ERE_FORCE_REBUILD_DOCKER_IMAGE` environment variable is set. fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> { - let workspace_dir = workspace_dir(); - let docker_dir = workspace_dir.join("docker"); - let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str()); - let force_rebuild = force_rebuild(); let base_image = base_image(zkvm_kind, gpu); let base_zkvm_image = base_zkvm_image(zkvm_kind, gpu); @@ -54,6 +50,10 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> { return Ok(()); } + let workspace_dir = workspace_dir()?; + let docker_dir = workspace_dir.join("docker"); + let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str()); + // Build `ere-base` if force_rebuild || !docker_image_exists(&base_image)? { info!("Building image {base_image}..."); diff --git a/crates/zkvm-interface/src/zkvm/resource.rs b/crates/zkvm-interface/src/zkvm/resource.rs index 672eaff..388cf7e 100644 --- a/crates/zkvm-interface/src/zkvm/resource.rs +++ b/crates/zkvm-interface/src/zkvm/resource.rs @@ -26,7 +26,7 @@ impl NetworkProverConfig { /// ResourceType specifies what resource will be used to create the proofs. #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] +#[serde(untagged, rename_all = "kebab-case")] #[cfg_attr(feature = "clap", derive(clap::Subcommand))] pub enum ProverResourceType { #[default]