[ROCM] Implement device_assert functionality. (#207)

Triton firstly prints assert message into stderr stream with the same
(refactored) helper function as `device_print` and then ends the thread
execution.

Note: s_endpgm instruction is used, since s_trap (generated from LLVM::Trap or LLVM::DebugTrap) has some issues on different HW.

Also got back fix in `python/triton/compiler/compiler.py` lost after one
of IFU.
This commit is contained in:
Daniil Fukalov
2023-05-15 16:16:14 +02:00
committed by GitHub
parent aabc2511f6
commit 7acc1cb707
4 changed files with 202 additions and 151 deletions

View File

@@ -90,35 +90,6 @@ struct BroadcastOpConversion
}
};
#ifdef USE_ROCM
static SmallString<16> getUniqueFormatGlobalName(mlir::ModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}
#endif // USE_ROCM
struct PrintOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -147,7 +118,8 @@ struct PrintOpConversion
os << ", " << getFormatSubstr(operands[i]);
}
#ifdef USE_ROCM
llPrintfHIP(op, formatStr, operands, rewriter);
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, operands,
rewriter);
#else
llPrintf(formatStr, operands, rewriter);
#endif
@@ -170,123 +142,6 @@ struct PrintOpConversion
return "";
}
#ifdef USE_ROCM
// The code is borrowed from https://reviews.llvm.org/D110448
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
void llPrintfHIP(triton::PrintOp op, StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
mlir::Type i8Ptr = getTypeConverter()->getPointerType(llvmI8);
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
// Original code from llvm-project needs gpu::GPUModuleOp here to check
// gpu.printf is in gpu.module scope.
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
auto ocklBegin =
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
LLVM::LLVMFuncOp ocklAppendArgs;
if (!args.empty()) {
ocklAppendArgs = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_args",
LLVM::LLVMFunctionType::get(llvmI64,
{llvmI64, /*numArgs*/ llvmI32, llvmI64,
llvmI64, llvmI64, llvmI64, llvmI64,
llvmI64, llvmI64, /*isLast*/ llvmI32}));
}
auto ocklAppendStringN = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
LLVM::LLVMFunctionType::get(
llvmI64,
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall =
rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(msg);
formatString.push_back('\n'); // Triton adds CR for each print.
formatString.push_back('\0'); // Null terminate for C
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
// Get a pointer to the format string's first element and pass it to
// printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
args.empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
// __ockl_printf_append_args takes 7 values per append call
constexpr size_t argsPerAppend = 7;
size_t nArgs = args.size();
for (size_t group = 0; group < nArgs; group += argsPerAppend) {
size_t bound = std::min(group + argsPerAppend, nArgs);
size_t numArgsThisCall = bound - group;
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = args[i];
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
arguments.push_back(arg);
}
// Pad out to 7 arguments since the hostcall always needs 7
for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
arguments.push_back(zeroI64);
}
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
}
#endif // USE_ROCM
// declare vprintf(i8*, i8*) as external function
static LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
@@ -413,14 +268,64 @@ struct AssertOpConversion
return failure();
}
}
#ifdef USE_ROCM
llAssertHIP(op, condition, adaptor.getMessage(), adaptor.getFile(),
adaptor.getFunc(), adaptor.getLine(), rewriter);
#else
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
adaptor.getFunc(), adaptor.getLine(), rewriter);
#endif
rewriter.eraseOp(op);
return success();
}
// op: the op at which the assert is inserted. Unlike printf, we need to
// know about the op to split the block.
#ifdef USE_ROCM
void llAssertHIP(Operation *op, Value condition, StringRef message,
StringRef file, StringRef func, int line,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
auto ctx = rewriter.getContext();
auto loc = op->getLoc();
// #prevBlock
// if (condition) {
// #ifBlock
// print(message);
// halt;
// }
// #endBlock
Block *prevBlock = op->getBlock();
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToStart(ifBlock);
SmallString<256> tmpBuf;
message =
llvm::Twine("Assertion failed: " + message + ", File: " + file +
", Function: " + func + ", Line: " + llvm::Twine(line))
.toStringRef(tmpBuf);
// Print assert message.
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), message,
ValueRange(), rewriter, /*stderr*/ true);
// Perform the trap.
GCNBuilder BuilderTrap;
// TODO: LLVM::Trap LLVM::DebugTrap instructions don't work here.
BuilderTrap.create<>("s_endpgm")->operator()();
BuilderTrap.launch(rewriter, loc, void_ty(ctx));
// Split a block after the call.
Block *endBlock = rewriter.splitBlock(ifBlock, op->getIterator());
rewriter.setInsertionPointToEnd(ifBlock);
rewriter.create<cf::BranchOp>(loc, endBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, endBlock);
}
#else // USE_ROCM
static void llAssert(Operation *op, Value condition, StringRef message,
StringRef file, StringRef func, int line,
ConversionPatternRewriter &rewriter) {
@@ -485,6 +390,7 @@ struct AssertOpConversion
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
funcType);
}
#endif // USE_ROCM
};
struct MakeRangeOpConversion

View File

@@ -911,6 +911,150 @@ private:
return resultIndices;
}
#ifdef USE_ROCM
private:
static SmallString<16> getUniqueFormatGlobalName(mlir::ModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
template <typename T>
static LLVM::LLVMFuncOp
getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter, StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}
protected:
// The code is borrowed from https://reviews.llvm.org/D110448
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
ValueRange args, ConversionPatternRewriter &rewriter,
bool stderr = false) const {
auto typeConverter = getTypeConverter();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
auto ocklBegin = getOrDefineFunction(
moduleOp, loc, rewriter,
(stderr ? "__ockl_fprintf_stderr_begin" : "__ockl_printf_begin"),
(LLVM::LLVMFunctionType::get(llvmI64, stderr ? ArrayRef<mlir::Type>()
: llvmI64)));
LLVM::LLVMFuncOp ocklAppendArgs;
if (!args.empty()) {
ocklAppendArgs = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_args",
LLVM::LLVMFunctionType::get(llvmI64,
{llvmI64, /*numArgs*/ llvmI32, llvmI64,
llvmI64, llvmI64, llvmI64, llvmI64,
llvmI64, llvmI64, /*isLast*/ llvmI32}));
}
auto ocklAppendStringN = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
LLVM::LLVMFunctionType::get(
llvmI64,
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(
loc, ocklBegin, stderr ? ValueRange() : zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
SmallString<32> formatString(msg);
formatString.push_back('\n'); // Triton adds CR for each print.
formatString.push_back('\0'); // Null terminate for C
size_t formatStringSize = formatString.size_in_bytes();
auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
// Get a pointer to the format string's first element and pass it to
// printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
loc, ocklAppendStringN,
ValueRange{printfDesc, stringStart, stringLen,
args.empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
// __ockl_printf_append_args takes 7 values per append call
constexpr size_t argsPerAppend = 7;
size_t nArgs = args.size();
for (size_t group = 0; group < nArgs; group += argsPerAppend) {
size_t bound = std::min(group + argsPerAppend, nArgs);
size_t numArgsThisCall = bound - group;
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = args[i];
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
arguments.push_back(arg);
}
// Pad out to 7 arguments since the hostcall always needs 7
for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
arguments.push_back(zeroI64);
}
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
}
#endif // USE_ROCM
protected:
TritonGPUToLLVMTypeConverter *converter;
const Allocation *allocation;

View File

@@ -33,9 +33,9 @@ def test_assert(func: str):
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, BLOCK=shape[0])
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)

View File

@@ -174,7 +174,8 @@ def get_amdgcn_bitcode_paths(arch):
"oclc_daz_opt_off.bc",
"oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc",
"oclc_wavefrontsize64_on.bc"]
"oclc_wavefrontsize64_on.bc",
"oclc_abi_version_400.bc",]
gfx_arch = arch[1]
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
@@ -183,7 +184,7 @@ def get_amdgcn_bitcode_paths(arch):
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
amdgcn_bitcode_paths = {}
i = 1
i = 0
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
bc_path = bitcode_path_dir + bc_lib
if os.path.exists(bc_path):