mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: optimize concrete mul with constant values
This commit is contained in:
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Optimization.td)
|
||||
mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangConcreteOptimizationPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangConcreteOptimizationPassIncGen)
|
||||
@@ -0,0 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H
|
||||
#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::OperationPass<>> createConcreteOptimizationPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS
|
||||
#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConcreteOptimization : Pass<"concrete-optimization"> {
|
||||
let summary = "Optimize Concrete operations";
|
||||
let constructor = "mlir::concretelang::createConcreteOptimizationPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::Concrete::ConcreteDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -47,6 +47,10 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
|
||||
mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
14
compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt
Normal file
14
compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_mlir_library(ConcreteDialectTransforms
|
||||
Optimization.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
mlir-headers
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
ConcreteDialect
|
||||
)
|
||||
112
compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp
Normal file
112
compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace {
|
||||
|
||||
// Get the integer value that the cleartext was created from if it exists.
|
||||
llvm::Optional<mlir::Value>
|
||||
getIntegerFromCleartextIfExists(mlir::Value cleartext) {
|
||||
assert(
|
||||
cleartext.getType().isa<mlir::concretelang::Concrete::CleartextType>());
|
||||
// Cleartext are supposed to be created from integers
|
||||
auto intToCleartextOp = cleartext.getDefiningOp();
|
||||
if (intToCleartextOp == nullptr)
|
||||
return {};
|
||||
if (llvm::isa<Concrete::IntToCleartextOp>(intToCleartextOp)) {
|
||||
// We want to match when the integer value is constant
|
||||
return intToCleartextOp->getOperand(0);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get the constant integer that the cleartext was created from if it exists.
|
||||
llvm::Optional<IntegerAttr>
|
||||
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
|
||||
auto cleartextInt = getIntegerFromCleartextIfExists(cleartext);
|
||||
if (!cleartextInt.hasValue())
|
||||
return {};
|
||||
auto constantOp = cleartextInt.getValue().getDefiningOp();
|
||||
if (constantOp == nullptr)
|
||||
return {};
|
||||
if (llvm::isa<arith::ConstantOp>(constantOp)) {
|
||||
auto constIntToMul = constantOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
if (constIntToMul != nullptr)
|
||||
return constIntToMul;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a
|
||||
// `Concrete.zero` operation if it's being multiplied with a constant 0, or as a
|
||||
// `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1.
|
||||
class MulCleartextLweCiphertextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp> {
|
||||
public:
|
||||
MulCleartextLweCiphertextOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::MulCleartextLweCiphertextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cleartext = op.getOperand(1);
|
||||
auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext);
|
||||
// Constant integer
|
||||
if (constIntToMul.hasValue()) {
|
||||
auto toMul = constIntToMul.getValue().getInt();
|
||||
if (toMul == 0) {
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::Concrete::ZeroLWEOp>(
|
||||
op, op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
if (toMul == -1) {
|
||||
rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::Concrete::NegateLweCiphertextOp>(
|
||||
op, op.getResult().getType(), op.getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
// Optimization pass that should choose more efficient ways of performing crypto
|
||||
// operations.
|
||||
class ConcreteOptimizationPass
|
||||
: public ConcreteOptimizationBase<ConcreteOptimizationPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
mlir::Operation *op = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(op->getContext());
|
||||
patterns.add<MulCleartextLweCiphertextOpPattern>(op->getContext());
|
||||
|
||||
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createConcreteOptimizationPass() {
|
||||
return std::make_unique<ConcreteOptimizationPass>();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -29,6 +29,7 @@ add_mlir_library(ConcretelangSupport
|
||||
FHEDialectAnalysis
|
||||
RTDialectAnalysis
|
||||
ConcretelangTransforms
|
||||
ConcreteDialectTransforms
|
||||
concrete_optimizer
|
||||
|
||||
MLIRExecutionEngine
|
||||
|
||||
@@ -246,6 +246,14 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from TFHE to Concrete failed");
|
||||
}
|
||||
|
||||
// Optimizing Concrete
|
||||
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETE)
|
||||
return std::move(res);
|
||||
|
||||
@@ -280,6 +288,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize Concrete
|
||||
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <concretelang/Conversion/Passes.h>
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
|
||||
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
|
||||
@@ -186,6 +187,17 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteOptimization", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConcreteOptimizationPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
|
||||
33
compiler/tests/Dialect/Concrete/Concrete/optimization.mlir
Normal file
33
compiler/tests/Dialect/Concrete/Concrete/optimization.mlir
Normal file
@@ -0,0 +1,33 @@
|
||||
// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.zero"() : () -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant 0 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant -1 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
Reference in New Issue
Block a user