[BACKEND] Merge TT_ElementwisePureExtern and TT_ElementwiseImpureExtern (#2137)

Use getEffect instead to tell passes whether the op has side effects or
not. This doesn't change functionality otherwise.
This commit is contained in:
Thomas
2023-08-18 13:56:10 -07:00
committed by GitHub
parent bf351b9ba2
commit 23ef2615d2
6 changed files with 34 additions and 57 deletions

View File

@@ -463,33 +463,23 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
//
// External Elementwise op
//
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
TT_Op<mnemonic,
traits # [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
}
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
}
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>]> {
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
}
//
// Make Range Op
//

View File

@@ -762,15 +762,16 @@ struct CmpFOpConversion
}
};
template <class T>
struct ExternElementwiseOpConversion
: public ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>> {
using Base = ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>>;
: public ElementwiseOpConversionBase<ExternElementwiseOp,
ExternElementwiseOpConversion> {
using Base = ElementwiseOpConversionBase<ExternElementwiseOp,
ExternElementwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
typedef typename Base::OpAdaptor OpAdaptor;
SmallVector<Value> createDestOps(T op, OpAdaptor adaptor,
SmallVector<Value> createDestOps(ExternElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
@@ -791,8 +792,9 @@ private:
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op,
StringRef funcName, Type funcType) const {
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
ExternElementwiseOp op, StringRef funcName,
Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
@@ -1288,11 +1290,7 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExternElementwiseOpConversion<triton::PureExternElementwiseOp>>(
typeConverter, benefit);
patterns
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
typeConverter, benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and

View File

@@ -498,24 +498,6 @@ struct TritonAtomicRMWPattern
}
};
template <class T>
struct TritonExternElementwisePattern : public OpConversionPattern<T> {
using OpConversionPattern<T>::OpConversionPattern;
using OpConversionPattern<T>::typeConverter;
typedef typename OpConversionPattern<T>::OpAdaptor OpAdaptor;
LogicalResult
matchAndRewrite(T op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(rewriter.replaceOpWithNewOp<T>(
op, typeConverter->convertType(op.getType()),
adaptor.getArgs(), adaptor.getLibname(),
adaptor.getLibpath(), adaptor.getSymbol()),
adaptor.getAttributes());
return success();
}
};
template <class Op>
struct TritonGenericPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
@@ -710,11 +692,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern,
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
typeConverter, context);
TritonGenericPattern<triton::ExternElementwiseOp>, TritonPrintPattern,
TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern,
TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context);
}
//

View File

@@ -882,5 +882,17 @@ void ElementwiseInlineAsmOp::getEffects(
SideEffects::DefaultResource::get());
}
// -- ExternElementwiseOp --
void ExternElementwiseOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getPure())
return;
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}
} // namespace triton
} // namespace mlir

View File

@@ -644,9 +644,10 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
math::RsqrtOp, math::SqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp, triton::AddPtrOp,
triton::PureExternElementwiseOp>(op))
triton::FpToFpOp, triton::AddPtrOp>(op))
return true;
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
return true;
return false;

View File

@@ -1424,12 +1424,8 @@ void init_triton_ir(py::module &&m) {
const std::string &libPath, const std::string &symbol,
std::vector<mlir::Value> &argList, mlir::Type retType,
bool isPure) -> mlir::Value {
if (isPure)
return self.create<mlir::triton::PureExternElementwiseOp>(
retType, argList, libName, libPath, symbol);
else
return self.create<mlir::triton::ImpureExternElementwiseOp>(
retType, argList, libName, libPath, symbol);
return self.create<mlir::triton::ExternElementwiseOp>(
retType, argList, libName, libPath, symbol, isPure);
})
// Built-in instruction
.def("create_get_program_id",