mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
Rebase onto llvm-project f69328049e9e with local changes
This commit rebases the compiler onto commit f69328049e9e from llvm-project. Changes: * Use of the one-shot bufferizer for improved memory management * A new pass `OneShotBufferizeDPSWrapper` that converts functions returning tensors to destination-passing-style as required by the one-shot bufferizer * A new pass `LinalgGenericOpWithTensorsToLoopsPass` that converts `linalg.generic` operations with value semantics to loop nests * Rebase onto a fork of llvm-project at f69328049e9e with local modifications to enable bufferization of `linalg.generic` operations with value semantics * Workaround for the absence of type propagation after type conversion via extra patterns in all dialect conversion passes * Printer, parser and verifier definitions moved from inline declarations in ODS to the respective source files as required by upstream changes * New tests for functions with a large number of inputs * Increase the number of allowed task inputs as required by new tests * Use upstream function `mlir_configure_python_dev_packages()` to locate Python development files for compatibility with various CMake versions Co-authored-by: Quentin Bourgerie <quentin.bourgerie@zama.ai> Co-authored-by: Ayoub Benaissa <ayoub.benaissa@zama.ai> Co-authored-by: Antoniu Pop <antoniu.pop@zama.ai>
This commit is contained in:
91
compiler/lib/Transforms/ForLoopToParallel.cpp
Normal file
91
compiler/lib/Transforms/ForLoopToParallel.cpp
Normal file
@@ -0,0 +1,91 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Transforms/Bufferize.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
namespace {
|
||||
class ForOpPattern : public mlir::OpRewritePattern<mlir::scf::ForOp> {
|
||||
public:
|
||||
ForOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::scf::ForOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::scf::ForOp forOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto attr = forOp->getAttrOfType<mlir::BoolAttr>("parallel");
|
||||
if (attr == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
assert(forOp.getRegionIterArgs().size() == 0 &&
|
||||
"unexpecting iter args when loops are bufferized");
|
||||
if (attr.getValue()) {
|
||||
rewriter.replaceOpWithNewOp<mlir::scf::ParallelOp>(
|
||||
forOp, mlir::ValueRange{forOp.getLowerBound()},
|
||||
mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), llvm::None,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location location,
|
||||
mlir::ValueRange indVar, mlir::ValueRange iterArgs) {
|
||||
mlir::BlockAndValueMapping map;
|
||||
map.map(forOp.getInductionVar(), indVar.front());
|
||||
for (auto &op : forOp.getRegion().front()) {
|
||||
auto newOp = builder.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<mlir::scf::ForOp>(
|
||||
forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
||||
llvm::None,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location location,
|
||||
mlir::Value indVar, mlir::ValueRange iterArgs) {
|
||||
mlir::BlockAndValueMapping map;
|
||||
map.map(forOp.getInductionVar(), indVar);
|
||||
for (auto &op : forOp.getRegion().front()) {
|
||||
auto newOp = builder.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct ForLoopToParallelPass
|
||||
: public ForLoopToParallelBase<ForLoopToParallelPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
mlir::ConversionTarget target(*context);
|
||||
patterns.add<ForOpPattern>(context);
|
||||
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp op) {
|
||||
auto r = op->getAttrOfType<mlir::BoolAttr>("parallel") == nullptr;
|
||||
return r;
|
||||
});
|
||||
target.markUnknownOpDynamicallyLegal(
|
||||
[&](mlir::Operation *op) { return true; });
|
||||
if (mlir::applyPatternsAndFoldGreedily(func, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
};
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
mlir::concretelang::createForLoopToParallel() {
|
||||
return std::make_unique<ForLoopToParallelPass>();
|
||||
}
|
||||
Reference in New Issue
Block a user