From 8357931eda508fa2fa2af475858780e3885807d4 Mon Sep 17 00:00:00 2001 From: han0110 Date: Tue, 10 Feb 2026 12:26:04 +0000 Subject: [PATCH] feat: read env `CUDA_ARCHS` instead of `CUDA_ARCH` to accept comma-separated compute cap (e.g. `CUDA_ARCHS=89,90,120`) --- .github/scripts/build-image.sh | 51 +++++++++++++++++++++++------ crates/dockerized/src/image.rs | 11 +++++-- crates/dockerized/src/util/cuda.rs | 49 +++++++++++++++++----------- crates/dockerized/src/zkvm.rs | 52 +++++++++++++++++++++++++----- 4 files changed, 125 insertions(+), 38 deletions(-) diff --git a/.github/scripts/build-image.sh b/.github/scripts/build-image.sh index ca6a069..badd329 100755 --- a/.github/scripts/build-image.sh +++ b/.github/scripts/build-image.sh @@ -11,11 +11,11 @@ BUILD_COMPILER=false BUILD_SERVER=false BUILD_CLUSTER=false CUDA=false -CUDA_ARCH="" +CUDA_ARCHS="" RUSTFLAGS="" usage() { - echo "Usage: $0 --zkvm --tag [--base] [--compiler] [--server] [--cluster] [--registry ] [--cuda] [--cuda-arch ] [--rustflags ]" + echo "Usage: $0 --zkvm --tag [--base] [--compiler] [--server] [--cluster] [--registry ] [--cuda] [--cuda-archs ] [--rustflags ]" echo "" echo "Required:" echo " --zkvm zkVM to build for (e.g., zisk, sp1, risc0)" @@ -30,7 +30,7 @@ usage() { echo "Optional:" echo " --registry Registry prefix (e.g., ghcr.io/eth-act/ere)" echo " --cuda Enable CUDA support (appends -cuda to tag)" - echo " --cuda-arch Set CUDA architecture (e.g., sm_120)" + echo " --cuda-archs Set CUDA architectures (comma-separated, e.g., 89,120). Implies --cuda." echo " --rustflags Pass RUSTFLAGS to build" exit 1 } @@ -70,8 +70,9 @@ while [[ $# -gt 0 ]]; do CUDA=true shift ;; - --cuda-arch) - CUDA_ARCH="$2" + --cuda-archs) + CUDA_ARCHS="$2" + CUDA=true shift 2 ;; --rustflags) @@ -141,10 +142,42 @@ if [ "$CUDA" = true ]; then CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA=1") fi -if [ -n "$CUDA_ARCH" ]; then - BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH") - SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH") - CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH") +# Default CUDA_ARCHS when --cuda is set but --cuda-archs not specified +if [ "$CUDA" = true ] && [ -z "$CUDA_ARCHS" ]; then + CUDA_ARCHS="89,120" +fi + +# Per-zkVM CUDA architecture translation +if [ "$CUDA" = true ] && [ -n "$CUDA_ARCHS" ]; then + case "$ZKVM" in + airbender) + CUDAARCHS=$(echo "$CUDA_ARCHS" | tr ',' ';') + BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS") + SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS") + ;; + openvm) + BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS") + SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS") + ;; + risc0) + NVCC_APPEND_FLAGS="" + IFS=',' read -ra ARCH_ARRAY <<< "$CUDA_ARCHS" + for arch in "${ARCH_ARRAY[@]}"; do + NVCC_APPEND_FLAGS+=" --generate-code arch=compute_${arch},code=sm_${arch}" + done + NVCC_APPEND_FLAGS="${NVCC_APPEND_FLAGS# }" + BASE_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS") + SERVER_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS") + ;; + zisk) + MAX_CUDA_ARCH=$(echo "$CUDA_ARCHS" | tr ',' '\n' | sort -n | tail -1) + BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}") + SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}") + CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}") + ;; + *) + ;; + esac fi if [ -n "$RUSTFLAGS" ]; then diff --git a/crates/dockerized/src/image.rs b/crates/dockerized/src/image.rs index d4f1011..6ebc4bc 100644 --- a/crates/dockerized/src/image.rs +++ b/crates/dockerized/src/image.rs @@ -5,9 +5,14 @@ pub fn image_tag(zkvm_kind: zkVMKind, gpu: bool) -> String { let suffix = match (zkvm_kind, gpu) { // Only the following zkVMs requires CUDA setup in the base image // when GPU support is required. - (zkVMKind::Airbender | zkVMKind::OpenVM | zkVMKind::Risc0 | zkVMKind::Zisk, true) => { - "-cuda" - } + ( + zkVMKind::Airbender + | zkVMKind::OpenVM + | zkVMKind::Risc0 + | zkVMKind::SP1 + | zkVMKind::Zisk, + true, + ) => "-cuda", _ => "", }; format!("{DOCKER_IMAGE_TAG}{suffix}") diff --git a/crates/dockerized/src/util/cuda.rs b/crates/dockerized/src/util/cuda.rs index 4347587..7abe7dd 100644 --- a/crates/dockerized/src/util/cuda.rs +++ b/crates/dockerized/src/util/cuda.rs @@ -25,28 +25,41 @@ pub fn cuda_compute_cap() -> Option { ) } -/// Returns the GPU code in format `sm_{numeric_compute_cap}` (e.g. `sm_120`). +/// Returns CUDA architecture(s) as comma-separated numeric strings +/// (e.g. "120", "89,120"). /// /// It does the following checks and returns the first valid value: -/// 1. Read env variable `CUDA_ARCH` and check if it is in valid format. -/// 2. Detect compute capability of the first visible GPU and format to GPU code. +/// 1. Read env variable `CUDA_ARCHS` and validate format (comma-separated numbers). +/// 2. Detect compute capability of the first visible GPU and convert to numeric format. /// /// Otherwise it returns `None`. -pub fn cuda_arch() -> Option { - if let Ok(cuda_arch) = env::var("CUDA_ARCH") { - if cuda_arch.starts_with("sm_") && cuda_arch[3..].parse::().is_ok() { - info!("Using CUDA_ARCH {cuda_arch} from env variable"); - Some(cuda_arch) - } else { - warn!( - "Skipping CUDA_ARCH {cuda_arch} from env variable (expected to be in format `sm_XX`)" - ); - None +pub fn cuda_archs() -> Option { + if let Ok(val) = env::var("CUDA_ARCHS") { + let valid = !val.is_empty() + && val + .split(',') + .all(|s| !s.is_empty() && s.parse::().is_ok()); + if valid { + info!("Using CUDA_ARCHS {val} from env variable"); + return Some(val); } - } else if let Some(cap) = cuda_compute_cap() { - info!("Using CUDA compute capability {} detected", cap); - Some(format!("sm_{}", cap.replace(".", ""))) - } else { - None + warn!( + "Skipping CUDA_ARCHS {val} from env variable \ + (expected comma-separated numbers, e.g. \"89,120\")" + ); } + + if let Some(cap) = cuda_compute_cap() { + let numeric = cap.replace('.', ""); + if numeric.parse::().is_ok() { + info!("Using CUDA compute capability {cap} detected (CUDA_ARCHS={numeric})"); + return Some(numeric); + } + warn!( + "Skipping CUDA compute capability {cap} detected \ + (expected a version number, e.g. 12.0)" + ); + } + + None } diff --git a/crates/dockerized/src/zkvm.rs b/crates/dockerized/src/zkvm.rs index f3e9780..d1e0ac1 100644 --- a/crates/dockerized/src/zkvm.rs +++ b/crates/dockerized/src/zkvm.rs @@ -2,7 +2,7 @@ use crate::{ compiler::SerializedProgram, image::{base_image, base_zkvm_image, server_zkvm_image}, util::{ - cuda::cuda_arch, + cuda::cuda_archs, docker::{ DockerBuildCmd, DockerRunCmd, docker_container_exists, docker_image_exists, docker_pull_image, stop_docker_container, @@ -41,6 +41,40 @@ mod error; pub use error::Error; +/// Applies per-zkVM CUDA architecture build args to a Docker build command. +/// +/// Each zkVM expects a different format for specifying CUDA architectures: +/// - Airbender: `CUDAARCHS` (semicolon-separated, e.g. "89;120") +/// - OpenVM: `CUDA_ARCH` (comma-separated, e.g. "89,120") +/// - Risc0: `NVCC_APPEND_FLAGS` (nvcc --generate-code flags) +/// - Zisk: `CUDA_ARCH` (single largest arch, e.g. "sm_120") +fn apply_cuda_build_args( + cmd: DockerBuildCmd, + zkvm_kind: zkVMKind, + cuda_archs: &str, +) -> DockerBuildCmd { + match zkvm_kind { + zkVMKind::Airbender => cmd.build_arg("CUDAARCHS", cuda_archs.replace(',', ";")), + zkVMKind::OpenVM => cmd.build_arg("CUDA_ARCH", cuda_archs), + zkVMKind::Risc0 => { + let flags = cuda_archs + .split(',') + .map(|arch| format!("--generate-code arch=compute_{arch},code=sm_{arch} ")) + .collect::(); + cmd.build_arg("NVCC_APPEND_FLAGS", flags.trim_end()) + } + zkVMKind::Zisk => { + let max_cuda_arch = cuda_archs + .split(',') + .filter_map(|s| s.parse::().ok()) + .max() + .unwrap_or(120); + cmd.build_arg("CUDA_ARCH", format!("sm_{max_cuda_arch}")) + } + _ => cmd, + } +} + /// This method builds 3 Docker images in sequence: /// 1. `ere-base:{version}` - Base image with common dependencies /// 2. `ere-base-{zkvm}:{version}` - zkVM-specific base image with the zkVM SDK @@ -77,6 +111,9 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> { let docker_dir = workspace_dir.join("docker"); let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str()); + // Resolve CUDA architectures once for both base-zkvm and server builds. + let cuda_archs = if gpu { cuda_archs() } else { None }; + // Build `ere-base` if force_rebuild || !docker_image_exists(&base_image)? { info!("Building image {base_image}..."); @@ -105,13 +142,8 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> { if gpu { cmd = cmd.build_arg("CUDA", "1"); - match zkvm_kind { - zkVMKind::Airbender | zkVMKind::OpenVM | zkVMKind::Risc0 | zkVMKind::Zisk => { - if let Some(cuda_arch) = cuda_arch() { - cmd = cmd.build_arg("CUDA_ARCH", cuda_arch) - } - } - _ => {} + if let Some(ref cuda_archs) = cuda_archs { + cmd = apply_cuda_build_args(cmd, zkvm_kind, cuda_archs); } } @@ -129,6 +161,10 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> { if gpu { cmd = cmd.build_arg("CUDA", "1"); + + if let Some(ref cuda_archs) = cuda_archs { + cmd = apply_cuda_build_args(cmd, zkvm_kind, cuda_archs); + } } cmd.exec(&workspace_dir)?;