Files
concrete/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp
Andi Drebes c8c969773e Rebase onto llvm-project 465ee9bfb26d with local changes
This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:

  * 5d8669d669ee: Fix the element alignment (size) for memrefCopy
  * 4239163ea337: fix: Do not fold the memref.subview if the offset are
                  != 0 and strides != 1
  * 72c5decfcc21: remove github stuff from llvm
  * 8d0ce8f9eca1: Support arbitrary element types in named operations
                  via attributes
  * 94f64805c38c: Copy attributes of scf.for on bufferization and make
                  it an allocation hoisting barrier

Main upstream changes from llvm-project that required modification of
concretecompiler:

  * Switch to C++17
  * Various changes in the interfaces for linalg named operations
  * Transition from `llvm::Optional` to `std::optional`
  * Use of enums instead of string values for iterator types in linalg
  * Changed default naming convention of getter methods in
    ODS-generated operation classes from `some_value()` to
    `getSomeValue()`
  * Renaming of Arithmetic dialect to Arith
  * Refactoring of side effect interfaces (i.e., renaming from
    `NoSideEffect` to `Pure`)
  * Re-design of the data flow analysis framework
  * Refactoring of build targets for Python bindings
  * Refactoring of array attributes with integer values
  * Renaming of `linalg.init_tensor` to `tensor.empty`
  * Emission of `linalg.map` operations in bufferization of the Tensor
    dialect requiring another linalg conversion pass and registration
    of the bufferization op interfaces for linalg operations
  * Refactoring of the one-shot bufferizer
  * Necessity to run the expand-strided-metadata, affine-to-std and
    finalize-memref-to-llvm passes before converson to the LLVM
    dialect
  * Renaming of `BlockAndValueMapping` to `IRMapping`
  * Changes in the build function of `LLVM::CallOp`
  * Refactoring of the construction of `llvm::ArrayRef` and
    `llvm::MutableArrayRef` (direct invocation of constructor instead
    of builder functions for some cases)
  * New naming conventions for generated SSA values requiring rewrite
    of some check tests
  * Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
  * Interface changes in generated type parsers
  * New dependencies for to mlir_float16_utils and
    MLIRSparseTensorRuntime for the runtime
  * Overhaul of MLIR-c deleting `mlir-c/Registration.h`
  * Deletion of library MLIRLinalgToSPIRV
  * Deletion of library MLIRLinalgAnalysis
  * Deletion of library MLIRMemRefUtils
  * Deletion of library MLIRQuantTransforms
  * Deletion of library MLIRVectorToROCDL
2023-03-09 17:47:16 +01:00

159 lines
6.7 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/IR/RTTypes.h>
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <concretelang/Conversion/Utils/FuncConstOpConversion.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <concretelang/Conversion/Utils/Legality.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Bufferization/Transforms/Bufferize.h>
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/Transforms/RegionUtils.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace concretelang {
namespace {
class BufferizeRTTypesConverter
: public mlir::bufferization::BufferizeTypeConverter {
public:
BufferizeRTTypesConverter() {
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
addConversion([&](mlir::FunctionType type) {
SignatureConversion result(type.getNumInputs());
mlir::SmallVector<mlir::Type, 1> newResults;
if (failed(this->convertSignatureArgs(type.getInputs(), result)) ||
failed(this->convertTypes(type.getResults(), newResults)))
return type;
return mlir::FunctionType::get(type.getContext(),
result.getConvertedTypes(), newResults);
});
}
};
} // namespace
namespace {
/// For documentation see Autopar.td
struct BufferizeDataflowTaskOpsPass
: public BufferizeDataflowTaskOpsBase<BufferizeDataflowTaskOpsPass> {
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
BufferizeRTTypesConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, typeConverter);
patterns.add<FunctionConstantOpConversion<BufferizeRTTypesConverter>>(
context, typeConverter);
target.addDynamicallyLegalDialect<mlir::func::FuncDialect>([&](Operation
*op) {
if (auto fun = dyn_cast_or_null<mlir::func::FuncOp>(op))
return typeConverter.isSignatureLegal(fun.getFunctionType()) &&
typeConverter.isLegal(&fun.getBody());
if (auto fun = dyn_cast_or_null<mlir::func::ConstantOp>(op))
return FunctionConstantOpConversion<BufferizeRTTypesConverter>::isLegal(
fun, typeConverter);
return typeConverter.isLegal(op);
});
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
typeConverter);
// Conversion of RT Dialect Ops
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target,
typeConverter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createBufferizeDataflowTaskOpsPass(bool debug) {
return std::make_unique<BufferizeDataflowTaskOpsPass>(debug);
}
} // namespace concretelang
} // namespace mlir