mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Added SASS to asm dict (#2280)
This commit is contained in:
@@ -125,15 +125,15 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
# ---- package data ---
|
||||
|
||||
|
||||
def download_and_copy_ptxas():
|
||||
|
||||
def download_and_copy(src_path, version, url_func):
|
||||
base_dir = os.path.dirname(__file__)
|
||||
src_path = "bin/ptxas"
|
||||
version = "12.1.105"
|
||||
# src_path = "bin/ptxas"
|
||||
# version = "12.1.105"
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
arch = "64"
|
||||
url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
|
||||
url = url_func(arch, version)
|
||||
# url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2"
|
||||
dst_prefix = os.path.join(base_dir, "triton")
|
||||
dst_suffix = os.path.join("third_party", "cuda", src_path)
|
||||
dst_path = os.path.join(dst_prefix, dst_suffix)
|
||||
@@ -156,9 +156,9 @@ def download_and_copy_ptxas():
|
||||
shutil.copy(src_path, dst_path)
|
||||
return dst_suffix
|
||||
|
||||
|
||||
# ---- cmake extension ----
|
||||
|
||||
|
||||
def get_base_dir():
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
||||
|
||||
@@ -280,8 +280,9 @@ class CMakeBuild(build_ext):
|
||||
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
|
||||
|
||||
|
||||
download_and_copy_ptxas()
|
||||
|
||||
download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2")
|
||||
download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2")
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import path_to_nvdisasm
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -50,10 +51,11 @@ def kernel_multi_files(X, Y, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
def extract_file_lines(asm):
|
||||
nvdisasm, _ = path_to_nvdisasm()
|
||||
fd, path = tempfile.mkstemp()
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(asm)
|
||||
asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8")
|
||||
asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8")
|
||||
file_lines = []
|
||||
lines = asm.splitlines()
|
||||
for line in lines:
|
||||
@@ -80,7 +82,7 @@ func_types = ["single", "call", "call_noinline", "multi_files"]
|
||||
@pytest.mark.parametrize("func", func_types)
|
||||
def test_line_info(func: str):
|
||||
try:
|
||||
subprocess.check_output(["nvdisasm", "-h"])
|
||||
_, _ = path_to_nvdisasm()
|
||||
except BaseException:
|
||||
pytest.skip("nvdisasm is not available")
|
||||
|
||||
@@ -99,20 +101,20 @@ def test_line_info(func: str):
|
||||
|
||||
file_lines = extract_file_lines(kernel_info.asm["cubin"])
|
||||
if func == "single":
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 15))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 16))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 17))
|
||||
elif func == "call":
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 28))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 21))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 30))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 29))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 22))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 31))
|
||||
elif func == "call_noinline":
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 42))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 35))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 43))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 36))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 37))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 38))
|
||||
elif func == "multi_files":
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 47))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 49))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 48))
|
||||
assert (check_file_lines(file_lines, "test_line_info.py", 50))
|
||||
assert (check_file_lines(file_lines, "standard.py", 33))
|
||||
assert (check_file_lines(file_lines, "standard.py", 34))
|
||||
assert (check_file_lines(file_lines, "standard.py", 36))
|
||||
|
||||
@@ -101,20 +101,34 @@ def get_backend(device_type: str):
|
||||
return _backends[device_type] if device_type in _backends else None
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
def _path_to_binary(binary: str):
|
||||
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
|
||||
paths = [
|
||||
os.environ.get("TRITON_PTXAS_PATH", ""),
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas")
|
||||
os.path.join(base_dir, "third_party", "cuda", "bin", binary)
|
||||
]
|
||||
|
||||
for ptxas in paths:
|
||||
ptxas_bin = ptxas.split(" ")[0]
|
||||
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
|
||||
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
|
||||
for p in paths:
|
||||
bin = p.split(" ")[0]
|
||||
if os.path.exists(bin) and os.path.isfile(bin):
|
||||
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
return p, version.group(1)
|
||||
raise RuntimeError(f"Cannot find {binary}")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_ptxas():
|
||||
return _path_to_binary("ptxas")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_cuobjdump():
|
||||
return _path_to_binary("cuobjdump")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
@@ -5,7 +5,6 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -24,7 +23,7 @@ from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_ma
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..tools.disasm import extract
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
|
||||
@@ -500,7 +499,6 @@ def compile(fn, **kwargs):
|
||||
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
|
||||
else:
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_dump_manager.put(next_module, ir_filename)
|
||||
if (enable_override and fn_override_manager.has_file(ir_filename)):
|
||||
print(f"\nOverriding kernel with file {ir_filename}")
|
||||
@@ -517,6 +515,11 @@ def compile(fn, **kwargs):
|
||||
|
||||
if ir_name == "cubin":
|
||||
asm[ir_name] = next_module
|
||||
sass_ir = "sass"
|
||||
sass_fname = f"{name}.{sass_ir}"
|
||||
asm[sass_ir] = get_sass(next_module)
|
||||
metadata_group[sass_fname] = fn_cache_manager.put(asm[sass_ir], sass_fname)
|
||||
|
||||
elif ir_name == "amdgcn":
|
||||
asm[ir_name] = str(next_module[0])
|
||||
else:
|
||||
@@ -669,16 +672,3 @@ class CompiledKernel:
|
||||
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
||||
return runner
|
||||
|
||||
def get_sass(self, fun=None):
|
||||
if 'sass' in self.asm:
|
||||
return self.asm['sass']
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(self.asm['cubin'])
|
||||
self.sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
self.asm['sass'] = self.sass
|
||||
return self.sass
|
||||
|
||||
@@ -20,8 +20,12 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from ..common.backend import path_to_cuobjdump, path_to_nvdisasm
|
||||
|
||||
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
|
||||
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
|
||||
@@ -60,11 +64,25 @@ def processSassLines(fline, sline, labels):
|
||||
return (f'{ctrl}', f'{asm}')
|
||||
|
||||
|
||||
def get_sass(cubin_asm, fun=None):
|
||||
fd, path = tempfile.mkstemp()
|
||||
try:
|
||||
with open(fd, 'wb') as cubin:
|
||||
cubin.write(cubin_asm)
|
||||
sass = extract(path, fun)
|
||||
finally:
|
||||
os.remove(path)
|
||||
return sass
|
||||
|
||||
|
||||
def extract(file_path, fun):
|
||||
cuobjdump, _ = path_to_cuobjdump()
|
||||
nvdisasm, _ = path_to_nvdisasm()
|
||||
os.environ["NVDISASM_PATH"] = nvdisasm
|
||||
if fun is None:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path])
|
||||
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
|
||||
else:
|
||||
sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path])
|
||||
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
|
||||
sass_lines = sass_str.splitlines()
|
||||
line_idx = 0
|
||||
while line_idx < len(sass_lines):
|
||||
|
||||
Reference in New Issue
Block a user