mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BUILD] re-download package if version has changed (#1294)
This commit is contained in:
@@ -9,6 +9,7 @@ import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from distutils.version import LooseVersion
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
from setuptools import Extension, setup
|
||||
@@ -38,7 +39,6 @@ class Package(NamedTuple):
|
||||
package: str
|
||||
name: str
|
||||
url: str
|
||||
test_file: str
|
||||
include_flag: str
|
||||
lib_flag: str
|
||||
syspath_var_name: str
|
||||
@@ -49,7 +49,7 @@ class Package(NamedTuple):
|
||||
def get_pybind11_package_info():
|
||||
name = "pybind11-2.10.0"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
# llvm
|
||||
|
||||
@@ -65,12 +65,13 @@ def get_llvm_package_info():
|
||||
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
||||
system_suffix = f"linux-gnu-{linux_suffix}"
|
||||
else:
|
||||
return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-17.0.0-8e5a41e8271f/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
version = "llvm-17.0.0-8e5a41e8271f"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
def get_thirdparty_packages(triton_cache_path):
|
||||
@@ -81,8 +82,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
package_dir = os.path.join(package_root_dir, p.name)
|
||||
if p.syspath_var_name in os.environ:
|
||||
package_dir = os.environ[p.syspath_var_name]
|
||||
test_file_path = os.path.join(package_dir, p.test_file)
|
||||
if not os.path.exists(test_file_path):
|
||||
version_file_path = os.path.join(package_dir, "version.txt")
|
||||
if not os.path.exists(version_file_path) or\
|
||||
Path(version_file_path).read_text() != p.url:
|
||||
try:
|
||||
shutil.rmtree(package_root_dir)
|
||||
except Exception:
|
||||
@@ -92,6 +94,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
ftpstream = urllib.request.urlopen(p.url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|*")
|
||||
file.extractall(path=package_root_dir)
|
||||
# write version url to package_root_dir
|
||||
with open(os.path.join(package_root_dir, "version.txt"), "w") as f:
|
||||
f.write(p.url)
|
||||
if p.include_flag:
|
||||
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
|
||||
if p.lib_flag:
|
||||
|
||||
Reference in New Issue
Block a user