From 1e726a50edab8cf7df737d7658f6a7aacfe70648 Mon Sep 17 00:00:00 2001 From: aPere3 Date: Tue, 29 Aug 2023 17:32:47 +0200 Subject: [PATCH] feat(compiler): add support for dynamic luts in compiler. --- .../Dialect/FHE/Transforms/CMakeLists.txt | 1 + .../FHE/Transforms/DynamicTLU/CMakeLists.txt | 4 + .../FHE/Transforms/DynamicTLU/DynamicTLU.h | 23 ++++ .../FHE/Transforms/DynamicTLU/DynamicTLU.td | 11 ++ .../compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 8 +- .../lib/Dialect/FHE/Transforms/CMakeLists.txt | 1 + .../lib/Dialect/FHE/Transforms/DynamicTLU.cpp | 129 ++++++++++++++++++ .../compiler/lib/Support/Pipeline.cpp | 2 + 8 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.td create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/DynamicTLU.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt index 1d893582b..9c798458c 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(EncryptedMulToDoubleTLU) +add_subdirectory(DynamicTLU) add_subdirectory(BigInt) add_subdirectory(Boolean) add_subdirectory(Max) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/CMakeLists.txt new file mode 100644 index 000000000..c2b6e7981 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS DynamicTLU.td) +mlir_tablegen(DynamicTLU.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(DynamicTLUPassIncGen) +add_dependencies(mlir-headers DynamicTLUPassIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h new file mode 100644 index 000000000..31e961ee3 --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h @@ -0,0 +1,23 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H +#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H + +#include +#include +#include + +#define GEN_PASS_CLASSES + +#include + +namespace mlir { +namespace concretelang { +std::unique_ptr> createDynamicTLUPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.td new file mode 100644 index 000000000..c2b3adfcd --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.td @@ -0,0 +1,11 @@ +#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS +#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS + +include "mlir/Pass/PassBase.td" + +def DynamicTLU : Pass<"DynamicTLU", "::mlir::func::FuncOp"> { + let summary = "Enable table lookups with luts of arbitrary integer precision."; + let constructor = "mlir::concretelang::createDynamicTLUPass()"; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 4a4c20d9d..aadb6d0b0 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -7,6 +7,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" @@ -270,8 +271,11 @@ mlir::LogicalResult GenGateOp::verify() { emitErrorBadLutSize(*this, "lut", "ct", expectedSize, width); return mlir::failure(); } - if (!lut.getElementType().isInteger(64)) { - this->emitOpError() << "should have the i64 constant"; + auto elmType = lut.getElementType(); + if (!elmType.isSignlessInteger() || elmType.getIntOrFloatBitWidth() > 64) { + this->emitOpError() << "lut must have signless integer elements, with " + "precision not bigger than 64."; + this->emitOpError() << "got : " << elmType.getIntOrFloatBitWidth(); return mlir::failure(); } return mlir::success(); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt index dc2b5cd41..e15a7fa85 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_library( Boolean.cpp Max.cpp EncryptedMulToDoubleTLU.cpp + DynamicTLU.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE DEPENDS diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/DynamicTLU.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/DynamicTLU.cpp new file mode 100644 index 000000000..871c6d087 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Transforms/DynamicTLU.cpp @@ -0,0 +1,129 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir::concretelang::FHE; + +namespace mlir { +namespace concretelang { +namespace { + +struct ApplyLookupTableEintOpPattern + : public mlir::OpConversionPattern { + + ApplyLookupTableEintOpPattern(mlir::MLIRContext *context) + : mlir::OpConversionPattern( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ApplyLookupTableEintOp op, + FHE::ApplyLookupTableEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // When lowered to the TFHE dialect, the table will need to be properly + // encoded by a function specific to the kind of table lookup executed. This + // function expects the input lut to use 64 bit integers. For this reason, + // every lut that use integers of smaller precision needs to be extended to + // 64 bits first. + + bool outputIsSigned = + op.getResult().getType().cast().isSigned(); + auto inputLutType = op.getLut().getType(); + + mlir::Value extendedLut; + if (inputLutType.getElementType().getIntOrFloatBitWidth() == 64) { + extendedLut = adaptor.getLut(); + } else { + // This is implemented as a map since the `arith.extsi` is not + // bufferizable :( + mlir::Value init = rewriter.create( + op.getLoc(), + mlir::RankedTensorType::get(inputLutType.getShape(), + rewriter.getI64Type()), + mlir::ValueRange{}); + + extendedLut = + rewriter + .create( + op.getLoc(), mlir::ValueRange{adaptor.getLut()}, init, + [&](mlir::OpBuilder &builder, mlir::Location loc, + mlir::ValueRange args) { + mlir::Value extended; + if (outputIsSigned) { + extended = builder.create( + loc, builder.getI64Type(), args[0]); + } else { + extended = builder.create( + loc, builder.getI64Type(), args[0]); + } + builder.create( + loc, mlir::ValueRange{extended}); + }) + ->getResult(0); + } + + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getA(), extendedLut); + + // Propagating the Oid if any ... + auto optimizerIdAttr = op->getAttr("TFHE.OId"); + if (optimizerIdAttr != nullptr) + newOp->setAttr("TFHE.OId", optimizerIdAttr); + + return mlir::success(); + }; +}; + +} // namespace + +class DynamicTLU : public DynamicTLUBase { + +public: + void runOnOperation() override { + mlir::func::FuncOp funcOp = getOperation(); + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](FHE::ApplyLookupTableEintOp op) { + return op.getLut() + .getType() + .getElementType() + .getIntOrFloatBitWidth() == 64; + }); + + mlir::RewritePatternSet patterns(funcOp->getContext()); + + patterns.add(funcOp->getContext()); + + if (mlir::applyPartialConversion(funcOp, target, std::move(patterns)) + .failed()) { + funcOp->emitError("Failed to extend dynamic luts."); + this->signalPassFailure(); + } + } +}; + +std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> +createDynamicTLUPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index a966def1b..228fc232a 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -29,6 +29,7 @@ #include #include "concretelang/Conversion/TFHEKeyNormalization/Pass.h" +#include "concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" #include @@ -188,6 +189,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass); addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass); + addPotentiallyNestedPass(pm, createDynamicTLUPass(), enablePass); return pm.run(module.getOperation()); }