Files
ROCm/python/test/unit/language/test_line_info.py
2023-09-13 21:21:01 +00:00

121 lines
3.7 KiB
Python

import subprocess
import tempfile
import pytest
import torch
import triton
import triton.language as tl
from triton.common.backend import path_to_nvdisasm
@triton.jit
def kernel_single(X,
Y,
BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def device_inline(x):
return x + x
@triton.jit
def kernel_call(X,
Y,
BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = device_inline(x)
tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit(noinline=True)
def device_noinline(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = x + x
tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel_call_noinline(X, Y, BLOCK: tl.constexpr):
device_noinline(X, Y, BLOCK)
@triton.jit
def kernel_multi_files(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.softmax(x)
tl.store(Y + tl.arange(0, BLOCK), y)
def extract_file_lines(asm):
nvdisasm, _ = path_to_nvdisasm()
fd, path = tempfile.mkstemp()
with open(fd, 'wb') as cubin:
cubin.write(asm)
asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8")
file_lines = []
lines = asm.splitlines()
for line in lines:
if "## File" in line:
entries = line[line.index("## File"):].split(",")
file_lines.append((entries[0].strip(), entries[1].strip()))
return file_lines
def check_file_lines(file_lines, file_name, lineno):
for file, line in file_lines:
# -1 means do not check line number
if lineno == -1:
if file_name in file:
return True
if file_name in file and str(lineno) in line:
return True
return False
func_types = ["single", "call", "call_noinline", "multi_files"]
@pytest.mark.parametrize("func", func_types)
def test_line_info(func: str):
try:
_, _ = path_to_nvdisasm()
except BaseException:
pytest.skip("nvdisasm is not available")
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.float32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel_info = {}
if func == "single":
kernel_info = kernel_single[(1,)](x, y, BLOCK=shape[0])
elif func == "call":
kernel_info = kernel_call[(1,)](x, y, BLOCK=shape[0])
elif func == "call_noinline":
kernel_info = kernel_call_noinline[(1,)](x, y, BLOCK=shape[0])
elif func == "multi_files":
kernel_info = kernel_multi_files[(1,)](x, y, BLOCK=shape[0])
file_lines = extract_file_lines(kernel_info.asm["cubin"])
if func == "single":
assert (check_file_lines(file_lines, "test_line_info.py", 16))
assert (check_file_lines(file_lines, "test_line_info.py", 17))
elif func == "call":
assert (check_file_lines(file_lines, "test_line_info.py", 29))
assert (check_file_lines(file_lines, "test_line_info.py", 22))
assert (check_file_lines(file_lines, "test_line_info.py", 31))
elif func == "call_noinline":
assert (check_file_lines(file_lines, "test_line_info.py", 43))
assert (check_file_lines(file_lines, "test_line_info.py", 36))
assert (check_file_lines(file_lines, "test_line_info.py", 37))
assert (check_file_lines(file_lines, "test_line_info.py", 38))
elif func == "multi_files":
assert (check_file_lines(file_lines, "test_line_info.py", 48))
assert (check_file_lines(file_lines, "test_line_info.py", 50))
assert (check_file_lines(file_lines, "standard.py", 33))
assert (check_file_lines(file_lines, "standard.py", 34))
assert (check_file_lines(file_lines, "standard.py", 36))