mirror of
https://github.com/eth-act/ere.git
synced 2026-02-19 11:54:42 -05:00
feat: read env CUDA_ARCHS instead of CUDA_ARCH to accept comma-separated compute cap (e.g. CUDA_ARCHS=89,90,120)
This commit is contained in:
51
.github/scripts/build-image.sh
vendored
51
.github/scripts/build-image.sh
vendored
@@ -11,11 +11,11 @@ BUILD_COMPILER=false
|
||||
BUILD_SERVER=false
|
||||
BUILD_CLUSTER=false
|
||||
CUDA=false
|
||||
CUDA_ARCH=""
|
||||
CUDA_ARCHS=""
|
||||
RUSTFLAGS=""
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 --zkvm <zkvm> --tag <tag> [--base] [--compiler] [--server] [--cluster] [--registry <registry>] [--cuda] [--cuda-arch <arch>] [--rustflags <flags>]"
|
||||
echo "Usage: $0 --zkvm <zkvm> --tag <tag> [--base] [--compiler] [--server] [--cluster] [--registry <registry>] [--cuda] [--cuda-archs <archs>] [--rustflags <flags>]"
|
||||
echo ""
|
||||
echo "Required:"
|
||||
echo " --zkvm <zkvm> zkVM to build for (e.g., zisk, sp1, risc0)"
|
||||
@@ -30,7 +30,7 @@ usage() {
|
||||
echo "Optional:"
|
||||
echo " --registry <reg> Registry prefix (e.g., ghcr.io/eth-act/ere)"
|
||||
echo " --cuda Enable CUDA support (appends -cuda to tag)"
|
||||
echo " --cuda-arch <arch> Set CUDA architecture (e.g., sm_120)"
|
||||
echo " --cuda-archs <archs> Set CUDA architectures (comma-separated, e.g., 89,120). Implies --cuda."
|
||||
echo " --rustflags <flags> Pass RUSTFLAGS to build"
|
||||
exit 1
|
||||
}
|
||||
@@ -70,8 +70,9 @@ while [[ $# -gt 0 ]]; do
|
||||
CUDA=true
|
||||
shift
|
||||
;;
|
||||
--cuda-arch)
|
||||
CUDA_ARCH="$2"
|
||||
--cuda-archs)
|
||||
CUDA_ARCHS="$2"
|
||||
CUDA=true
|
||||
shift 2
|
||||
;;
|
||||
--rustflags)
|
||||
@@ -141,10 +142,42 @@ if [ "$CUDA" = true ]; then
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA=1")
|
||||
fi
|
||||
|
||||
if [ -n "$CUDA_ARCH" ]; then
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCH")
|
||||
# Default CUDA_ARCHS when --cuda is set but --cuda-archs not specified
|
||||
if [ "$CUDA" = true ] && [ -z "$CUDA_ARCHS" ]; then
|
||||
CUDA_ARCHS="89,120"
|
||||
fi
|
||||
|
||||
# Per-zkVM CUDA architecture translation
|
||||
if [ "$CUDA" = true ] && [ -n "$CUDA_ARCHS" ]; then
|
||||
case "$ZKVM" in
|
||||
airbender)
|
||||
CUDAARCHS=$(echo "$CUDA_ARCHS" | tr ',' ';')
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDAARCHS=$CUDAARCHS")
|
||||
;;
|
||||
openvm)
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=$CUDA_ARCHS")
|
||||
;;
|
||||
risc0)
|
||||
NVCC_APPEND_FLAGS=""
|
||||
IFS=',' read -ra ARCH_ARRAY <<< "$CUDA_ARCHS"
|
||||
for arch in "${ARCH_ARRAY[@]}"; do
|
||||
NVCC_APPEND_FLAGS+=" --generate-code arch=compute_${arch},code=sm_${arch}"
|
||||
done
|
||||
NVCC_APPEND_FLAGS="${NVCC_APPEND_FLAGS# }"
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "NVCC_APPEND_FLAGS=$NVCC_APPEND_FLAGS")
|
||||
;;
|
||||
zisk)
|
||||
MAX_CUDA_ARCH=$(echo "$CUDA_ARCHS" | tr ',' '\n' | sort -n | tail -1)
|
||||
BASE_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
SERVER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
CLUSTER_ZKVM_BUILD_ARGS+=(--build-arg "CUDA_ARCH=sm_${MAX_CUDA_ARCH}")
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [ -n "$RUSTFLAGS" ]; then
|
||||
|
||||
Reference in New Issue
Block a user