mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Merge TT_ElementwisePureExtern and TT_ElementwiseImpureExtern (#2137)
Use getEffect instead to tell passes whether the op has side effects or not. This doesn't change functionality otherwise.
This commit is contained in:
@@ -463,33 +463,23 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
|
||||
//
|
||||
// External Elementwise op
|
||||
//
|
||||
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
|
||||
TT_Op<mnemonic,
|
||||
traits # [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
|
||||
let description = [{
|
||||
call an external function $symbol implemented in $libpath/$libname with $args
|
||||
return $libpath/$libname:$symbol($args...)
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
||||
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
|
||||
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
//
|
||||
// Make Range Op
|
||||
//
|
||||
|
||||
@@ -762,15 +762,16 @@ struct CmpFOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct ExternElementwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>> {
|
||||
using Base = ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>>;
|
||||
: public ElementwiseOpConversionBase<ExternElementwiseOp,
|
||||
ExternElementwiseOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<ExternElementwiseOp,
|
||||
ExternElementwiseOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
typedef typename Base::OpAdaptor OpAdaptor;
|
||||
|
||||
SmallVector<Value> createDestOps(T op, OpAdaptor adaptor,
|
||||
SmallVector<Value> createDestOps(ExternElementwiseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, MultipleOperandsRange operands,
|
||||
Location loc) const {
|
||||
@@ -791,8 +792,9 @@ private:
|
||||
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
|
||||
}
|
||||
|
||||
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op,
|
||||
StringRef funcName, Type funcType) const {
|
||||
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
|
||||
ExternElementwiseOp op, StringRef funcName,
|
||||
Type funcType) const {
|
||||
using LLVM::LLVMFuncOp;
|
||||
|
||||
auto funcAttr = StringAttr::get(op->getContext(), funcName);
|
||||
@@ -1288,11 +1290,7 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExternElementwiseOpConversion<triton::PureExternElementwiseOp>>(
|
||||
typeConverter, benefit);
|
||||
patterns
|
||||
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
|
||||
typeConverter, benefit);
|
||||
patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For other input types, ExpOpConversionApprox will return failure and
|
||||
|
||||
@@ -498,24 +498,6 @@ struct TritonAtomicRMWPattern
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct TritonExternElementwisePattern : public OpConversionPattern<T> {
|
||||
using OpConversionPattern<T>::OpConversionPattern;
|
||||
using OpConversionPattern<T>::typeConverter;
|
||||
typedef typename OpConversionPattern<T>::OpAdaptor OpAdaptor;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(T op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<T>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getArgs(), adaptor.getLibname(),
|
||||
adaptor.getLibpath(), adaptor.getSymbol()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
@@ -710,11 +692,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern,
|
||||
TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern,
|
||||
TritonDotPattern, TritonLoadPattern, TritonStorePattern,
|
||||
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
|
||||
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
|
||||
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
|
||||
typeConverter, context);
|
||||
TritonGenericPattern<triton::ExternElementwiseOp>, TritonPrintPattern,
|
||||
TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern,
|
||||
TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -882,5 +882,17 @@ void ElementwiseInlineAsmOp::getEffects(
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
// -- ExternElementwiseOp --
|
||||
void ExternElementwiseOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (getPure())
|
||||
return;
|
||||
effects.emplace_back(MemoryEffects::Write::get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Read::get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -644,9 +644,10 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
math::RsqrtOp, math::SqrtOp, math::TanhOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
|
||||
triton::FpToFpOp, triton::AddPtrOp,
|
||||
triton::PureExternElementwiseOp>(op))
|
||||
triton::FpToFpOp, triton::AddPtrOp>(op))
|
||||
return true;
|
||||
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
|
||||
return externElementwiseOp.getPure();
|
||||
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
|
||||
@@ -1424,12 +1424,8 @@ void init_triton_ir(py::module &&m) {
|
||||
const std::string &libPath, const std::string &symbol,
|
||||
std::vector<mlir::Value> &argList, mlir::Type retType,
|
||||
bool isPure) -> mlir::Value {
|
||||
if (isPure)
|
||||
return self.create<mlir::triton::PureExternElementwiseOp>(
|
||||
retType, argList, libName, libPath, symbol);
|
||||
else
|
||||
return self.create<mlir::triton::ImpureExternElementwiseOp>(
|
||||
retType, argList, libName, libPath, symbol);
|
||||
return self.create<mlir::triton::ExternElementwiseOp>(
|
||||
retType, argList, libName, libPath, symbol, isPure);
|
||||
})
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id",
|
||||
|
||||
Reference in New Issue
Block a user