mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Fix PrintfHIP
This commit is contained in:
@@ -207,6 +207,7 @@ struct PrintOpConversion
|
||||
// with " " and ends with ": ").
|
||||
|
||||
Value formatStrValue;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::PrintFormatting formatting;
|
||||
for (int i = 0; i < elems.size(); i++) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
@@ -266,17 +267,18 @@ struct PrintOpConversion
|
||||
// strings, so we cache the Value.
|
||||
if (i == 0) {
|
||||
#ifdef USE_ROCM
|
||||
formatStrVAlue = llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
printfOperands, rewriter);
|
||||
#else
|
||||
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef USE_ROCM
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, printfOperands,
|
||||
rewriter);
|
||||
llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatting,
|
||||
printfOperands, rewriter);
|
||||
#else
|
||||
llPrintf(formatStrValue, printfOperands, rewriter);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,6 +186,12 @@ public:
|
||||
OpBuilder::InsertPoint *indexInsertPoint = nullptr;
|
||||
};
|
||||
|
||||
struct PrintFormatting
|
||||
{
|
||||
Value formatStrValue;
|
||||
size_t formatStrSize;
|
||||
};
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter)
|
||||
: converter(&typeConverter) {}
|
||||
@@ -1312,9 +1318,25 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
// Returns a Value for the format string, which you can reuse.
|
||||
PrintFormatting llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
|
||||
ValueRange args, ConversionPatternRewriter &rewriter,
|
||||
bool stderr = false) const {
|
||||
assert(!msg.empty() && "printf with empty string not supported");
|
||||
PrintFormatting formatting;
|
||||
llvm::SmallString<32> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
formatting.formatStrValue =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
formatting.formatStrSize = msgNewline.size_in_bytes();
|
||||
llPrintfHIP(loc, moduleOp, formatting, args, rewriter);
|
||||
return formatting;
|
||||
}
|
||||
|
||||
// The code is borrowed from https://reviews.llvm.org/D110448
|
||||
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
|
||||
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
|
||||
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, PrintFormatting formatting,
|
||||
ValueRange args, ConversionPatternRewriter &rewriter,
|
||||
bool stderr = false) const {
|
||||
|
||||
@@ -1353,19 +1375,11 @@ protected:
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
SmallString<32> formatString(msg);
|
||||
formatString.push_back('\n'); // Triton adds CR for each print.
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatString);
|
||||
|
||||
auto prefixPtrType = ocklAppendStringN.getArgumentTypes()[1];
|
||||
prefixString = bitcast(prefixString, prefixPtrType);
|
||||
Value prefixString = bitcast(formatting.formatStrValue, prefixPtrType);
|
||||
|
||||
Value stringLen =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatting.formatStrSize);
|
||||
|
||||
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
|
||||
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
|
||||
|
||||
Reference in New Issue
Block a user