[FRONTEND] Added SASS to asm dict (#2280)

This commit is contained in:
Zahi Moudallal
2023-09-13 14:21:01 -07:00
committed by GitHub
parent a301502d25
commit 36087a108f
5 changed files with 70 additions and 45 deletions

View File

@@ -125,15 +125,15 @@ def get_thirdparty_packages(triton_cache_path):
# ---- package data ---
def download_and_copy_ptxas():
def download_and_copy(src_path, version, url_func):
base_dir = os.path.dirname(__file__)
src_path = "bin/ptxas"
version = "12.1.105"
# src_path = "bin/ptxas"
# version = "12.1.105"
arch = platform.machine()
if arch == "x86_64":
arch = "64"
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
url = url_func(arch, version)
# url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
dst_prefix = os.path.join(base_dir, "triton")
dst_suffix = os.path.join("third_party", "cuda", src_path)
dst_path = os.path.join(dst_prefix, dst_suffix)
@@ -156,9 +156,9 @@ def download_and_copy_ptxas():
shutil.copy(src_path, dst_path)
return dst_suffix
# ---- cmake extension ----
def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
@@ -280,8 +280,9 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
download_and_copy_ptxas()
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
setup(
name="triton",

View File

@@ -6,6 +6,7 @@ import torch
import triton
import triton.language as tl
from triton.common.backend import path_to_nvdisasm
@triton.jit
@@ -50,10 +51,11 @@ def kernel_multi_files(X, Y, BLOCK: tl.constexpr):
def extract_file_lines(asm):
nvdisasm, _ = path_to_nvdisasm()
fd, path = tempfile.mkstemp()
with open(fd, 'wb') as cubin:
cubin.write(asm)
asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8")
asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8")
file_lines = []
lines = asm.splitlines()
for line in lines:
@@ -80,7 +82,7 @@ func_types = ["single", "call", "call_noinline", "multi_files"]
@pytest.mark.parametrize("func", func_types)
def test_line_info(func: str):
try:
subprocess.check_output(["nvdisasm", "-h"])
_, _ = path_to_nvdisasm()
except BaseException:
pytest.skip("nvdisasm is not available")
@@ -99,20 +101,20 @@ def test_line_info(func: str):
file_lines = extract_file_lines(kernel_info.asm["cubin"])
if func == "single":
assert (check_file_lines(file_lines, "test_line_info.py", 15))
assert (check_file_lines(file_lines, "test_line_info.py", 16))
assert (check_file_lines(file_lines, "test_line_info.py", 17))
elif func == "call":
assert (check_file_lines(file_lines, "test_line_info.py", 28))
assert (check_file_lines(file_lines, "test_line_info.py", 21))
assert (check_file_lines(file_lines, "test_line_info.py", 30))
assert (check_file_lines(file_lines, "test_line_info.py", 29))
assert (check_file_lines(file_lines, "test_line_info.py", 22))
assert (check_file_lines(file_lines, "test_line_info.py", 31))
elif func == "call_noinline":
assert (check_file_lines(file_lines, "test_line_info.py", 42))
assert (check_file_lines(file_lines, "test_line_info.py", 35))
assert (check_file_lines(file_lines, "test_line_info.py", 43))
assert (check_file_lines(file_lines, "test_line_info.py", 36))
assert (check_file_lines(file_lines, "test_line_info.py", 37))
assert (check_file_lines(file_lines, "test_line_info.py", 38))
elif func == "multi_files":
assert (check_file_lines(file_lines, "test_line_info.py", 47))
assert (check_file_lines(file_lines, "test_line_info.py", 49))
assert (check_file_lines(file_lines, "test_line_info.py", 48))
assert (check_file_lines(file_lines, "test_line_info.py", 50))
assert (check_file_lines(file_lines, "standard.py", 33))
assert (check_file_lines(file_lines, "standard.py", 34))
assert (check_file_lines(file_lines, "standard.py", 36))

View File

@@ -101,20 +101,34 @@ def get_backend(device_type: str):
return _backends[device_type] if device_type in _backends else None
@functools.lru_cache()
def path_to_ptxas():
def _path_to_binary(binary: str):
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get("TRITON_PTXAS_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
]
for ptxas in paths:
ptxas_bin = ptxas.split(" ")[0]
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
for p in paths:
bin = p.split(" ")[0]
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
return p, version.group(1)
raise RuntimeError(f"Cannot find {binary}")
@functools.lru_cache()
def path_to_ptxas():
return _path_to_binary("ptxas")
@functools.lru_cache()
def path_to_cuobjdump():
return _path_to_binary("cuobjdump")
@functools.lru_cache()
def path_to_nvdisasm():
return _path_to_binary("nvdisasm")

View File

@@ -5,7 +5,6 @@ import hashlib
import json
import os
import re
import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Any
@@ -24,7 +23,7 @@ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_ma
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
from ..tools.disasm import extract
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
@@ -500,7 +499,6 @@ def compile(fn, **kwargs):
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
else:
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
fn_dump_manager.put(next_module, ir_filename)
if (enable_override and fn_override_manager.has_file(ir_filename)):
print(f"\nOverriding kernel with file {ir_filename}")
@@ -517,6 +515,11 @@ def compile(fn, **kwargs):
if ir_name == "cubin":
asm[ir_name] = next_module
sass_ir = "sass"
sass_fname = f"{name}.{sass_ir}"
asm[sass_ir] = get_sass(next_module)
metadata_group[sass_fname] = fn_cache_manager.put(asm[sass_ir], sass_fname)
elif ir_name == "amdgcn":
asm[ir_name] = str(next_module[0])
else:
@@ -669,16 +672,3 @@ class CompiledKernel:
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner
def get_sass(self, fun=None):
if 'sass' in self.asm:
return self.asm['sass']
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(self.asm['cubin'])
self.sass = extract(path, fun)
finally:
os.remove(path)
self.asm['sass'] = self.sass
return self.sass

View File

@@ -20,8 +20,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import re
import subprocess
import tempfile
from ..common.backend import path_to_cuobjdump, path_to_nvdisasm
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
@@ -60,11 +64,25 @@ def processSassLines(fline, sline, labels):
return (f'{ctrl}', f'{asm}')
def get_sass(cubin_asm, fun=None):
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(cubin_asm)
sass = extract(path, fun)
finally:
os.remove(path)
return sass
def extract(file_path, fun):
cuobjdump, _ = path_to_cuobjdump()
nvdisasm, _ = path_to_nvdisasm()
os.environ["NVDISASM_PATH"] = nvdisasm
if fun is None:
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
else:
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
sass_lines = sass_str.splitlines()
line_idx = 0
while line_idx < len(sass_lines):