[BACKEND] Improve printf. (#2532)

[BACKEND] Improve printf.

Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

 - nvptx printf can only handle 32 arguments; if you pass more than
   that, it prints garbage.  So if a thread had more than 32 values, you
   couldn't print them, issue #2486.

 - The order of the values within the Triton program (GPU thread block)
   is an implementation detail -- it depends on the layout the compiler
   assigns to a tensor.  So this also prevented you from interpreting
   the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value.  And each value gets its own line to avoid
passing too many args to printf.

Example output:

    ```
    pid (0, 1, 2) idx (36, 127) x: 42
    ```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes #2486.
This commit is contained in:
Justin Lebar
2023-10-25 01:47:55 -07:00
committed by GitHub
parent 2217bd2f5c
commit e70e11e834
7 changed files with 298 additions and 71 deletions

View File

@@ -11,21 +11,47 @@ import triton.language as tl
@triton.jit
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_print("", x)
tl.device_print("x: ", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
print("", x)
# Triton should add a space after this prefix.
print("x:", x)
tl.store(Y + tl.arange(0, BLOCK), x)
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
# the static print is run every time.
@triton.jit
def kernel_device_print_large(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
# Triton should change this prefix to "x: ".
tl.device_print("x ", x)
@triton.jit
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.full((BLOCK,), 1, tl.int32)
print("", x, y)
@triton.jit
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.full((BLOCK,), 1, tl.int32)
tl.device_print("", x, y)
tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
# This function takes an extra value as a tl.constexpr so this kernel is not
# cached. This way the static print is run every time.
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@@ -38,19 +64,27 @@ def kernel_no_arg_print():
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "device_print_large":
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
elif func == "print_multiple_args":
kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
elif func == "device_print_multiple_args":
kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)
else:
assert f"Unknown kernel: {func}"
if func != "no_arg_print":
if func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args":
assert_close(y, x)

View File

@@ -1,6 +1,8 @@
import itertools
import os
import subprocess
import sys
from collections import Counter
import pytest
@@ -14,26 +16,53 @@ nested_types = [(caller, callee) for caller in ["true", "false", "none"] for cal
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
# TODO: Print with multiple operands
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
[("device_print", data_type) for data_type in torch_types] + [
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = line
if func_type != "static_print":
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print" and func_type != "no_arg_print":
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
# Format is
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
expected_lines = Counter()
if func_type == "print" or func_type == "device_print":
for i in range(128):
assert i in new_lines
else:
assert len(new_lines) == 1
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = 1
elif func_type == "static_print":
expected_lines[" int32[constexpr[128]]"] = 1
elif func_type == "no_arg_print":
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
elif func_type == "device_print_large":
for i, j, k in itertools.product(range(2), range(64), range(128)):
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
for i in range(128):
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
actual_lines = Counter()
for line in outs:
actual_lines[line] += 1
diff = Counter(actual_lines)
diff.subtract(expected_lines)
for line, delta in diff.items():
if delta == 0:
continue
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
assert all(delta == 0 for delta in diff.values())
@pytest.mark.parametrize("func_type", assert_types)

View File

@@ -1577,6 +1577,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
# It makes sense visually for prefix to end in ": "; make it so. Also,
# non-empty prefixes should start with " ".
if not prefix.endswith(" "):
prefix += " "
if not prefix.endswith(": "):
prefix = prefix[:-1] + ": "
if len(prefix) > 2 and not prefix.startswith(" "):
prefix = " " + prefix
new_args = []
for arg in args:
new_args.append(arg.handle)