mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
fix(compiler): Add a pass to unparametrize LowLFHE ciphertext to remove unrelized_convesrion_cast for linalg bufferization
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEUnparametrizePass();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
|
||||
#include "zamalang/Conversion/LowLFHEUnparametrize/Pass.h"
|
||||
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h"
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h"
|
||||
|
||||
@@ -38,6 +38,12 @@ def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp">
|
||||
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LowLFHEUnparametrize : Pass<"lowlfhe-unparametrize", "mlir::ModuleOp"> {
|
||||
let summary = "Unparametrize LowLFHE types and remove unrealized_conversion_cast";
|
||||
let constructor = "mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::zamalang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
|
||||
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";
|
||||
|
||||
@@ -4,3 +4,4 @@ add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(LowLFHEToConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LowLFHEUnparametrize)
|
||||
|
||||
16
compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt
Normal file
16
compiler/lib/Conversion/LowLFHEUnparametrize/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(LowLFHEUnparametrize
|
||||
LowLFHEUnparametrize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
|
||||
|
||||
DEPENDS
|
||||
LowLFHEDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(LowLFHEUnparametrize PUBLIC MLIRIR)
|
||||
@@ -0,0 +1,129 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
|
||||
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
|
||||
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
/// LowLFHE types
|
||||
class LowLFHEUnparametrizeTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
static mlir::Type unparematrizeLowLFHEType(mlir::Type type) {
|
||||
if (type.isa<mlir::zamalang::LowLFHE::PlaintextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::CleartextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
return mlir::zamalang::LowLFHE::LweCiphertextType::get(type.getContext(),
|
||||
-1, -1);
|
||||
}
|
||||
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensorType != nullptr) {
|
||||
auto eltTy0 = tensorType.getElementType();
|
||||
auto eltTy1 = unparematrizeLowLFHEType(eltTy0);
|
||||
if (eltTy0 == eltTy1) {
|
||||
return type;
|
||||
}
|
||||
return mlir::RankedTensorType::get(tensorType.getShape(), eltTy1);
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
LowLFHEUnparametrizeTypeConverter() {
|
||||
addConversion(
|
||||
[](mlir::Type type) { return unparematrizeLowLFHEType(type); });
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
|
||||
/// t1 are a LowLFHE type.
|
||||
struct LowLFHEUnrealizedCastReplacementPattern
|
||||
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
|
||||
LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (mlir::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
|
||||
op.getOperandTypes()[0].getDialect()) ||
|
||||
mlir::isa<mlir::zamalang::LowLFHE::LowLFHEDialect>(
|
||||
op.getType(0).getDialect())) {
|
||||
rewriter.replaceOp(op, op.getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
};
|
||||
};
|
||||
|
||||
/// LowLFHEUnparametrizePass remove all parameters of LowLFHE types and remove
|
||||
/// the unrealized_conversion_cast operation that operates on parametrized
|
||||
/// LowLFHE types.
|
||||
struct LowLFHEUnparametrizePass
|
||||
: public LowLFHEUnparametrizeBase<LowLFHEUnparametrizePass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
LowLFHEUnparametrizeTypeConverter converter;
|
||||
|
||||
// Conversion of linalg.generic operation
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
|
||||
[&](mlir::linalg::GenericOp op) {
|
||||
return (converter.isLegal(op.getOperandTypes()) &&
|
||||
converter.isLegal(op.getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
patterns.add<
|
||||
LinalgGenericTypeConverterPattern<LowLFHEUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Conversion of function signature and arguments
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Replacement of unrealized_conversion_cast
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::UnrealizedConversionCastOp>(
|
||||
target, converter);
|
||||
patterns.add<LowLFHEUnrealizedCastReplacementPattern>(patterns.getContext());
|
||||
|
||||
// Conversion of tensor operators
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
// Conversion of CallOp
|
||||
patterns.add<mlir::zamalang::GenericTypeConverterPattern<mlir::CallOp>>(
|
||||
patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::CallOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEUnparametrizePass() {
|
||||
return std::make_unique<LowLFHEUnparametrizePass>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -14,6 +14,7 @@ add_mlir_library(ZamalangSupport
|
||||
LINK_LIBS PUBLIC
|
||||
HLFHETensorOpsToLinalg
|
||||
HLFHEToMidLFHE
|
||||
LowLFHEUnparametrize
|
||||
MLIRLowerableDialectsToLLVM
|
||||
|
||||
MLIRExecutionEngine
|
||||
|
||||
@@ -89,6 +89,11 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
pm.enableVerifier();
|
||||
}
|
||||
|
||||
// Unparametrize LowLFHE
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(),
|
||||
options.enablePass);
|
||||
|
||||
// Bufferize
|
||||
addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(),
|
||||
options.enablePass);
|
||||
@@ -105,6 +110,7 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(),
|
||||
options.enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
options.enablePass);
|
||||
|
||||
Reference in New Issue
Block a user