mirror of
https://github.com/eth-act/ere.git
synced 2026-02-19 11:54:42 -05:00
feat: read env CUDA_ARCHS instead of CUDA_ARCH to accept comma-separated compute cap (e.g. CUDA_ARCHS=89,90,120)
This commit is contained in:
51
.github/scripts/build-image.sh
vendored
51
.github/scripts/build-image.sh
vendored
@@ -11,11 +11,11 @@ BUILD_COMPILER=false
|
||||
BUILD_SERVER=false
|
||||
BUILD_CLUSTER=false
|
||||
CUDA=false
|
||||
CUDA_ARCH=""
|
||||
CUDA_ARCHS=""
|
||||
RUSTFLAGS=""
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 --zkvm <zkvm> --tag <tag> [--base] [--compiler] [--server] [--cluster] [--registry <registry>] [--cuda] [--cuda-arch <arch>] [--rustflags <flags>]"
|
||||
echo "Usage: $0 --zkvm <zkvm> --tag <tag> [--base] [--compiler] [--server] [--cluster] [--registry <registry>] [--cuda] [--cuda-archs <archs>] [--rustflags <flags>]"
|
||||
echo ""
|
||||
echo "Required:"
|
||||
echo " --zkvm <zkvm> zkVM to build for (e.g., zisk, sp1, risc0)"
|
||||
@@ -30,7 +30,7 @@ usage() {
|
||||
echo "Optional:"
|
||||
echo " --registry <reg> Registry prefix (e.g., ghcr.io/eth-act/ere)"
|
||||
echo " --cuda Enable CUDA support (appends -cuda to tag)"
|
||||
echo " --cuda-arch <arch> Set CUDA architecture (e.g., sm_120)"
|
||||
echo " --cuda-archs <archs> Set CUDA architectures (comma-separated, e.g., 89,120). Implies --cuda."
|
||||
echo " --rustflags <flags> Pass RUSTFLAGS to build"
|
||||
exit 1
|
||||
}
|
||||
@@ -70,8 +70,9 @@ while [[ $# -gt 0 ]]; do
|
||||
CUDA=true
|
||||
shift
|
||||
;;
|
||||
--cuda-arch)
|
||||
CUDA_ARCH="$2"
|
||||
--cuda-archs)
|
||||
CUDA_ARCHS="$2"
|
||||
CUDA=true
|
||||
shift 2
|
||||
;;
|
||||
--rustflags)
|
||||
@@ -141,10 +142,42 @@ if [ "$CUDA" = true ]; then
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA=1")
|
||||
fi
|
||||
|
||||
if [ -n "$CUDA_ARCH" ]; then
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
# Default CUDA_ARCHS when --cuda is set but --cuda-archs not specified
|
||||
if [ "$CUDA" = true ] && [ -z "$CUDA_ARCHS" ]; then
|
||||
CUDA_ARCHS="89,120"
|
||||
fi
|
||||
|
||||
# Per-zkVM CUDA architecture translation
|
||||
if [ "$CUDA" = true ] && [ -n "$CUDA_ARCHS" ]; then
|
||||
case "$ZKVM" in
|
||||
airbender)
|
||||
CUDAARCHS=$(echo "$CUDA_ARCHS" | tr ',' ';')
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS")
|
||||
;;
|
||||
openvm)
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS")
|
||||
;;
|
||||
risc0)
|
||||
NVCC_APPEND_FLAGS=""
|
||||
IFS=',' read -ra ARCH_ARRAY <<< "$CUDA_ARCHS"
|
||||
for arch in "${ARCH_ARRAY[@]}"; do
|
||||
NVCC_APPEND_FLAGS+=" --generate-code arch=compute_${arch},code=sm_${arch}"
|
||||
done
|
||||
NVCC_APPEND_FLAGS="${NVCC_APPEND_FLAGS# }"
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS")
|
||||
;;
|
||||
zisk)
|
||||
MAX_CUDA_ARCH=$(echo "$CUDA_ARCHS" | tr ',' '\n' | sort -n | tail -1)
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [ -n "$RUSTFLAGS" ]; then
|
||||
|
||||
@@ -5,9 +5,14 @@ pub fn image_tag(zkvm_kind: zkVMKind, gpu: bool) -> String {
|
||||
let suffix = match (zkvm_kind, gpu) {
|
||||
// Only the following zkVMs requires CUDA setup in the base image
|
||||
// when GPU support is required.
|
||||
(zkVMKind::Airbender | zkVMKind::OpenVM | zkVMKind::Risc0 | zkVMKind::Zisk, true) => {
|
||||
"-cuda"
|
||||
}
|
||||
(
|
||||
zkVMKind::Airbender
|
||||
| zkVMKind::OpenVM
|
||||
| zkVMKind::Risc0
|
||||
| zkVMKind::SP1
|
||||
| zkVMKind::Zisk,
|
||||
true,
|
||||
) => "-cuda",
|
||||
_ => "",
|
||||
};
|
||||
format!("{DOCKER_IMAGE_TAG}{suffix}")
|
||||
|
||||
@@ -25,28 +25,41 @@ pub fn cuda_compute_cap() -> Option<String> {
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the GPU code in format `sm_{numeric_compute_cap}` (e.g. `sm_120`).
|
||||
/// Returns CUDA architecture(s) as comma-separated numeric strings
|
||||
/// (e.g. "120", "89,120").
|
||||
///
|
||||
/// It does the following checks and returns the first valid value:
|
||||
/// 1. Read env variable `CUDA_ARCH` and check if it is in valid format.
|
||||
/// 2. Detect compute capability of the first visible GPU and format to GPU code.
|
||||
/// 1. Read env variable `CUDA_ARCHS` and validate format (comma-separated numbers).
|
||||
/// 2. Detect compute capability of the first visible GPU and convert to numeric format.
|
||||
///
|
||||
/// Otherwise it returns `None`.
|
||||
pub fn cuda_arch() -> Option<String> {
|
||||
if let Ok(cuda_arch) = env::var("CUDA_ARCH") {
|
||||
if cuda_arch.starts_with("sm_") && cuda_arch[3..].parse::<usize>().is_ok() {
|
||||
info!("Using CUDA_ARCH {cuda_arch} from env variable");
|
||||
Some(cuda_arch)
|
||||
} else {
|
||||
warn!(
|
||||
"Skipping CUDA_ARCH {cuda_arch} from env variable (expected to be in format `sm_XX`)"
|
||||
);
|
||||
None
|
||||
pub fn cuda_archs() -> Option<String> {
|
||||
if let Ok(val) = env::var("CUDA_ARCHS") {
|
||||
let valid = !val.is_empty()
|
||||
&& val
|
||||
.split(',')
|
||||
.all(|s| !s.is_empty() && s.parse::<u32>().is_ok());
|
||||
if valid {
|
||||
info!("Using CUDA_ARCHS {val} from env variable");
|
||||
return Some(val);
|
||||
}
|
||||
} else if let Some(cap) = cuda_compute_cap() {
|
||||
info!("Using CUDA compute capability {} detected", cap);
|
||||
Some(format!("sm_{}", cap.replace(".", "")))
|
||||
} else {
|
||||
None
|
||||
warn!(
|
||||
"Skipping CUDA_ARCHS {val} from env variable \
|
||||
(expected comma-separated numbers, e.g. \"89,120\")"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(cap) = cuda_compute_cap() {
|
||||
let numeric = cap.replace('.', "");
|
||||
if numeric.parse::<u32>().is_ok() {
|
||||
info!("Using CUDA compute capability {cap} detected (CUDA_ARCHS={numeric})");
|
||||
return Some(numeric);
|
||||
}
|
||||
warn!(
|
||||
"Skipping CUDA compute capability {cap} detected \
|
||||
(expected a version number, e.g. 12.0)"
|
||||
);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
compiler::SerializedProgram,
|
||||
image::{base_image, base_zkvm_image, server_zkvm_image},
|
||||
util::{
|
||||
cuda::cuda_arch,
|
||||
cuda::cuda_archs,
|
||||
docker::{
|
||||
DockerBuildCmd, DockerRunCmd, docker_container_exists, docker_image_exists,
|
||||
docker_pull_image, stop_docker_container,
|
||||
@@ -41,6 +41,40 @@ mod error;
|
||||
|
||||
pub use error::Error;
|
||||
|
||||
/// Applies per-zkVM CUDA architecture build args to a Docker build command.
|
||||
///
|
||||
/// Each zkVM expects a different format for specifying CUDA architectures:
|
||||
/// - Airbender: `CUDAARCHS` (semicolon-separated, e.g. "89;120")
|
||||
/// - OpenVM: `CUDA_ARCH` (comma-separated, e.g. "89,120")
|
||||
/// - Risc0: `NVCC_APPEND_FLAGS` (nvcc --generate-code flags)
|
||||
/// - Zisk: `CUDA_ARCH` (single largest arch, e.g. "sm_120")
|
||||
fn apply_cuda_build_args(
|
||||
cmd: DockerBuildCmd,
|
||||
zkvm_kind: zkVMKind,
|
||||
cuda_archs: &str,
|
||||
) -> DockerBuildCmd {
|
||||
match zkvm_kind {
|
||||
zkVMKind::Airbender => cmd.build_arg("CUDAARCHS", cuda_archs.replace(',', ";")),
|
||||
zkVMKind::OpenVM => cmd.build_arg("CUDA_ARCH", cuda_archs),
|
||||
zkVMKind::Risc0 => {
|
||||
let flags = cuda_archs
|
||||
.split(',')
|
||||
.map(|arch| format!("--generate-code arch=compute_{arch},code=sm_{arch} "))
|
||||
.collect::<String>();
|
||||
cmd.build_arg("NVCC_APPEND_FLAGS", flags.trim_end())
|
||||
}
|
||||
zkVMKind::Zisk => {
|
||||
let max_cuda_arch = cuda_archs
|
||||
.split(',')
|
||||
.filter_map(|s| s.parse::<u32>().ok())
|
||||
.max()
|
||||
.unwrap_or(120);
|
||||
cmd.build_arg("CUDA_ARCH", format!("sm_{max_cuda_arch}"))
|
||||
}
|
||||
_ => cmd,
|
||||
}
|
||||
}
|
||||
|
||||
/// This method builds 3 Docker images in sequence:
|
||||
/// 1. `ere-base:{version}` - Base image with common dependencies
|
||||
/// 2. `ere-base-{zkvm}:{version}` - zkVM-specific base image with the zkVM SDK
|
||||
@@ -77,6 +111,9 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> {
|
||||
let docker_dir = workspace_dir.join("docker");
|
||||
let docker_zkvm_dir = docker_dir.join(zkvm_kind.as_str());
|
||||
|
||||
// Resolve CUDA architectures once for both base-zkvm and server builds.
|
||||
let cuda_archs = if gpu { cuda_archs() } else { None };
|
||||
|
||||
// Build `ere-base`
|
||||
if force_rebuild || !docker_image_exists(&base_image)? {
|
||||
info!("Building image {base_image}...");
|
||||
@@ -105,13 +142,8 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> {
|
||||
if gpu {
|
||||
cmd = cmd.build_arg("CUDA", "1");
|
||||
|
||||
match zkvm_kind {
|
||||
zkVMKind::Airbender | zkVMKind::OpenVM | zkVMKind::Risc0 | zkVMKind::Zisk => {
|
||||
if let Some(cuda_arch) = cuda_arch() {
|
||||
cmd = cmd.build_arg("CUDA_ARCH", cuda_arch)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
if let Some(ref cuda_archs) = cuda_archs {
|
||||
cmd = apply_cuda_build_args(cmd, zkvm_kind, cuda_archs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,6 +161,10 @@ fn build_server_image(zkvm_kind: zkVMKind, gpu: bool) -> Result<(), Error> {
|
||||
|
||||
if gpu {
|
||||
cmd = cmd.build_arg("CUDA", "1");
|
||||
|
||||
if let Some(ref cuda_archs) = cuda_archs {
|
||||
cmd = apply_cuda_build_args(cmd, zkvm_kind, cuda_archs);
|
||||
}
|
||||
}
|
||||
|
||||
cmd.exec(&workspace_dir)?;
|
||||
|
||||
Reference in New Issue
Block a user