Linux arm64 support (#2003)

We are interested in having python wheels for triton built for Linux
arm64 platforms, such as NVIDIA's Grace CPU.

This change is fairly simple, however:
- It requires a linux arm64 build of LLVM to be available (see MR here:
https://github.com/ptillet/triton-llvm-releases/pull/15)
- For now my changes use the LLVM build hosted here:
https://github.com/acollins3/triton-llvm-releases/releases/tag/llvm-17.0.0-c5dede880d17
- The Triton release process will need to be updated to include arm64
wheels. Is this something you have time to work on @ptillet? It would be
difficult for me to update this part without more access permissions.

With these changes, I managed to build a set of python wheels and have
hosted them here for us to use in the meantime:
https://github.com/acollins3/triton/releases/tag/triton-2.1.0-arm64
This commit is contained in:
Alex Collins
2023-08-08 05:39:41 +01:00
committed by GitHub
parent 3cec89ebb3
commit 4ed8381fdb

View File

@@ -68,7 +68,9 @@ def get_pybind11_package_info():
def get_llvm_package_info():
# added statement for Apple Silicon
system = platform.system()
arch = 'x86_64'
arch = platform.machine()
if arch == 'aarch64':
arch = 'arm64'
if system == "Darwin":
system_suffix = "apple-darwin"
arch = platform.machine()
@@ -84,6 +86,9 @@ def get_llvm_package_info():
name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-c5dede880d17"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
# FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases
if arch == 'arm64' and 'linux' in system_suffix:
url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
@@ -124,7 +129,10 @@ def download_and_copy_ptxas():
base_dir = os.path.dirname(__file__)
src_path = "bin/ptxas"
version = "12.1.105"
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-64/cuda-nvcc-{version}-0.tar.bz2"
arch = platform.machine()
if arch == "x86_64":
arch = "64"
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
dst_prefix = os.path.join(base_dir, "triton")
dst_suffix = os.path.join("third_party", "cuda", src_path)
dst_path = os.path.join(dst_prefix, dst_suffix)