mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Fix print and assert
This commit is contained in:
@@ -17,7 +17,13 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
|
||||
assert(axis >= 0);
|
||||
assert(axis < 3);
|
||||
assert(moduleOp);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]);
|
||||
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
|
||||
#else
|
||||
// It is not easy to get the compute capability here, so we use numCTAs to
|
||||
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
|
||||
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
|
||||
@@ -27,6 +33,7 @@ Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
|
||||
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
|
||||
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
|
||||
return getSRegValue(rewriter, loc, sreg);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
@@ -137,12 +144,14 @@ struct PrintOpConversion
|
||||
if (op.getNumOperands() == 0) {
|
||||
std::string formatStr;
|
||||
llvm::raw_string_ostream os(formatStr);
|
||||
#ifdef USE_ROCM
|
||||
os << "pid (" << getFormatSubstr(pid[0]) << ", "
|
||||
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << op.getPrefix().str();
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
{pid[0], pid[1], pid[2]}, rewriter);
|
||||
#else
|
||||
os << "pid (" << getFormatSubstr(pid[0]) << ", "
|
||||
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
|
||||
#ifdef USE_ROCM
|
||||
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
{pid[0], pid[1], pid[2], prefixStr}, rewriter);
|
||||
#else
|
||||
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
|
||||
#endif
|
||||
} else {
|
||||
@@ -250,8 +259,12 @@ struct PrintOpConversion
|
||||
}
|
||||
os << ")";
|
||||
|
||||
#if USE_ROCM
|
||||
os << op.getPrefix().str();
|
||||
#else
|
||||
os << "%s";
|
||||
printfOperands.push_back(prefixStr);
|
||||
#endif
|
||||
|
||||
if (numOperands > 1) {
|
||||
os << "(operand " << operand << ") ";
|
||||
@@ -266,7 +279,7 @@ struct PrintOpConversion
|
||||
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
|
||||
// strings, so we cache the Value.
|
||||
if (i == 0) {
|
||||
#ifdef USE_ROCM
|
||||
#if USE_ROCM
|
||||
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
|
||||
printfOperands, rewriter);
|
||||
#else
|
||||
@@ -616,20 +629,14 @@ struct GetProgramIdOpConversion
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxisAsInt() < 3);
|
||||
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]);
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
#else
|
||||
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
|
||||
op->getParentOfType<ModuleOp>(), rewriter);
|
||||
#ifdef USE_ROCM
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
|
||||
#else
|
||||
rewriter.replaceOp(op, programId);
|
||||
return success();
|
||||
#endif
|
||||
return success();
|
||||
}
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
|
||||
@@ -1330,7 +1330,7 @@ protected:
|
||||
formatting.formatStrValue =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
formatting.formatStrSize = msgNewline.size_in_bytes();
|
||||
llPrintfHIP(loc, moduleOp, formatting, args, rewriter);
|
||||
llPrintfHIP(loc, moduleOp, formatting, args, rewriter, stderr);
|
||||
return formatting;
|
||||
}
|
||||
|
||||
@@ -1342,7 +1342,7 @@ protected:
|
||||
|
||||
auto typeConverter = getTypeConverter();
|
||||
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
|
||||
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
|
||||
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
||||
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
|
||||
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
|
||||
|
||||
@@ -1364,7 +1364,7 @@ protected:
|
||||
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
|
||||
LLVM::LLVMFunctionType::get(
|
||||
llvmI64,
|
||||
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
{llvmI64, {ptrType}, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
|
||||
|
||||
/// Start the printf hostcall
|
||||
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
|
||||
@@ -1405,12 +1405,11 @@ protected:
|
||||
Value arg = args[i];
|
||||
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
|
||||
if (!floatType.isF64())
|
||||
arg = rewriter.create<LLVM::FPExtOp>(
|
||||
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
|
||||
arg = fpext(typeConverter->convertType(rewriter.getF64Type()), arg);
|
||||
arg = bitcast(arg, llvmI64);
|
||||
}
|
||||
if (arg.getType().getIntOrFloatBitWidth() != 64)
|
||||
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
|
||||
arg = zext(llvmI64, arg);
|
||||
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
@@ -1421,7 +1420,7 @@ protected:
|
||||
|
||||
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
|
||||
arguments.push_back(isLast);
|
||||
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
|
||||
auto call = call(ocklAppendArgs, arguments);
|
||||
printfDesc = call.getResult();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user