mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Implement device_assert functionality. (#207)
Triton firstly prints assert message into stderr stream with the same (refactored) helper function as `device_print` and then ends the thread execution. Note: s_endpgm instruction is used, since s_trap (generated from LLVM::Trap or LLVM::DebugTrap) has some issues on different HW. Also got back fix in `python/triton/compiler/compiler.py` lost after one of IFU.
This commit is contained in:
@@ -90,35 +90,6 @@ struct BroadcastOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static SmallString<16> getUniqueFormatGlobalName(mlir::ModuleOp moduleOp) {
|
||||
const char formatStringPrefix[] = "printfFormat_";
|
||||
// Get a unique global name.
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
return stringConstName;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
StringRef name,
|
||||
LLVM::LLVMFunctionType type) {
|
||||
LLVM::LLVMFuncOp ret;
|
||||
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
|
||||
LLVM::Linkage::External);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
struct PrintOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -147,7 +118,8 @@ struct PrintOpConversion
|
||||
os << ", " << getFormatSubstr(operands[i]);
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
llPrintfHIP(op, formatStr, operands, rewriter);
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr, operands,
|
||||
rewriter);
|
||||
#else
|
||||
llPrintf(formatStr, operands, rewriter);
|
||||
#endif
|
||||
@@ -170,123 +142,6 @@ struct PrintOpConversion
|
||||
return "";
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// The code is borrowed from https://reviews.llvm.org/D110448
|
||||
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
|
||||
void llPrintfHIP(triton::PrintOp op, StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
mlir::Location loc = op->getLoc();
|
||||
|
||||
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
|
||||
mlir::Type i8Ptr = getTypeConverter()->getPointerType(llvmI8);
|
||||
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
|
||||
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
|
||||
|
||||
// Original code from llvm-project needs gpu::GPUModuleOp here to check
|
||||
// gpu.printf is in gpu.module scope.
|
||||
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
|
||||
|
||||
auto ocklBegin =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
|
||||
LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
|
||||
LLVM::LLVMFuncOp ocklAppendArgs;
|
||||
if (!args.empty()) {
|
||||
ocklAppendArgs = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_args",
|
||||
LLVM::LLVMFunctionType::get(llvmI64,
|
||||
{llvmI64, /*numArgs*/ llvmI32, llvmI64,
|
||||
llvmI64, llvmI64, llvmI64, llvmI64,
|
||||
llvmI64, llvmI64, /*isLast*/ llvmI32}));
|
||||
}
|
||||
auto ocklAppendStringN = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
|
||||
LLVM::LLVMFunctionType::get(
|
||||
llvmI64,
|
||||
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
|
||||
/// Start the printf hostcall
|
||||
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
|
||||
auto printfBeginCall =
|
||||
rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
|
||||
Value printfDesc = printfBeginCall.getResult();
|
||||
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
llvm::SmallString<20> formatString(msg);
|
||||
formatString.push_back('\n'); // Triton adds CR for each print.
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
|
||||
auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
// Get a pointer to the format string's first element and pass it to
|
||||
// printf()
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc,
|
||||
getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value stringStart = rewriter.create<LLVM::GEPOp>(
|
||||
loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
Value stringLen =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
|
||||
|
||||
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
|
||||
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
|
||||
|
||||
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, ocklAppendStringN,
|
||||
ValueRange{printfDesc, stringStart, stringLen,
|
||||
args.empty() ? oneI32 : zeroI32});
|
||||
printfDesc = appendFormatCall.getResult();
|
||||
|
||||
// __ockl_printf_append_args takes 7 values per append call
|
||||
constexpr size_t argsPerAppend = 7;
|
||||
size_t nArgs = args.size();
|
||||
for (size_t group = 0; group < nArgs; group += argsPerAppend) {
|
||||
size_t bound = std::min(group + argsPerAppend, nArgs);
|
||||
size_t numArgsThisCall = bound - group;
|
||||
|
||||
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
|
||||
arguments.push_back(printfDesc);
|
||||
arguments.push_back(
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
|
||||
for (size_t i = group; i < bound; ++i) {
|
||||
Value arg = args[i];
|
||||
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
|
||||
if (!floatType.isF64())
|
||||
arg = rewriter.create<LLVM::FPExtOp>(
|
||||
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
|
||||
}
|
||||
if (arg.getType().getIntOrFloatBitWidth() != 64)
|
||||
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
|
||||
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
// Pad out to 7 arguments since the hostcall always needs 7
|
||||
for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
|
||||
arguments.push_back(zeroI64);
|
||||
}
|
||||
|
||||
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
|
||||
arguments.push_back(isLast);
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
|
||||
printfDesc = call.getResult();
|
||||
}
|
||||
}
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
// declare vprintf(i8*, i8*) as external function
|
||||
static LLVM::LLVMFuncOp
|
||||
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
|
||||
@@ -413,14 +268,64 @@ struct AssertOpConversion
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
llAssertHIP(op, condition, adaptor.getMessage(), adaptor.getFile(),
|
||||
adaptor.getFunc(), adaptor.getLine(), rewriter);
|
||||
#else
|
||||
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
|
||||
adaptor.getFunc(), adaptor.getLine(), rewriter);
|
||||
#endif
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// op: the op at which the assert is inserted. Unlike printf, we need to
|
||||
// know about the op to split the block.
|
||||
#ifdef USE_ROCM
|
||||
void llAssertHIP(Operation *op, Value condition, StringRef message,
|
||||
StringRef file, StringRef func, int line,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
auto ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// #prevBlock
|
||||
// if (condition) {
|
||||
// #ifBlock
|
||||
// print(message);
|
||||
// halt;
|
||||
// }
|
||||
// #endBlock
|
||||
Block *prevBlock = op->getBlock();
|
||||
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToStart(ifBlock);
|
||||
|
||||
SmallString<256> tmpBuf;
|
||||
message =
|
||||
llvm::Twine("Assertion failed: " + message + ", File: " + file +
|
||||
", Function: " + func + ", Line: " + llvm::Twine(line))
|
||||
.toStringRef(tmpBuf);
|
||||
|
||||
// Print assert message.
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), message,
|
||||
ValueRange(), rewriter, /*stderr*/ true);
|
||||
|
||||
// Perform the trap.
|
||||
GCNBuilder BuilderTrap;
|
||||
// TODO: LLVM::Trap LLVM::DebugTrap instructions don't work here.
|
||||
BuilderTrap.create<>("s_endpgm")->operator()();
|
||||
BuilderTrap.launch(rewriter, loc, void_ty(ctx));
|
||||
|
||||
// Split a block after the call.
|
||||
Block *endBlock = rewriter.splitBlock(ifBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(ifBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, endBlock);
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, endBlock);
|
||||
}
|
||||
|
||||
#else // USE_ROCM
|
||||
|
||||
static void llAssert(Operation *op, Value condition, StringRef message,
|
||||
StringRef file, StringRef func, int line,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
@@ -485,6 +390,7 @@ struct AssertOpConversion
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
|
||||
funcType);
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
};
|
||||
|
||||
struct MakeRangeOpConversion
|
||||
|
||||
@@ -911,6 +911,150 @@ private:
|
||||
return resultIndices;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
private:
|
||||
static SmallString<16> getUniqueFormatGlobalName(mlir::ModuleOp moduleOp) {
|
||||
const char formatStringPrefix[] = "printfFormat_";
|
||||
// Get a unique global name.
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
return stringConstName;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LLVM::LLVMFuncOp
|
||||
getOrDefineFunction(T &moduleOp, const Location loc,
|
||||
ConversionPatternRewriter &rewriter, StringRef name,
|
||||
LLVM::LLVMFunctionType type) {
|
||||
LLVM::LLVMFuncOp ret;
|
||||
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
|
||||
LLVM::Linkage::External);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
protected:
|
||||
// The code is borrowed from https://reviews.llvm.org/D110448
|
||||
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
|
||||
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
|
||||
ValueRange args, ConversionPatternRewriter &rewriter,
|
||||
bool stderr = false) const {
|
||||
|
||||
auto typeConverter = getTypeConverter();
|
||||
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
|
||||
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
|
||||
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
|
||||
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
|
||||
|
||||
auto ocklBegin = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter,
|
||||
(stderr ? "__ockl_fprintf_stderr_begin" : "__ockl_printf_begin"),
|
||||
(LLVM::LLVMFunctionType::get(llvmI64, stderr ? ArrayRef<mlir::Type>()
|
||||
: llvmI64)));
|
||||
LLVM::LLVMFuncOp ocklAppendArgs;
|
||||
if (!args.empty()) {
|
||||
ocklAppendArgs = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_args",
|
||||
LLVM::LLVMFunctionType::get(llvmI64,
|
||||
{llvmI64, /*numArgs*/ llvmI32, llvmI64,
|
||||
llvmI64, llvmI64, llvmI64, llvmI64,
|
||||
llvmI64, llvmI64, /*isLast*/ llvmI32}));
|
||||
}
|
||||
auto ocklAppendStringN = getOrDefineFunction(
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
|
||||
LLVM::LLVMFunctionType::get(
|
||||
llvmI64,
|
||||
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
|
||||
/// Start the printf hostcall
|
||||
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
|
||||
auto printfBeginCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, ocklBegin, stderr ? ValueRange() : zeroI64);
|
||||
Value printfDesc = printfBeginCall.getResult();
|
||||
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
SmallString<32> formatString(msg);
|
||||
formatString.push_back('\n'); // Triton adds CR for each print.
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
|
||||
auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
// Get a pointer to the format string's first element and pass it to
|
||||
// printf()
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
|
||||
loc,
|
||||
getTypeConverter()->getPointerType(globalType, global.getAddrSpace()),
|
||||
global.getSymNameAttr());
|
||||
Value stringStart = rewriter.create<LLVM::GEPOp>(
|
||||
loc, i8Ptr, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
Value stringLen =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
|
||||
|
||||
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
|
||||
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
|
||||
|
||||
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
|
||||
loc, ocklAppendStringN,
|
||||
ValueRange{printfDesc, stringStart, stringLen,
|
||||
args.empty() ? oneI32 : zeroI32});
|
||||
printfDesc = appendFormatCall.getResult();
|
||||
|
||||
// __ockl_printf_append_args takes 7 values per append call
|
||||
constexpr size_t argsPerAppend = 7;
|
||||
size_t nArgs = args.size();
|
||||
for (size_t group = 0; group < nArgs; group += argsPerAppend) {
|
||||
size_t bound = std::min(group + argsPerAppend, nArgs);
|
||||
size_t numArgsThisCall = bound - group;
|
||||
|
||||
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
|
||||
arguments.push_back(printfDesc);
|
||||
arguments.push_back(
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
|
||||
for (size_t i = group; i < bound; ++i) {
|
||||
Value arg = args[i];
|
||||
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
|
||||
if (!floatType.isF64())
|
||||
arg = rewriter.create<LLVM::FPExtOp>(
|
||||
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
|
||||
}
|
||||
if (arg.getType().getIntOrFloatBitWidth() != 64)
|
||||
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
|
||||
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
// Pad out to 7 arguments since the hostcall always needs 7
|
||||
for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
|
||||
arguments.push_back(zeroI64);
|
||||
}
|
||||
|
||||
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
|
||||
arguments.push_back(isLast);
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
|
||||
printfDesc = call.getResult();
|
||||
}
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
protected:
|
||||
TritonGPUToLLVMTypeConverter *converter;
|
||||
const Allocation *allocation;
|
||||
|
||||
@@ -33,9 +33,9 @@ def test_assert(func: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
@@ -174,7 +174,8 @@ def get_amdgcn_bitcode_paths(arch):
|
||||
"oclc_daz_opt_off.bc",
|
||||
"oclc_correctly_rounded_sqrt_on.bc",
|
||||
"oclc_unsafe_math_off.bc",
|
||||
"oclc_wavefrontsize64_on.bc"]
|
||||
"oclc_wavefrontsize64_on.bc",
|
||||
"oclc_abi_version_400.bc",]
|
||||
|
||||
gfx_arch = arch[1]
|
||||
gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip()
|
||||
@@ -183,7 +184,7 @@ def get_amdgcn_bitcode_paths(arch):
|
||||
bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/")
|
||||
|
||||
amdgcn_bitcode_paths = {}
|
||||
i = 1
|
||||
i = 0
|
||||
for bc_lib in gpu_arch_agnostic_bitcode_libraries:
|
||||
bc_path = bitcode_path_dir + bc_lib
|
||||
if os.path.exists(bc_path):
|
||||
|
||||
Reference in New Issue
Block a user