ROCM IFU: Fix print and assert

This commit is contained in:
Jason Furmanek
2023-12-12 19:30:01 +00:00
parent 50a6db3afd
commit 160dfe838e
4 changed files with 37 additions and 30 deletions

View File

@@ -17,7 +17,13 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
assert(axis >= 0);
assert(axis < 3);
assert(moduleOp);
#ifdef USE_ROCM
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
#else
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
@@ -27,6 +33,7 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
return getSRegValue(rewriter, loc, sreg);
#endif
}
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
@@ -137,12 +144,14 @@ struct PrintOpConversion
if (op.getNumOperands() == 0) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
#ifdef USE_ROCM
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << op.getPrefix().str();
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
{pid[0], pid[1], pid[2]}, rewriter);
#else
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
#ifdef USE_ROCM
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
{pid[0], pid[1], pid[2], prefixStr}, rewriter);
#else
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
#endif
} else {
@@ -250,8 +259,12 @@ struct PrintOpConversion
}
os << ")";
#if USE_ROCM
os << op.getPrefix().str();
#else
os << "%s";
printfOperands.push_back(prefixStr);
#endif
if (numOperands > 1) {
os << "(operand " << operand << ") ";
@@ -266,7 +279,7 @@ struct PrintOpConversion
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
// strings, so we cache the Value.
if (i == 0) {
#ifdef USE_ROCM
#if USE_ROCM
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
printfOperands, rewriter);
#else
@@ -616,20 +629,14 @@ struct GetProgramIdOpConversion
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
Location loc = op->getLoc();
assert(op.getAxisAsInt() < 3);
Value blockId =
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
return success();
#else
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
op->getParentOfType<ModuleOp>(), rewriter);
#ifdef USE_ROCM
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
#else
rewriter.replaceOp(op, programId);
return success();
#endif
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,

View File

@@ -1330,7 +1330,7 @@ protected:
formatting.formatStrValue =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
formatting.formatStrSize = msgNewline.size_in_bytes();
llPrintfHIP(loc, moduleOp, formatting, args, rewriter);
llPrintfHIP(loc, moduleOp, formatting, args, rewriter, stderr);
return formatting;
}
@@ -1342,7 +1342,7 @@ protected:
auto typeConverter = getTypeConverter();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
@@ -1364,7 +1364,7 @@ protected:
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
LLVM::LLVMFunctionType::get(
llvmI64,
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
{llvmI64, {ptrType}, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
@@ -1405,12 +1405,11 @@ protected:
Value arg = args[i];
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
arg = fpext(typeConverter->convertType(rewriter.getF64Type()), arg);
arg = bitcast(arg, llvmI64);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
arg = zext(llvmI64, arg);
arguments.push_back(arg);
}
@@ -1421,7 +1420,7 @@ protected:
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
auto call = call(ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
}