mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-14 15:27:58 -05:00
This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:
* 5d8669d669ee: Fix the element alignment (size) for memrefCopy
* 4239163ea337: fix: Do not fold the memref.subview if the offset are
!= 0 and strides != 1
* 72c5decfcc21: remove github stuff from llvm
* 8d0ce8f9eca1: Support arbitrary element types in named operations
via attributes
* 94f64805c38c: Copy attributes of scf.for on bufferization and make
it an allocation hoisting barrier
Main upstream changes from llvm-project that required modification of
concretecompiler:
* Switch to C++17
* Various changes in the interfaces for linalg named operations
* Transition from `llvm::Optional` to `std::optional`
* Use of enums instead of string values for iterator types in linalg
* Changed default naming convention of getter methods in
ODS-generated operation classes from `some_value()` to
`getSomeValue()`
* Renaming of Arithmetic dialect to Arith
* Refactoring of side effect interfaces (i.e., renaming from
`NoSideEffect` to `Pure`)
* Re-design of the data flow analysis framework
* Refactoring of build targets for Python bindings
* Refactoring of array attributes with integer values
* Renaming of `linalg.init_tensor` to `tensor.empty`
* Emission of `linalg.map` operations in bufferization of the Tensor
dialect requiring another linalg conversion pass and registration
of the bufferization op interfaces for linalg operations
* Refactoring of the one-shot bufferizer
* Necessity to run the expand-strided-metadata, affine-to-std and
finalize-memref-to-llvm passes before converson to the LLVM
dialect
* Renaming of `BlockAndValueMapping` to `IRMapping`
* Changes in the build function of `LLVM::CallOp`
* Refactoring of the construction of `llvm::ArrayRef` and
`llvm::MutableArrayRef` (direct invocation of constructor instead
of builder functions for some cases)
* New naming conventions for generated SSA values requiring rewrite
of some check tests
* Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
* Interface changes in generated type parsers
* New dependencies for to mlir_float16_utils and
MLIRSparseTensorRuntime for the runtime
* Overhaul of MLIR-c deleting `mlir-c/Registration.h`
* Deletion of library MLIRLinalgToSPIRV
* Deletion of library MLIRLinalgAnalysis
* Deletion of library MLIRMemRefUtils
* Deletion of library MLIRQuantTransforms
* Deletion of library MLIRVectorToROCDL
93 lines
3.3 KiB
C++
93 lines
3.3 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
|
|
|
namespace {
|
|
class ForOpPattern : public mlir::OpRewritePattern<mlir::scf::ForOp> {
|
|
public:
|
|
ForOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
|
: ::mlir::OpRewritePattern<mlir::scf::ForOp>(context, benefit) {}
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mlir::scf::ForOp forOp,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto attr = forOp->getAttrOfType<mlir::BoolAttr>("parallel");
|
|
if (attr == nullptr) {
|
|
return mlir::failure();
|
|
}
|
|
assert(forOp.getRegionIterArgs().size() == 0 &&
|
|
"unexpecting iter args when loops are bufferized");
|
|
if (attr.getValue()) {
|
|
rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
|
|
forOp, mlir::ValueRange{forOp.getLowerBound()},
|
|
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(),
|
|
std::nullopt,
|
|
[&](mlir::OpBuilder &builder, mlir::Location location,
|
|
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
|
|
mlir::IRMapping map;
|
|
map.map(forOp.getInductionVar(), indVar.front());
|
|
for (auto &op : forOp.getRegion().front()) {
|
|
auto newOp = builder.clone(op, map);
|
|
map.map(op.getResults(), newOp->getResults());
|
|
}
|
|
});
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<mlir::scf::ForOp>(
|
|
forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
|
std::nullopt,
|
|
[&](mlir::OpBuilder &builder, mlir::Location location,
|
|
mlir::Value indVar, mlir::ValueRange iterArgs) {
|
|
mlir::IRMapping map;
|
|
map.map(forOp.getInductionVar(), indVar);
|
|
for (auto &op : forOp.getRegion().front()) {
|
|
auto newOp = builder.clone(op, map);
|
|
map.map(op.getResults(), newOp->getResults());
|
|
}
|
|
});
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
struct ForLoopToParallelPass
|
|
: public ForLoopToParallelBase<ForLoopToParallelPass> {
|
|
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto *context = &getContext();
|
|
mlir::RewritePatternSet patterns(context);
|
|
mlir::ConversionTarget target(*context);
|
|
patterns.add<ForOpPattern>(context);
|
|
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp op) {
|
|
auto r = op->getAttrOfType<mlir::BoolAttr>("parallel") == nullptr;
|
|
return r;
|
|
});
|
|
target.markUnknownOpDynamicallyLegal(
|
|
[&](mlir::Operation *op) { return true; });
|
|
if (mlir::applyPatternsAndFoldGreedily(func, std::move(patterns))
|
|
.failed()) {
|
|
this->signalPassFailure();
|
|
};
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
|
mlir::concretelang::createForLoopToParallel() {
|
|
return std::make_unique<ForLoopToParallelPass>();
|
|
}
|