feat: optimize concrete mul with constant values

This commit is contained in:
youben11
2022-05-19 14:19:49 +01:00
committed by Ayoub Benaissa
parent c42f0a1ada
commit 8d0f20390c
12 changed files with 231 additions and 0 deletions

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Optimization.td)
mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangConcreteOptimizationPassIncGen)
add_dependencies(mlir-headers ConcretelangConcreteOptimizationPassIncGen)

View File

@@ -0,0 +1,21 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H
#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createConcreteOptimizationPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS
#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS
include "mlir/Pass/PassBase.td"
def ConcreteOptimization : Pass<"concrete-optimization"> {
let summary = "Optimize Concrete operations";
let constructor = "mlir::concretelang::createConcreteOptimizationPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::Concrete::ConcreteDialect" ];
}
#endif

View File

@@ -47,6 +47,10 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops);
mlir::LogicalResult
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,14 @@
add_mlir_library(ConcreteDialectTransforms
Optimization.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete
DEPENDS
ConcreteDialect
mlir-headers
LINK_LIBS PUBLIC
MLIRIR
ConcreteDialect
)

View File

@@ -0,0 +1,112 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
#include <concretelang/Support/Constants.h>
namespace mlir {
namespace concretelang {
namespace {
// Get the integer value that the cleartext was created from if it exists.
llvm::Optional<mlir::Value>
getIntegerFromCleartextIfExists(mlir::Value cleartext) {
assert(
cleartext.getType().isa<mlir::concretelang::Concrete::CleartextType>());
// Cleartext are supposed to be created from integers
auto intToCleartextOp = cleartext.getDefiningOp();
if (intToCleartextOp == nullptr)
return {};
if (llvm::isa<Concrete::IntToCleartextOp>(intToCleartextOp)) {
// We want to match when the integer value is constant
return intToCleartextOp->getOperand(0);
}
return {};
}
// Get the constant integer that the cleartext was created from if it exists.
llvm::Optional<IntegerAttr>
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
auto cleartextInt = getIntegerFromCleartextIfExists(cleartext);
if (!cleartextInt.hasValue())
return {};
auto constantOp = cleartextInt.getValue().getDefiningOp();
if (constantOp == nullptr)
return {};
if (llvm::isa<arith::ConstantOp>(constantOp)) {
auto constIntToMul = constantOp->getAttrOfType<mlir::IntegerAttr>("value");
if (constIntToMul != nullptr)
return constIntToMul;
}
return {};
}
// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a
// `Concrete.zero` operation if it's being multiplied with a constant 0, or as a
// `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1.
class MulCleartextLweCiphertextOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp> {
public:
MulCleartextLweCiphertextOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::MulCleartextLweCiphertextOp op,
mlir::PatternRewriter &rewriter) const override {
auto cleartext = op.getOperand(1);
auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext);
// Constant integer
if (constIntToMul.hasValue()) {
auto toMul = constIntToMul.getValue().getInt();
if (toMul == 0) {
rewriter.replaceOpWithNewOp<mlir::concretelang::Concrete::ZeroLWEOp>(
op, op.getResult().getType());
return mlir::success();
}
if (toMul == -1) {
rewriter.replaceOpWithNewOp<
mlir::concretelang::Concrete::NegateLweCiphertextOp>(
op, op.getResult().getType(), op.getOperand(0));
return mlir::success();
}
}
return mlir::failure();
}
};
// Optimization pass that should choose more efficient ways of performing crypto
// operations.
class ConcreteOptimizationPass
: public ConcreteOptimizationBase<ConcreteOptimizationPass> {
public:
void runOnOperation() override {
mlir::Operation *op = getOperation();
mlir::RewritePatternSet patterns(op->getContext());
patterns.add<MulCleartextLweCiphertextOpPattern>(op->getContext());
if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) {
this->signalPassFailure();
}
}
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<>> createConcreteOptimizationPass() {
return std::make_unique<ConcreteOptimizationPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -29,6 +29,7 @@ add_mlir_library(ConcretelangSupport
FHEDialectAnalysis
RTDialectAnalysis
ConcretelangTransforms
ConcreteDialectTransforms
concrete_optimizer
MLIRExecutionEngine

View File

@@ -246,6 +246,14 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
.failed()) {
return errorDiag("Lowering from TFHE to Concrete failed");
}
// Optimizing Concrete
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
this->enablePass)
.failed()) {
return errorDiag("Optimizing Concrete failed");
}
if (target == Target::CONCRETE)
return std::move(res);
@@ -280,6 +288,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
}
}
// Optimize Concrete
if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
this->enablePass)
.failed()) {
return StreamStringError("Optimizing Concrete failed");
}
// Concrete -> BConcrete
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
mlirContext, module, this->enablePass, loopParallelize)

View File

@@ -20,6 +20,7 @@
#include <mlir/Transforms/Passes.h>
#include <concretelang/Conversion/Passes.h>
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
@@ -186,6 +187,17 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteOptimization", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConcreteOptimizationPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,

View File

@@ -0,0 +1,33 @@
// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.zero"() : () -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%0 = arith.constant 0 : i7
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %2: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%0 = arith.constant -1 : i7
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %2: !Concrete.lwe_ciphertext<2048,7>
}