[BACKEND] Overwrite NVPTX converters for fp16<->fp32 and int16<->int32 to avoid ptxas problems (#1267)

This commit is contained in:
Keren Zhou
2023-03-01 18:26:06 -08:00
committed by GitHub
parent cb7b315a17
commit 90fcb38c7b
4 changed files with 247 additions and 0 deletions

View File

@@ -269,6 +269,17 @@ struct FpToFpOpConversion
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value convertFp16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.f32.f16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
@@ -282,6 +293,17 @@ struct FpToFpOpConversion
return builder.launch(rewriter, loc, i16_ty, false);
}
static Value convertFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
return builder.launch(rewriter, loc, f16_ty, false);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -336,6 +358,12 @@ struct FpToFpOpConversion
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
resultVals.emplace_back(
convertFp32ToBf16(loc, rewriter, adaptor.getFrom()));
} else if (srcEltType.isF16() && dstEltType.isF32()) {
resultVals.emplace_back(
convertFp16ToFp32(loc, rewriter, adaptor.getFrom()));
} else if (srcEltType.isF32() && dstEltType.isF16()) {
resultVals.emplace_back(
convertFp32ToFp16(loc, rewriter, adaptor.getFrom()));
} else {
assert(false && "unsupported type casting");
}
@@ -860,3 +888,154 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}
struct FPExtOpConversion
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF32() && srcTy.isF16()) {
return false;
}
return true;
}
Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]);
}
};
struct FPTruncOpConversion
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
using Base =
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPTruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF16() && srcTy.isF32()) {
return false;
}
return true;
}
Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]);
}
};
struct TruncOpConversion
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::TruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
return false;
}
return true;
}
Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u16.u32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(operands[0], "r");
cvt(res, operand);
return builder.launch(rewriter, loc, i16_ty, false);
}
};
struct SExtOpConversion
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::SExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.s32.s16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};
struct ZExtOpConversion
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::ZExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u32.u16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};
bool isLegalElementwiseOp(Operation *op) {
if (isa<LLVM::FPExtOp>(op)) {
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
} else if (isa<LLVM::FPTruncOp>(op)) {
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
} else if (isa<LLVM::TruncOp>(op)) {
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
} else if (isa<LLVM::SExtOp>(op)) {
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
} else if (isa<LLVM::ZExtOp>(op)) {
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
}
return true;
}
void populateElementwiseOpToPTXPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FPExtOpConversion>(typeConverter, benefit);
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
patterns.add<TruncOpConversion>(typeConverter, benefit);
patterns.add<SExtOpConversion>(typeConverter, benefit);
patterns.add<ZExtOpConversion>(typeConverter, benefit);
}

View File

@@ -13,4 +13,10 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem, PatternBenefit benefit);
bool isLegalElementwiseOp(Operation *op);
void populateElementwiseOpToPTXPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);
#endif

View File

@@ -56,6 +56,16 @@ public:
}
};
class TritonPTXConversionTarget : public ConversionTarget {
public:
explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
[&](Operation *op) { return isLegalElementwiseOp(op); });
addLegalDialect<NVVM::NVVMDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
} // namespace mlir
namespace {
@@ -202,6 +212,19 @@ public:
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// Use our custom converters to convert some operations to PTX to avoid
// using NVPTX for two reasons:
// 1. NVPTX backend is flaky on data types like float16 and bfloat16
// 2. In some cases, we may generate faster PTX code than NVPTX backend
TritonPTXConversionTarget ptxTarget(*context);
RewritePatternSet ptxPatterns(context);
// Add patterns to convert LLVM to PTX
populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns,
/*benefits=*/10);
if (failed(applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns))))
return signalPassFailure();
}
private:

View File

@@ -1994,3 +1994,42 @@ def test_load_scalar_with_mask():
Out = torch.empty_like(Index, device='cuda')
kernel[(1,)](Input, Index, Out, Index.numel())
assert Out.data[0] == 0
# This test is used to test our own PTX codegen for float16 and int16 conversions
# maybe delete it later after ptxas has been fixed
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
def test_ptx_cast(dtype_str):
@triton.jit
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
tmp1 = 2
tmp2 = tmp0 * tmp1
tmp3 = tmp2.to(dtype)
tmp5 = _tmp4 < tmp3
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
torch.manual_seed(123)
if dtype_str == 'int16':
torch_dtype = torch.int16
triton_dtype = tl.int32
else:
torch_dtype = torch.float16
triton_dtype = tl.float32
s0 = 4
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
assert buf14.to(torch.float32).mean() == -2.0