[BACKEND] Improve printf. (#2532)

[BACKEND] Improve printf.

Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

 - nvptx printf can only handle 32 arguments; if you pass more than
   that, it prints garbage.  So if a thread had more than 32 values, you
   couldn't print them, issue #2486.

 - The order of the values within the Triton program (GPU thread block)
   is an implementation detail -- it depends on the layout the compiler
   assigns to a tensor.  So this also prevented you from interpreting
   the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value.  And each value gets its own line to avoid
passing too many args to printf.

Example output:

    ```
    pid (0, 1, 2) idx (36, 127) x: 42
    ```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes #2486.
This commit is contained in:
Justin Lebar
2023-10-25 01:47:55 -07:00
committed by GitHub
parent 2217bd2f5c
commit e70e11e834
7 changed files with 298 additions and 71 deletions

View File

@@ -1042,7 +1042,7 @@ private:
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp>
};
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,

View File

@@ -2,6 +2,8 @@
#include "Utility.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
@@ -10,6 +12,23 @@ using ::mlir::LLVM::getSRegValue;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
ConversionPatternRewriter &rewriter) {
assert(axis >= 0);
assert(axis < 3);
assert(moduleOp);
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
return getSRegValue(rewriter, loc, sreg);
}
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
@@ -91,6 +110,12 @@ struct BroadcastOpConversion
}
};
// The input print op contains:
// - a "prefix" (string) specified by the user, and
// - one or more "operands" (tensors).
//
// For each operand, we print all of the values contained in this GPU thread,
// one per line, along with the index of the value in its tensor.
struct PrintOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -100,45 +125,169 @@ struct PrintOpConversion
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
SmallVector<Value, 16> operands;
for (size_t i = 0; i < op.getNumOperands(); i++) {
auto sub_operands = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
for (auto elem : sub_operands) {
operands.push_back(elem);
}
}
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << op.getPrefix();
if (!operands.empty()) {
os << getFormatSubstr(operands[0]);
Value prefixStr =
LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix());
auto getPid = [&](int axis) {
return llGetPid(axis, loc, op->getParentOfType<ModuleOp>(), rewriter);
};
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
// Simple printf of a string without any tensors.
if (op.getNumOperands() == 0) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
return success();
}
for (size_t i = 1; i < operands.size(); ++i) {
os << ", " << getFormatSubstr(operands[i]);
for (size_t i = 0; i < op.getNumOperands(); i++) {
// Elements of the tensor that are resident in this GPU thread.
auto elems = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
// Get the indices of `elems` within the tensor. Note that if `elems` has
// an "interesting" layout, then these will not be in any particularly
// nice order.
// Extract the shape of the tensor being printed and use it to figure out
// how many digits we need for each of the dimensions.
SmallVector<int, 8> dimWidths;
SmallVector<SmallVector<Value>> indices;
if (auto rankedTy =
op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
indices = emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy);
for (int64_t dim : rankedTy.getShape()) {
if (dim > 0) {
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
} else {
dimWidths.push_back(0);
}
}
} else {
// We're printing a scalar.
assert(elems.size() == 1);
indices.push_back({});
}
if (!elems.empty()) {
printTensor(prefixStr, /*operand=*/i,
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
dimWidths, rewriter);
}
}
llPrintf(formatStr, operands, rewriter);
rewriter.eraseOp(op);
return success();
}
std::string getFormatSubstr(Value value) const {
void printTensor(Value prefixStr, size_t operand, size_t numOperands,
ArrayRef<Value> elems, std::array<Value, 3> pid,
ArrayRef<SmallVector<Value>> indices,
ArrayRef<int> dimWidths,
ConversionPatternRewriter &rewriter) const {
assert(!elems.empty());
assert(elems.size() == indices.size());
assert(dimWidths.size() == indices.front().size());
size_t rank = dimWidths.size();
// Format is:
// pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...)<prefix> (operand <n>) <elem>
// where we leave off "(operand <n>)" if there's only one operand.
//
// The Python wrapper munges `prefix` so that it prints nicely (e.g. starts
// with " " and ends with ": ").
Value formatStrValue;
for (int i = 0; i < elems.size(); i++) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
// nvptx printf can only accept 32 args; if we pass more than that, it
// will print garbage for the trailing args.
constexpr int kMaxPrintfOperands = 32;
SmallVector<Value, kMaxPrintfOperands> printfOperands;
// TODO(jlebar): We really should pad the pid, but because the max pid is
// not known at compile-time, this would require nontrivial device-side
// work.
os << "pid (";
for (int j = 0; j < pid.size(); j++) {
if (j != 0) {
os << ", ";
}
os << getFormatSubstr(pid[j]);
printfOperands.push_back(pid[j]);
}
os << ") ";
// If `rank` is large enough, we could end up exceeding
// kMaxPrintfOperands. In that case, just truncate the index.
// (Subtract 2 because we're going to add two operands after the index.)
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;
os << "idx (";
const auto &index = indices[i];
for (size_t dim = 0; dim < index.size(); dim++) {
if (dim != 0) {
os << ", ";
}
if (dim == maxAllowedRank) {
os << "... (truncated)";
break;
}
os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]);
printfOperands.push_back(index[dim]);
}
os << ")";
os << "%s";
printfOperands.push_back(prefixStr);
if (numOperands > 1) {
os << "(operand " << operand << ") ";
}
auto elem = elems[i];
os << getFormatSubstr(elem);
printfOperands.push_back(elem);
// It's the same format string each iteration, but it's a lot easier if we
// construct the format string at the same time as we populate
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
// strings, so we cache the Value.
if (i == 0) {
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
} else {
llPrintf(formatStrValue, printfOperands, rewriter);
}
}
}
std::string getFormatSubstr(Value value,
std::optional<int> width = std::nullopt) const {
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
}
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return "%p";
return prefix + "p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return "%f";
return prefix + "f";
} else if (type.isSignedInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return "%lli";
return prefix + "lli";
else
return "%i";
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return "%llu";
return prefix + "llu";
else
return "%u";
return prefix + "u";
}
assert(false && "not supported type");
return "";
@@ -194,9 +343,22 @@ struct PrintOpConversion
return {newType, newOp};
}
static void llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
// Returns a Value for the format string, which you can reuse.
static Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
llPrintf(msgValue, args, rewriter);
return msgValue;
}
static void llPrintf(Value msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
Type int8Ptr = ptr_ty(i8_ty);
auto *ctx = rewriter.getContext();
@@ -208,11 +370,6 @@ struct PrintOpConversion
Value one = i32_val(1);
Value zero = i32_val(0);
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value prefixString =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
Value bufferPtr = null(int8Ptr);
SmallVector<Value, 16> newArgs;
@@ -240,7 +397,7 @@ struct PrintOpConversion
bufferPtr = bitcast(allocated, int8Ptr);
}
SmallVector<Value> operands{prefixString, bufferPtr};
SmallVector<Value> operands{msg, bufferPtr};
call(funcOp, operands);
}
};
@@ -390,20 +547,8 @@ struct GetProgramIdOpConversion
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
Location loc = op->getLoc();
assert(op.getAxisAsInt() < 3);
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
Value programId = getSRegValue(rewriter, loc, sreg);
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
op->getParentOfType<ModuleOp>(), rewriter);
rewriter.replaceOp(op, programId);
return success();
}
@@ -685,6 +830,10 @@ struct AsyncBulkCommitGroupOpConversion
}
};
} // namespace
namespace mlir::triton {
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
@@ -710,3 +859,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<PrintOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}
} // namespace mlir::triton

View File

@@ -6,6 +6,8 @@
using namespace mlir;
using namespace mlir::triton;
namespace mlir::triton {
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
@@ -13,4 +15,6 @@ void populateTritonGPUToLLVMPatterns(
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
} // namespace mlir::triton
#endif

View File

@@ -179,10 +179,10 @@ public:
// Key: {layout, shape, withCTAOffset}
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
*baseIndexCache = nullptr;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
CacheKeyDenseMapInfo> *indexCache = nullptr;
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(
@@ -778,7 +778,7 @@ public:
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"emitIndices for layouts other than blocked, mma, and slice not "
"implemented yet");
}
if (cache) {

View File

@@ -11,21 +11,47 @@ import triton.language as tl
@triton.jit
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.device_print("", x)
tl.device_print("x: ", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_print(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
print("", x)
# Triton should add a space after this prefix.
print("x:", x)
tl.store(Y + tl.arange(0, BLOCK), x)
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
# the static print is run every time.
@triton.jit
def kernel_device_print_large(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
# Triton should change this prefix to "x: ".
tl.device_print("x ", x)
@triton.jit
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.full((BLOCK,), 1, tl.int32)
print("", x, y)
@triton.jit
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.full((BLOCK,), 1, tl.int32)
tl.device_print("", x, y)
tl.store(Y + tl.arange(0, BLOCK), y)
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
# This function takes an extra value as a tl.constexpr so this kernel is not
# cached. This way the static print is run every time.
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@@ -38,19 +64,27 @@ def kernel_no_arg_print():
def test_print(func: str, data_type: str):
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "device_print_large":
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
elif func == "print_multiple_args":
kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
elif func == "device_print_multiple_args":
kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)
else:
assert f"Unknown kernel: {func}"
if func != "no_arg_print":
if func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args":
assert_close(y, x)

View File

@@ -1,6 +1,8 @@
import itertools
import os
import subprocess
import sys
from collections import Counter
import pytest
@@ -14,26 +16,53 @@ nested_types = [(caller, callee) for caller in ["true", "false", "none"] for cal
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
# TODO: Print with multiple operands
@pytest.mark.parametrize("func_type, data_type",
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
[("device_print", data_type) for data_type in torch_types] + [
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
])
def test_print(func_type: str, data_type: str):
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
outs, _ = proc.communicate()
outs = outs.split()
new_lines = set()
for line in outs:
try:
value = line
if func_type != "static_print":
value = int(float(line))
new_lines.add(value)
except Exception as e:
print(e)
if func_type != "static_print" and func_type != "no_arg_print":
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
# Format is
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
expected_lines = Counter()
if func_type == "print" or func_type == "device_print":
for i in range(128):
assert i in new_lines
else:
assert len(new_lines) == 1
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = 1
elif func_type == "static_print":
expected_lines[" int32[constexpr[128]]"] = 1
elif func_type == "no_arg_print":
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
elif func_type == "device_print_large":
for i, j, k in itertools.product(range(2), range(64), range(128)):
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
for i in range(128):
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
actual_lines = Counter()
for line in outs:
actual_lines[line] += 1
diff = Counter(actual_lines)
diff.subtract(expected_lines)
for line, delta in diff.items():
if delta == 0:
continue
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
assert all(delta == 0 for delta in diff.values())
@pytest.mark.parametrize("func_type", assert_types)

View File

@@ -1577,6 +1577,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
# It makes sense visually for prefix to end in ": "; make it so. Also,
# non-empty prefixes should start with " ".
if not prefix.endswith(" "):
prefix += " "
if not prefix.endswith(": "):
prefix = prefix[:-1] + ": "
if len(prefix) > 2 and not prefix.startswith(" "):
prefix = " " + prefix
new_args = []
for arg in args:
new_args.append(arg.handle)