mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Linux arm64 support (#2003)
We are interested in having python wheels for triton built for Linux arm64 platforms, such as NVIDIA's Grace CPU. This change is fairly simple, however: - It requires a linux arm64 build of LLVM to be available (see MR here: https://github.com/ptillet/triton-llvm-releases/pull/15) - For now my changes use the LLVM build hosted here: https://github.com/acollins3/triton-llvm-releases/releases/tag/llvm-17.0.0-c5dede880d17 - The Triton release process will need to be updated to include arm64 wheels. Is this something you have time to work on @ptillet? It would be difficult for me to update this part without more access permissions. With these changes, I managed to build a set of python wheels and have hosted them here for us to use in the meantime: https://github.com/acollins3/triton/releases/tag/triton-2.1.0-arm64
This commit is contained in:
@@ -68,7 +68,9 @@ def get_pybind11_package_info():
|
||||
def get_llvm_package_info():
|
||||
# added statement for Apple Silicon
|
||||
system = platform.system()
|
||||
arch = 'x86_64'
|
||||
arch = platform.machine()
|
||||
if arch == 'aarch64':
|
||||
arch = 'arm64'
|
||||
if system == "Darwin":
|
||||
system_suffix = "apple-darwin"
|
||||
arch = platform.machine()
|
||||
@@ -84,6 +86,9 @@ def get_llvm_package_info():
|
||||
name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}'
|
||||
version = "llvm-17.0.0-c5dede880d17"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
# FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases
|
||||
if arch == 'arm64' and 'linux' in system_suffix:
|
||||
url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
@@ -124,7 +129,10 @@ def download_and_copy_ptxas():
|
||||
base_dir = os.path.dirname(__file__)
|
||||
src_path = "bin/ptxas"
|
||||
version = "12.1.105"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-64/cuda-nvcc-{version}-0.tar.bz2"
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
arch = "64"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
|
||||
dst_prefix = os.path.join(base_dir, "triton")
|
||||
dst_suffix = os.path.join("third_party", "cuda", src_path)
|
||||
dst_path = os.path.join(dst_prefix, dst_suffix)
|
||||
|
||||
Reference in New Issue
Block a user