mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Improve printf. (#2532)
[BACKEND] Improve printf. Previously, we printed all of a GPU thread's values in a single printf() call, and this, plus the user-specified prefix, was all we printed. This caused a few problems. - nvptx printf can only handle 32 arguments; if you pass more than that, it prints garbage. So if a thread had more than 32 values, you couldn't print them, issue #2486. - The order of the values within the Triton program (GPU thread block) is an implementation detail -- it depends on the layout the compiler assigns to a tensor. So this also prevented you from interpreting the printed output. To address this, we now print the Triton pid and multi-dimensional Tensor index for each value. And each value gets its own line to avoid passing too many args to printf. Example output: ``` pid (0, 1, 2) idx (36, 127) x: 42 ``` If you want to observe all the values in a tensor in order, you can grep and then sort the output. We also make a UX enhancement to print: The printed label always ends with ": "; you don't have to add it yourself. Fixes #2486.
This commit is contained in:
@@ -1042,7 +1042,7 @@ private:
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}; // namespace triton::gpu::ConvertLayoutOp>
|
||||
};
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -10,6 +12,23 @@ using ::mlir::LLVM::getSRegValue;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(axis >= 0);
|
||||
assert(axis < 3);
|
||||
assert(moduleOp);
|
||||
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
return getSRegValue(rewriter, loc, sreg);
|
||||
}
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
@@ -91,6 +110,12 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
// The input print op contains:
|
||||
// - a "prefix" (string) specified by the user, and
|
||||
// - one or more "operands" (tensors).
|
||||
//
|
||||
// For each operand, we print all of the values contained in this GPU thread,
|
||||
// one per line, along with the index of the value in its tensor.
|
||||
struct PrintOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -100,45 +125,169 @@ struct PrintOpConversion
|
||||
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
auto sub_operands = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
}
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << op.getPrefix();
|
||||
if (!operands.empty()) {
|
||||
os << getFormatSubstr(operands[0]);
|
||||
Value prefixStr =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix());
|
||||
|
||||
auto getPid = [&](int axis) {
|
||||
return llGetPid(axis, loc, op->getParentOfType<ModuleOp>(), rewriter);
|
||||
};
|
||||
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
|
||||
|
||||
// Simple printf of a string without any tensors.
|
||||
if (op.getNumOperands() == 0) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
os << "pid (" << getFormatSubstr(pid[0]) << ", "
|
||||
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
|
||||
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < operands.size(); ++i) {
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
// Elements of the tensor that are resident in this GPU thread.
|
||||
auto elems = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
|
||||
|
||||
// Get the indices of `elems` within the tensor. Note that if `elems` has
|
||||
// an "interesting" layout, then these will not be in any particularly
|
||||
// nice order.
|
||||
|
||||
// Extract the shape of the tensor being printed and use it to figure out
|
||||
// how many digits we need for each of the dimensions.
|
||||
SmallVector<int, 8> dimWidths;
|
||||
SmallVector<SmallVector<Value>> indices;
|
||||
if (auto rankedTy =
|
||||
op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
|
||||
indices = emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy);
|
||||
for (int64_t dim : rankedTy.getShape()) {
|
||||
if (dim > 0) {
|
||||
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
|
||||
} else {
|
||||
dimWidths.push_back(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We're printing a scalar.
|
||||
assert(elems.size() == 1);
|
||||
indices.push_back({});
|
||||
}
|
||||
|
||||
if (!elems.empty()) {
|
||||
printTensor(prefixStr, /*operand=*/i,
|
||||
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
|
||||
dimWidths, rewriter);
|
||||
}
|
||||
}
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
std::string getFormatSubstr(Value value) const {
|
||||
void printTensor(Value prefixStr, size_t operand, size_t numOperands,
|
||||
ArrayRef<Value> elems, std::array<Value, 3> pid,
|
||||
ArrayRef<SmallVector<Value>> indices,
|
||||
ArrayRef<int> dimWidths,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(!elems.empty());
|
||||
assert(elems.size() == indices.size());
|
||||
assert(dimWidths.size() == indices.front().size());
|
||||
|
||||
size_t rank = dimWidths.size();
|
||||
|
||||
// Format is:
|
||||
// pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...)<prefix> (operand <n>) <elem>
|
||||
// where we leave off "(operand <n>)" if there's only one operand.
|
||||
//
|
||||
// The Python wrapper munges `prefix` so that it prints nicely (e.g. starts
|
||||
// with " " and ends with ": ").
|
||||
|
||||
Value formatStrValue;
|
||||
for (int i = 0; i < elems.size(); i++) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
|
||||
// nvptx printf can only accept 32 args; if we pass more than that, it
|
||||
// will print garbage for the trailing args.
|
||||
constexpr int kMaxPrintfOperands = 32;
|
||||
SmallVector<Value, kMaxPrintfOperands> printfOperands;
|
||||
|
||||
// TODO(jlebar): We really should pad the pid, but because the max pid is
|
||||
// not known at compile-time, this would require nontrivial device-side
|
||||
// work.
|
||||
os << "pid (";
|
||||
for (int j = 0; j < pid.size(); j++) {
|
||||
if (j != 0) {
|
||||
os << ", ";
|
||||
}
|
||||
os << getFormatSubstr(pid[j]);
|
||||
printfOperands.push_back(pid[j]);
|
||||
}
|
||||
os << ") ";
|
||||
|
||||
// If `rank` is large enough, we could end up exceeding
|
||||
// kMaxPrintfOperands. In that case, just truncate the index.
|
||||
// (Subtract 2 because we're going to add two operands after the index.)
|
||||
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;
|
||||
|
||||
os << "idx (";
|
||||
const auto &index = indices[i];
|
||||
for (size_t dim = 0; dim < index.size(); dim++) {
|
||||
if (dim != 0) {
|
||||
os << ", ";
|
||||
}
|
||||
if (dim == maxAllowedRank) {
|
||||
os << "... (truncated)";
|
||||
break;
|
||||
}
|
||||
os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]);
|
||||
printfOperands.push_back(index[dim]);
|
||||
}
|
||||
os << ")";
|
||||
|
||||
os << "%s";
|
||||
printfOperands.push_back(prefixStr);
|
||||
|
||||
if (numOperands > 1) {
|
||||
os << "(operand " << operand << ") ";
|
||||
}
|
||||
|
||||
auto elem = elems[i];
|
||||
os << getFormatSubstr(elem);
|
||||
printfOperands.push_back(elem);
|
||||
|
||||
// It's the same format string each iteration, but it's a lot easier if we
|
||||
// construct the format string at the same time as we populate
|
||||
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
|
||||
// strings, so we cache the Value.
|
||||
if (i == 0) {
|
||||
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
|
||||
} else {
|
||||
llPrintf(formatStrValue, printfOperands, rewriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string getFormatSubstr(Value value,
|
||||
std::optional<int> width = std::nullopt) const {
|
||||
std::string prefix = "%";
|
||||
if (width.has_value()) {
|
||||
prefix += std::to_string(*width);
|
||||
}
|
||||
|
||||
Type type = value.getType();
|
||||
if (type.isa<LLVM::LLVMPointerType>()) {
|
||||
return "%p";
|
||||
return prefix + "p";
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
|
||||
return "%f";
|
||||
return prefix + "f";
|
||||
} else if (type.isSignedInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%lli";
|
||||
return prefix + "lli";
|
||||
else
|
||||
return "%i";
|
||||
return prefix + "i";
|
||||
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
|
||||
if (type.getIntOrFloatBitWidth() == 64)
|
||||
return "%llu";
|
||||
return prefix + "llu";
|
||||
else
|
||||
return "%u";
|
||||
return prefix + "u";
|
||||
}
|
||||
assert(false && "not supported type");
|
||||
return "";
|
||||
@@ -194,9 +343,22 @@ struct PrintOpConversion
|
||||
return {newType, newOp};
|
||||
}
|
||||
|
||||
static void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// Returns a Value for the format string, which you can reuse.
|
||||
static Value llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(!msg.empty() && "printf with empty string not supported");
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value msgValue =
|
||||
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
|
||||
rewriter, "printfFormat_", msgNewline);
|
||||
llPrintf(msgValue, args, rewriter);
|
||||
return msgValue;
|
||||
}
|
||||
|
||||
static void llPrintf(Value msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
@@ -208,11 +370,6 @@ struct PrintOpConversion
|
||||
Value one = i32_val(1);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
Value bufferPtr = null(int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
@@ -240,7 +397,7 @@ struct PrintOpConversion
|
||||
bufferPtr = bitcast(allocated, int8Ptr);
|
||||
}
|
||||
|
||||
SmallVector<Value> operands{prefixString, bufferPtr};
|
||||
SmallVector<Value> operands{msg, bufferPtr};
|
||||
call(funcOp, operands);
|
||||
}
|
||||
};
|
||||
@@ -390,20 +547,8 @@ struct GetProgramIdOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
// "%clusterid".
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
|
||||
Value programId = getSRegValue(rewriter, loc, sreg);
|
||||
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
|
||||
op->getParentOfType<ModuleOp>(), rewriter);
|
||||
rewriter.replaceOp(op, programId);
|
||||
return success();
|
||||
}
|
||||
@@ -685,6 +830,10 @@ struct AsyncBulkCommitGroupOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir::triton {
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -710,3 +859,5 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<PrintOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
} // namespace mlir::triton
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace mlir::triton {
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -13,4 +15,6 @@ void populateTritonGPUToLLVMPatterns(
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
} // namespace mlir::triton
|
||||
|
||||
#endif
|
||||
|
||||
@@ -179,10 +179,10 @@ public:
|
||||
// Key: {layout, shape, withCTAOffset}
|
||||
struct IndexCacheInfo {
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
*baseIndexCache;
|
||||
*baseIndexCache = nullptr;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
CacheKeyDenseMapInfo> *indexCache;
|
||||
OpBuilder::InsertPoint *indexInsertPoint;
|
||||
CacheKeyDenseMapInfo> *indexCache = nullptr;
|
||||
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
|
||||
};
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
@@ -778,7 +778,7 @@ public:
|
||||
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
"emitIndices for layouts other than blocked, mma, and slice not "
|
||||
"implemented yet");
|
||||
}
|
||||
if (cache) {
|
||||
|
||||
@@ -11,21 +11,47 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.device_print("x: ", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
# Triton should add a space after this prefix.
|
||||
print("x:", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
|
||||
# the static print is run every time.
|
||||
@triton.jit
|
||||
def kernel_device_print_large(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32)
|
||||
# Triton should change this prefix to "x: ".
|
||||
tl.device_print("x ", x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.full((BLOCK,), 1, tl.int32)
|
||||
print("", x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
y = tl.full((BLOCK,), 1, tl.int32)
|
||||
tl.device_print("", x, y)
|
||||
tl.store(Y + tl.arange(0, BLOCK), y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
|
||||
# This function takes an extra value as a tl.constexpr so this kernel is not
|
||||
# cached. This way the static print is run every time.
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
@@ -38,19 +64,27 @@ def kernel_no_arg_print():
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_large":
|
||||
kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128)
|
||||
elif func == "print_multiple_args":
|
||||
kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "device_print_multiple_args":
|
||||
kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
else:
|
||||
assert f"Unknown kernel: {func}"
|
||||
|
||||
if func != "no_arg_print":
|
||||
if func != "no_arg_print" and func != "device_print_large" and \
|
||||
func != "print_multiple_args" and func != "device_print_multiple_args":
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,26 +16,53 @@ nested_types = [(caller, callee) for caller in ["true", "false", "none"] for cal
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
# TODO: Print with multiple operands
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")])
|
||||
[("device_print", data_type) for data_type in torch_types] + [
|
||||
("print", "int32"),
|
||||
("static_print", "int32"),
|
||||
("no_arg_print", "int32"),
|
||||
("device_print_large", "int32"),
|
||||
("print_multiple_args", "int32"),
|
||||
("device_print_multiple_args", "int32"),
|
||||
])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print" and func_type != "no_arg_print":
|
||||
outs = [line for line in outs.decode("UTF-8").split("\n") if line]
|
||||
|
||||
# Format is
|
||||
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
|
||||
expected_lines = Counter()
|
||||
if func_type == "print" or func_type == "device_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
|
||||
if data_type.startswith("float"):
|
||||
line += ".000000"
|
||||
expected_lines[line] = 1
|
||||
elif func_type == "static_print":
|
||||
expected_lines[" int32[constexpr[128]]"] = 1
|
||||
elif func_type == "no_arg_print":
|
||||
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
|
||||
elif func_type == "device_print_large":
|
||||
for i, j, k in itertools.product(range(2), range(64), range(128)):
|
||||
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
|
||||
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
|
||||
for i in range(128):
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
|
||||
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1
|
||||
|
||||
actual_lines = Counter()
|
||||
for line in outs:
|
||||
actual_lines[line] += 1
|
||||
|
||||
diff = Counter(actual_lines)
|
||||
diff.subtract(expected_lines)
|
||||
for line, delta in diff.items():
|
||||
if delta == 0:
|
||||
continue
|
||||
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
|
||||
assert all(delta == 0 for delta in diff.values())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", assert_types)
|
||||
|
||||
@@ -1577,6 +1577,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
||||
# non-empty prefixes should start with " ".
|
||||
if not prefix.endswith(" "):
|
||||
prefix += " "
|
||||
if not prefix.endswith(": "):
|
||||
prefix = prefix[:-1] + ": "
|
||||
if len(prefix) > 2 and not prefix.startswith(" "):
|
||||
prefix = " " + prefix
|
||||
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
|
||||
Reference in New Issue
Block a user