mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
feat(compiler): Integrate HLFHELinalg tiling passes into compilation pipeline
This commit is contained in:
@@ -148,12 +148,15 @@ public:
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
@@ -16,6 +16,15 @@ llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::ArrayRef<int64_t> tileSizes,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -84,6 +84,10 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
@@ -98,6 +98,9 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Conversion of function signature and arguments
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
|
||||
@@ -294,6 +294,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::tensor::GenerateOp, MidLFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::scf::ForOp, MidLFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
|
||||
@@ -58,14 +58,13 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp, mlir::scf::ForOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
@@ -83,6 +82,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -22,6 +22,7 @@ add_mlir_library(ZamalangSupport
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
HLFHELinalgDialect
|
||||
HLFHELinalgDialectTransforms
|
||||
HLFHETensorOpsToLinalg
|
||||
HLFHEToMidLFHE
|
||||
LowLFHEUnparametrize
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/SCF/SCF.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
@@ -53,6 +54,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
this->mlirContext->getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::scf::SCFDialect>();
|
||||
}
|
||||
|
||||
return this->mlirContext;
|
||||
@@ -94,6 +96,10 @@ void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
|
||||
this->clientParametersFuncName = name.str();
|
||||
}
|
||||
|
||||
void CompilerEngine::setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes) {
|
||||
this->hlfhelinalgTileSizes = sizes.vec();
|
||||
}
|
||||
|
||||
void CompilerEngine::setEnablePass(
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
this->enablePass = enablePass;
|
||||
@@ -192,6 +198,21 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
// HLFHE High level pass to determine FHE parameters
|
||||
if (auto err = this->determineFHEParameters(res))
|
||||
return std::move(err);
|
||||
|
||||
// HLFHELinalg tiling
|
||||
if (this->hlfhelinalgTileSizes) {
|
||||
if (mlir::zamalang::pipeline::markHLFHELinalgForTiling(
|
||||
mlirContext, module, *this->hlfhelinalgTileSizes, enablePass)
|
||||
.failed())
|
||||
return errorDiag("Marking of HLFHELinalg operations for tiling failed");
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::tileMarkedHLFHELinalg(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Tiling of HLFHELinalg operations failed");
|
||||
}
|
||||
|
||||
if (target == Target::HLFHE)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
|
||||
#include <mlir/Dialect/Linalg/Passes.h>
|
||||
#include <mlir/Dialect/SCF/Passes.h>
|
||||
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Tensor/Transforms/Passes.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
@@ -13,6 +14,7 @@
|
||||
|
||||
#include <zamalang/Conversion/Passes.h>
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h>
|
||||
#include <zamalang/Support/Pipeline.h>
|
||||
#include <zamalang/Support/logging.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
@@ -100,6 +102,29 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return ret;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("TileMarkedHLFHELinalg", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::zamalang::createHLFHELinalgTilingPass(),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::ArrayRef<int64_t> tileSizes,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("MarkHLFHELinalgForTiling", pm, context);
|
||||
addPotentiallyNestedPass(pm, createHLFHELinalgTilingMarkerPass(tileSizes),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
@@ -163,6 +188,7 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
|
||||
enablePass);
|
||||
|
||||
Reference in New Issue
Block a user