[BUILD] make cuda tools vendoring optional (#2546)

This commit is contained in:
Someone
2023-10-27 06:16:41 +00:00
committed by GitHub
parent 0469d5fccd
commit cde42e6221
2 changed files with 28 additions and 8 deletions

View File

@@ -124,7 +124,9 @@ def get_thirdparty_packages(triton_cache_path):
# ---- package data ---
def download_and_copy(src_path, version, url_func):
def download_and_copy(src_path, variable, version, url_func):
if variable in os.environ:
return
base_dir = os.path.dirname(__file__)
arch = platform.machine()
if arch == "x86_64":
@@ -150,7 +152,6 @@ def download_and_copy(src_path, version, url_func):
src_path = os.path.join(temp_dir, src_path)
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
shutil.copy(src_path, dst_path)
return dst_suffix
# ---- cmake extension ----
@@ -298,9 +299,24 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.1.105",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.1.111",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.1.105",
url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
setup(
name="triton",

View File

@@ -108,7 +108,7 @@ def get_backend(device_type: str):
def _path_to_binary(binary: str):
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
]
@@ -174,6 +174,10 @@ def get_cuda_version_key():
global _cached_cuda_version_key
if _cached_cuda_version_key is None:
key = compute_core_version_key()
ptxas = path_to_ptxas()[0]
_cached_cuda_version_key = key + '-' + hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
try:
ptxas = path_to_ptxas()[0]
ptxas_version = subprocess.check_output([ptxas, "--version"])
except RuntimeError:
ptxas_version = b"NO_PTXAS"
_cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest()
return _cached_cuda_version_key