mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Improve printf. (#2532)
[BACKEND] Improve printf. Previously, we printed all of a GPU thread's values in a single printf() call, and this, plus the user-specified prefix, was all we printed. This caused a few problems. - nvptx printf can only handle 32 arguments; if you pass more than that, it prints garbage. So if a thread had more than 32 values, you couldn't print them, issue #2486. - The order of the values within the Triton program (GPU thread block) is an implementation detail -- it depends on the layout the compiler assigns to a tensor. So this also prevented you from interpreting the printed output. To address this, we now print the Triton pid and multi-dimensional Tensor index for each value. And each value gets its own line to avoid passing too many args to printf. Example output: ``` pid (0, 1, 2) idx (36, 127) x: 42 ``` If you want to observe all the values in a tensor in order, you can grep and then sort the output. We also make a UX enhancement to print: The printed label always ends with ": "; you don't have to add it yourself. Fixes #2486.
This commit is contained in:
@@ -11,21 +11,47 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.device_print("x: ", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
# Triton should add a space after this prefix.
|
||||
print("x:", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
|
||||
# the static print is run every time.
|
||||
@triton.jit
|
||||
def kernel_device_print_large(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
|
||||
# Triton should change this prefix to "x: ".
|
||||
tl.device_print("x ", x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.full((BLOCK,), 1, tl.int32)
|
||||
print("", x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.full((BLOCK,), 1, tl.int32)
|
||||
tl.device_print("", x, y)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
|
||||
# This function takes an extra value as a tl.constexpr so this kernel is not
|
||||
# cached. This way the static print is run every time.
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
@@ -38,19 +64,27 @@ def kernel_no_arg_print():
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_large":
|
||||
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
|
||||
elif func == "print_multiple_args":
|
||||
kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_multiple_args":
|
||||
kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
else:
|
||||
assert f"Unknown kernel: {func}"
|
||||
|
||||
if func != "no_arg_print":
|
||||
if func != "no_arg_print" and func != "device_print_large" and \
|
||||
func != "print_multiple_args" and func != "device_print_multiple_args":
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,26 +16,53 @@ nested_types = [(caller, callee) for caller in ["true", "false", "none"] for cal
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
# TODO: Print with multiple operands
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
|
||||
[("device_print", data_type) for data_type in torch_types] + [
|
||||
("print", "int32"),
|
||||
("static_print", "int32"),
|
||||
("no_arg_print", "int32"),
|
||||
("device_print_large", "int32"),
|
||||
("print_multiple_args", "int32"),
|
||||
("device_print_multiple_args", "int32"),
|
||||
])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print" and func_type != "no_arg_print":
|
||||
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
|
||||
|
||||
# Format is
|
||||
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
|
||||
expected_lines = Counter()
|
||||
if func_type == "print" or func_type == "device_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
|
||||
if data_type.startswith("float"):
|
||||
line += ".000000"
|
||||
expected_lines[line] = 1
|
||||
elif func_type == "static_print":
|
||||
expected_lines[" int32[constexpr[128]]"] = 1
|
||||
elif func_type == "no_arg_print":
|
||||
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
|
||||
elif func_type == "device_print_large":
|
||||
for i, j, k in itertools.product(range(2), range(64), range(128)):
|
||||
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
|
||||
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
|
||||
for i in range(128):
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
|
||||
|
||||
actual_lines = Counter()
|
||||
for line in outs:
|
||||
actual_lines[line] += 1
|
||||
|
||||
diff = Counter(actual_lines)
|
||||
diff.subtract(expected_lines)
|
||||
for line, delta in diff.items():
|
||||
if delta == 0:
|
||||
continue
|
||||
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
|
||||
assert all(delta == 0 for delta in diff.values())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", assert_types)
|
||||
|
||||
@@ -1577,6 +1577,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
||||
# non-empty prefixes should start with " ".
|
||||
if not prefix.endswith(" "):
|
||||
prefix += " "
|
||||
if not prefix.endswith(": "):
|
||||
prefix = prefix[:-1] + ": "
|
||||
if len(prefix) > 2 and not prefix.startswith(" "):
|
||||
prefix = " " + prefix
|
||||
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
|
||||
Reference in New Issue
Block a user