ROCM IFU: Fix PrintfHIP

This commit is contained in:
Jason Furmanek
2023-11-21 23:06:14 +00:00
parent a08dafe7fe
commit 4e86b25f1c
2 changed files with 30 additions and 14 deletions

View File

@@ -207,6 +207,7 @@ struct PrintOpConversion
// with " " and ends with ": ").
Value formatStrValue;
ConvertTritonGPUOpToLLVMPatternBase::PrintFormatting formatting;
for (int i = 0; i < elems.size(); i++) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
@@ -266,17 +267,18 @@ struct PrintOpConversion
// strings, so we cache the Value.
if (i == 0) {
#ifdef USE_ROCM
formatStrVAlue = llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
printfOperands, rewriter);
#else
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
#endif
} else {
#ifdef USE_ROCM
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, printfOperands,
rewriter);
llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatting,
printfOperands, rewriter);
#else
llPrintf(formatStrValue, printfOperands, rewriter);
#endif
}
}
}

View File

@@ -186,6 +186,12 @@ public:
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
};
struct PrintFormatting
{
Value formatStrValue;
size_t formatStrSize;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
@@ -1312,9 +1318,25 @@ private:
}
protected:
// Returns a Value for the format string, which you can reuse.
PrintFormatting llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
ValueRange args, ConversionPatternRewriter &rewriter,
bool stderr = false) const {
assert(!msg.empty() && "printf with empty string not supported");
PrintFormatting formatting;
llvm::SmallString<32> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
formatting.formatStrValue =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
formatting.formatStrSize = msgNewline.size_in_bytes();
llPrintfHIP(loc, moduleOp, formatting, args, rewriter);
return formatting;
}
// The code is borrowed from https://reviews.llvm.org/D110448
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, PrintFormatting formatting,
ValueRange args, ConversionPatternRewriter &rewriter,
bool stderr = false) const {
@@ -1353,19 +1375,11 @@ protected:
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
SmallString<32> formatString(msg);
formatString.push_back('\n'); // Triton adds CR for each print.
formatString.push_back('\0'); // Null terminate for C
size_t formatStringSize = formatString.size_in_bytes();
Value prefixString =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatString);
auto prefixPtrType = ocklAppendStringN.getArgumentTypes()[1];
prefixString = bitcast(prefixString, prefixPtrType);
Value prefixString = bitcast(formatting.formatStrValue, prefixPtrType);
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatting.formatStrSize);
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);