mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Overwrite NVPTX converters for fp16<->fp32 and int16<->int32 to avoid ptxas problems (#1267)
This commit is contained in:
@@ -269,6 +269,17 @@ struct FpToFpOpConversion
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
static Value convertFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.f16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
static Value convertFp32ToBf16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
@@ -282,6 +293,17 @@ struct FpToFpOpConversion
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
}
|
||||
|
||||
static Value convertFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -336,6 +358,12 @@ struct FpToFpOpConversion
|
||||
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
|
||||
resultVals.emplace_back(
|
||||
convertFp32ToBf16(loc, rewriter, adaptor.getFrom()));
|
||||
} else if (srcEltType.isF16() && dstEltType.isF32()) {
|
||||
resultVals.emplace_back(
|
||||
convertFp16ToFp32(loc, rewriter, adaptor.getFrom()));
|
||||
} else if (srcEltType.isF32() && dstEltType.isF16()) {
|
||||
resultVals.emplace_back(
|
||||
convertFp32ToFp16(loc, rewriter, adaptor.getFrom()));
|
||||
} else {
|
||||
assert(false && "unsupported type casting");
|
||||
}
|
||||
@@ -860,3 +888,154 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
struct FPExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF32() && srcTy.isF16()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]);
|
||||
}
|
||||
};
|
||||
|
||||
struct FPTruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPTruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF16() && srcTy.isF32()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]);
|
||||
}
|
||||
};
|
||||
|
||||
struct TruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::TruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u16.u32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(operands[0], "r");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
struct SExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::SExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.s32.s16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0], "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
struct ZExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::ZExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u32.u16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0], "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op) {
|
||||
if (isa<LLVM::FPExtOp>(op)) {
|
||||
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
|
||||
} else if (isa<LLVM::FPTruncOp>(op)) {
|
||||
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
|
||||
} else if (isa<LLVM::TruncOp>(op)) {
|
||||
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
|
||||
} else if (isa<LLVM::SExtOp>(op)) {
|
||||
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
|
||||
} else if (isa<LLVM::ZExtOp>(op)) {
|
||||
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<FPExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<TruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ZExtOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
@@ -13,4 +13,10 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem, PatternBenefit benefit);
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op);
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -56,6 +56,16 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class TritonPTXConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
|
||||
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
|
||||
[&](Operation *op) { return isLegalElementwiseOp(op); });
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
@@ -202,6 +212,19 @@ public:
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Use our custom converters to convert some operations to PTX to avoid
|
||||
// using NVPTX for two reasons:
|
||||
// 1. NVPTX backend is flaky on data types like float16 and bfloat16
|
||||
// 2. In some cases, we may generate faster PTX code than NVPTX backend
|
||||
TritonPTXConversionTarget ptxTarget(*context);
|
||||
RewritePatternSet ptxPatterns(context);
|
||||
// Add patterns to convert LLVM to PTX
|
||||
populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns,
|
||||
/*benefits=*/10);
|
||||
|
||||
if (failed(applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -1994,3 +1994,42 @@ def test_load_scalar_with_mask():
|
||||
Out = torch.empty_like(Index, device='cuda')
|
||||
kernel[(1,)](Input, Index, Out, Index.numel())
|
||||
assert Out.data[0] == 0
|
||||
|
||||
|
||||
# This test is used to test our own PTX codegen for float16 and int16 conversions
|
||||
# maybe delete it later after ptxas has been fixed
|
||||
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
|
||||
def test_ptx_cast(dtype_str):
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
||||
xmask = xindex < xnumel
|
||||
rbase = tl.arange(0, RBLOCK)[None, :]
|
||||
x0 = xindex
|
||||
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
|
||||
for roffset in range(0, rnumel, RBLOCK):
|
||||
rindex = roffset + rbase
|
||||
rmask = rindex < rnumel
|
||||
r1 = rindex
|
||||
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
|
||||
tmp1 = 2
|
||||
tmp2 = tmp0 * tmp1
|
||||
tmp3 = tmp2.to(dtype)
|
||||
tmp5 = _tmp4 < tmp3
|
||||
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
|
||||
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
|
||||
|
||||
torch.manual_seed(123)
|
||||
if dtype_str == 'int16':
|
||||
torch_dtype = torch.int16
|
||||
triton_dtype = tl.int32
|
||||
else:
|
||||
torch_dtype = torch.float16
|
||||
triton_dtype = tl.float32
|
||||
|
||||
s0 = 4
|
||||
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
||||
assert buf14.to(torch.float32).mean() == -2.0
|
||||
|
||||
Reference in New Issue
Block a user