mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again
Conflicts: .gitignore .gitmodules README.md bin/triton-translate.cpp include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Target/AMDGCN/AMDGCNTranslation.h include/triton/Target/HSACO/HSACOTranslation.h lib/Analysis/Allocation.cpp lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/CMakeLists.txt lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/Utility.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/HSACO/CMakeLists.txt lib/Target/HSACO/HSACOTranslation.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/language/test_core.py python/test/unit/operators/test_flash_attention.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -8,6 +8,7 @@ import sysconfig
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from distutils.command.clean import clean
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
@@ -58,8 +59,8 @@ class Package(NamedTuple):
|
||||
|
||||
|
||||
def get_pybind11_package_info():
|
||||
name = "pybind11-2.10.0"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||
name = "pybind11-2.11.1"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
|
||||
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
# llvm
|
||||
@@ -124,15 +125,12 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
# ---- package data ---
|
||||
|
||||
|
||||
def download_and_copy_ptxas():
|
||||
|
||||
def download_and_copy(src_path, version, url_func):
|
||||
base_dir = os.path.dirname(__file__)
|
||||
src_path = "bin/ptxas"
|
||||
version = "12.1.105"
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
arch = "64"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
|
||||
url = url_func(arch, version)
|
||||
dst_prefix = os.path.join(base_dir, "triton")
|
||||
dst_suffix = os.path.join("third_party", "cuda", src_path)
|
||||
dst_path = os.path.join(dst_prefix, dst_suffix)
|
||||
@@ -155,9 +153,28 @@ def download_and_copy_ptxas():
|
||||
shutil.copy(src_path, dst_path)
|
||||
return dst_suffix
|
||||
|
||||
|
||||
# ---- cmake extension ----
|
||||
|
||||
|
||||
def get_base_dir():
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
||||
|
||||
|
||||
def get_cmake_dir():
|
||||
plat_name = sysconfig.get_platform()
|
||||
python_version = sysconfig.get_python_version()
|
||||
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
|
||||
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
|
||||
cmake_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cmake_dir
|
||||
|
||||
|
||||
class CMakeClean(clean):
|
||||
def initialize_options(self):
|
||||
clean.initialize_options(self)
|
||||
self.build_temp = get_cmake_dir()
|
||||
|
||||
|
||||
class CMakeBuildPy(build_py):
|
||||
def run(self) -> None:
|
||||
self.run_command('build_ext')
|
||||
@@ -178,10 +195,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def initialize_options(self):
|
||||
build_ext.initialize_options(self)
|
||||
self.base_dir = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
os.pardir))
|
||||
self.base_dir = get_base_dir()
|
||||
|
||||
def finalize_options(self):
|
||||
build_ext.finalize_options(self)
|
||||
@@ -200,14 +214,6 @@ class CMakeBuild(build_ext):
|
||||
for ext in self.extensions:
|
||||
self.build_extension(ext)
|
||||
|
||||
def get_cmake_dir(self):
|
||||
plat_name = sysconfig.get_platform()
|
||||
python_version = sysconfig.get_python_version()
|
||||
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
|
||||
cmake_dir = Path(self.base_dir) / "python" / "build" / dir_name
|
||||
cmake_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cmake_dir
|
||||
|
||||
def build_extension(self, ext):
|
||||
lit_dir = shutil.which('lit')
|
||||
ninja_dir = shutil.which('ninja')
|
||||
@@ -267,14 +273,21 @@ class CMakeBuild(build_ext):
|
||||
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
|
||||
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"]
|
||||
|
||||
if check_env_flag("TRITON_BUILD_WITH_CCACHE"):
|
||||
cmake_args += [
|
||||
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
cmake_dir = self.get_cmake_dir()
|
||||
cmake_dir = get_cmake_dir()
|
||||
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
|
||||
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
|
||||
|
||||
|
||||
download_and_copy_ptxas()
|
||||
|
||||
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
@@ -288,7 +301,6 @@ setup(
|
||||
"triton/_C",
|
||||
"triton/common",
|
||||
"triton/compiler",
|
||||
"triton/interpreter",
|
||||
"triton/language",
|
||||
"triton/language/extra",
|
||||
"triton/ops",
|
||||
@@ -304,7 +316,7 @@ setup(
|
||||
],
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||
cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy},
|
||||
cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean},
|
||||
zip_safe=False,
|
||||
# for PyPI
|
||||
keywords=["Compiler", "Deep Learning"],
|
||||
|
||||
Reference in New Issue
Block a user