mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
enhance(compiler): Add custom finalize bufferize pass to handle memref.tensor_load op
This commit is contained in:
committed by
Quentin Bourgerie
parent
626493dda7
commit
8a9cce64e3
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Transforms)
|
||||
21
compiler/include/concretelang/Transforms/Bufferize.h
Normal file
21
compiler/include/concretelang/Transforms/Bufferize.h
Normal file
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BUFFERIZE_PASS_H
|
||||
#define CONCRETELANG_BUFFERIZE_PASS_H
|
||||
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Transforms/Bufferize.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::FunctionPass> createFinalizingBufferizePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
14
compiler/include/concretelang/Transforms/Bufferize.td
Normal file
14
compiler/include/concretelang/Transforms/Bufferize.td
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef CONCRETELANG_FHELINALG_TILING_PASS
|
||||
#define CONCRETELANG_FHELINALG_TILING_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FinalizingBufferize : FunctionPass<"concretelang-bufferize"> {
|
||||
let summary =
|
||||
"Marks FHELinalg operations for tiling using a vector of tile sizes";
|
||||
let constructor = "mlir::concretelang::createBufferizePass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::memref::MemRefDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
3
compiler/include/concretelang/Transforms/CMakeLists.txt
Normal file
3
compiler/include/concretelang/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Bufferize.td)
|
||||
mlir_tablegen(Bufferize.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangTransformsPassIncGen)
|
||||
@@ -1,5 +1,6 @@
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(Runtime)
|
||||
add_subdirectory(ClientLib)
|
||||
|
||||
@@ -26,6 +26,7 @@ add_mlir_library(ConcretelangSupport
|
||||
MLIRLowerableDialectsToLLVM
|
||||
FHEDialectAnalysis
|
||||
RTDialectAnalysis
|
||||
ConcretelangTransforms
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include <concretelang/Support/Pipeline.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
#include <concretelang/Support/math.h>
|
||||
#include <concretelang/Transforms/Bufferize.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -234,8 +235,9 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFinalizingBufferizePass(), enablePass);
|
||||
|
||||
if (parallelizeLoops)
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertSCFToOpenMPPass(),
|
||||
enablePass);
|
||||
|
||||
73
compiler/lib/Transforms/Bufferize.cpp
Normal file
73
compiler/lib/Transforms/Bufferize.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "concretelang/Transforms/Bufferize.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// In a finalizing bufferize conversion, we know that all tensors have been
|
||||
// converted to memrefs, thus, this op becomes an identity.
|
||||
class BufferizeTensorStoreOp
|
||||
: public OpConversionPattern<memref::TensorStoreOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::TensorStoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<memref::CopyOp>(op, op.tensor(), op.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void populatePatterns(BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
mlir::populateEliminateBufferizeMaterializationsPatterns(typeConverter,
|
||||
patterns);
|
||||
patterns.add<BufferizeTensorStoreOp>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct FinalizingBufferizePass
|
||||
: public FinalizingBufferizeBase<FinalizingBufferizePass> {
|
||||
using FinalizingBufferizeBase<
|
||||
FinalizingBufferizePass>::FinalizingBufferizeBase;
|
||||
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
|
||||
BufferizeTypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
populatePatterns(typeConverter, patterns);
|
||||
|
||||
// If all result types are legal, and all block arguments are legal (ensured
|
||||
// by func conversion above), then all types in the program are legal.
|
||||
//
|
||||
// We also check that the operand types are legal to avoid creating invalid
|
||||
// IR. For example, this prevents
|
||||
// populateEliminateBufferizeMaterializationsPatterns from updating the
|
||||
// types of the operands to a return op without updating the enclosing
|
||||
// function.
|
||||
target.markUnknownOpDynamicallyLegal(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
target.addLegalOp<memref::CopyOp>();
|
||||
|
||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<FunctionPass>
|
||||
mlir::concretelang::createFinalizingBufferizePass() {
|
||||
return std::make_unique<FinalizingBufferizePass>();
|
||||
}
|
||||
16
compiler/lib/Transforms/CMakeLists.txt
Normal file
16
compiler/lib/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_library(ConcretelangTransforms
|
||||
Bufferize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRTransforms
|
||||
ConcretelangTransformsPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRMemRef
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(FHELinalgDialectTransforms PUBLIC MLIRIR)
|
||||
Reference in New Issue
Block a user