diff --git a/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt b/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/Concrete/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt new file mode 100644 index 000000000..425267561 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Optimization.td) +mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangConcreteOptimizationPassIncGen) +add_dependencies(mlir-headers ConcretelangConcreteOptimizationPassIncGen) diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h b/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h new file mode 100644 index 000000000..d13407290 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.h @@ -0,0 +1,21 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H +#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { +std::unique_ptr> createConcreteOptimizationPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td b/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td new file mode 100644 index 000000000..42bc363e7 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Concrete/Transforms/Optimization.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_CONCRETE_OPTIMIZATION_PASS +#define CONCRETELANG_CONCRETE_OPTIMIZATION_PASS + +include "mlir/Pass/PassBase.td" + +def ConcreteOptimization : Pass<"concrete-optimization"> { + let summary = "Optimize Concrete operations"; + let constructor = "mlir::concretelang::createConcreteOptimizationPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::Concrete::ConcreteDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index e5e6cc891..bd58cebdf 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -47,6 +47,10 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, bool parallelizeLoops); +mlir::LogicalResult +optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Dialect/Concrete/CMakeLists.txt b/compiler/lib/Dialect/Concrete/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/lib/Dialect/Concrete/CMakeLists.txt +++ b/compiler/lib/Dialect/Concrete/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt new file mode 100644 index 000000000..311c2dacf --- /dev/null +++ b/compiler/lib/Dialect/Concrete/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_library(ConcreteDialectTransforms + Optimization.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Concrete + + DEPENDS + ConcreteDialect + mlir-headers + + LINK_LIBS PUBLIC + MLIRIR + ConcreteDialect +) diff --git a/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp b/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp new file mode 100644 index 000000000..afb024dc9 --- /dev/null +++ b/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp @@ -0,0 +1,112 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include +#include +#include + +#include +#include +#include + +namespace mlir { +namespace concretelang { + +namespace { + +// Get the integer value that the cleartext was created from if it exists. +llvm::Optional +getIntegerFromCleartextIfExists(mlir::Value cleartext) { + assert( + cleartext.getType().isa()); + // Cleartext are supposed to be created from integers + auto intToCleartextOp = cleartext.getDefiningOp(); + if (intToCleartextOp == nullptr) + return {}; + if (llvm::isa(intToCleartextOp)) { + // We want to match when the integer value is constant + return intToCleartextOp->getOperand(0); + } + return {}; +} + +// Get the constant integer that the cleartext was created from if it exists. +llvm::Optional +getConstantIntFromCleartextIfExists(mlir::Value cleartext) { + auto cleartextInt = getIntegerFromCleartextIfExists(cleartext); + if (!cleartextInt.hasValue()) + return {}; + auto constantOp = cleartextInt.getValue().getDefiningOp(); + if (constantOp == nullptr) + return {}; + if (llvm::isa(constantOp)) { + auto constIntToMul = constantOp->getAttrOfType("value"); + if (constIntToMul != nullptr) + return constIntToMul; + } + return {}; +} + +// Rewrite a `Concrete.mul_cleartext_lwe_ciphertext` operation as a +// `Concrete.zero` operation if it's being multiplied with a constant 0, or as a +// `Concrete.negate_lwe_ciphertext` if multiplied with a constant -1. +class MulCleartextLweCiphertextOpPattern + : public mlir::OpRewritePattern< + mlir::concretelang::Concrete::MulCleartextLweCiphertextOp> { +public: + MulCleartextLweCiphertextOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern< + mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::MulCleartextLweCiphertextOp op, + mlir::PatternRewriter &rewriter) const override { + auto cleartext = op.getOperand(1); + auto constIntToMul = getConstantIntFromCleartextIfExists(cleartext); + // Constant integer + if (constIntToMul.hasValue()) { + auto toMul = constIntToMul.getValue().getInt(); + if (toMul == 0) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType()); + return mlir::success(); + } + if (toMul == -1) { + rewriter.replaceOpWithNewOp< + mlir::concretelang::Concrete::NegateLweCiphertextOp>( + op, op.getResult().getType(), op.getOperand(0)); + return mlir::success(); + } + } + return mlir::failure(); + } +}; + +// Optimization pass that should choose more efficient ways of performing crypto +// operations. +class ConcreteOptimizationPass + : public ConcreteOptimizationBase { +public: + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + + if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { + this->signalPassFailure(); + } + } +}; + +} // end anonymous namespace + +std::unique_ptr> createConcreteOptimizationPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index c3f315af9..2d16787cd 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -29,6 +29,7 @@ add_mlir_library(ConcretelangSupport FHEDialectAnalysis RTDialectAnalysis ConcretelangTransforms + ConcreteDialectTransforms concrete_optimizer MLIRExecutionEngine diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index e399971db..4a632caed 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -246,6 +246,14 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { .failed()) { return errorDiag("Lowering from TFHE to Concrete failed"); } + + // Optimizing Concrete + if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module, + this->enablePass) + .failed()) { + return errorDiag("Optimizing Concrete failed"); + } + if (target == Target::CONCRETE) return std::move(res); @@ -280,6 +288,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { } } + // Optimize Concrete + if (mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module, + this->enablePass) + .failed()) { + return StreamStringError("Optimizing Concrete failed"); + } + // Concrete -> BConcrete if (mlir::concretelang::pipeline::lowerConcreteToBConcrete( mlirContext, module, this->enablePass, loopParallelize) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 6c63e09af..0e856e8ec 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -186,6 +187,17 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("ConcreteOptimization", pm, context); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConcreteOptimizationPass(), enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, diff --git a/compiler/tests/Dialect/Concrete/Concrete/optimization.mlir b/compiler/tests/Dialect/Concrete/Concrete/optimization.mlir new file mode 100644 index 000000000..14523d1d3 --- /dev/null +++ b/compiler/tests/Dialect/Concrete/Concrete/optimization.mlir @@ -0,0 +1,33 @@ +// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s + + +// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> + func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> + + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + return %1: !Concrete.lwe_ciphertext<2048,7> + } + +// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> +func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.zero"() : () -> !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> + + %0 = arith.constant 0 : i7 + %1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7> + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + return %2: !Concrete.lwe_ciphertext<2048,7> +} + +// CHECK-LABEL: func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> +func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> + + %0 = arith.constant -1 : i7 + %1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7> + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + return %2: !Concrete.lwe_ciphertext<2048,7> +}