mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BUILD] make cuda tools vendoring optional (#2546)
This commit is contained in:
@@ -124,7 +124,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
# ---- package data ---
|
||||
|
||||
|
||||
def download_and_copy(src_path, version, url_func):
|
||||
def download_and_copy(src_path, variable, version, url_func):
|
||||
if variable in os.environ:
|
||||
return
|
||||
base_dir = os.path.dirname(__file__)
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
@@ -150,7 +152,6 @@ def download_and_copy(src_path, version, url_func):
|
||||
src_path = os.path.join(temp_dir, src_path)
|
||||
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
|
||||
shutil.copy(src_path, dst_path)
|
||||
return dst_suffix
|
||||
|
||||
# ---- cmake extension ----
|
||||
|
||||
@@ -298,9 +299,24 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
|
||||
|
||||
|
||||
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
|
||||
download_and_copy(
|
||||
src_path="bin/ptxas",
|
||||
variable="TRITON_PTXAS_PATH",
|
||||
version="12.1.105",
|
||||
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
|
||||
)
|
||||
download_and_copy(
|
||||
src_path="bin/cuobjdump",
|
||||
variable="TRITON_CUOBJDUMP_PATH",
|
||||
version="12.1.111",
|
||||
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
|
||||
)
|
||||
download_and_copy(
|
||||
src_path="bin/nvdisasm",
|
||||
variable="TRITON_NVDISASM_PATH",
|
||||
version="12.1.105",
|
||||
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
|
||||
)
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
|
||||
@@ -108,7 +108,7 @@ def get_backend(device_type: str):
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
@@ -174,6 +174,10 @@ def get_cuda_version_key():
|
||||
global _cached_cuda_version_key
|
||||
if _cached_cuda_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
ptxas = path_to_ptxas()[0]
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
try:
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = subprocess.check_output([ptxas, "--version"])
|
||||
except RuntimeError:
|
||||
ptxas_version = b"NO_PTXAS"
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
|
||||
return _cached_cuda_version_key
|
||||
|
||||
Reference in New Issue
Block a user