feat(compiler): add support for dynamic luts in compiler.

This commit is contained in:
aPere3
2023-08-29 17:32:47 +02:00
committed by Alexandre Péré
parent 083ab1103f
commit 1e726a50ed
8 changed files with 177 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
add_subdirectory(EncryptedMulToDoubleTLU)
add_subdirectory(DynamicTLU)
add_subdirectory(BigInt)
add_subdirectory(Boolean)
add_subdirectory(Max)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS DynamicTLU.td)
mlir_tablegen(DynamicTLU.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(DynamicTLUPassIncGen)
add_dependencies(mlir-headers DynamicTLUPassIncGen)

View File

@@ -0,0 +1,23 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H
#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> createDynamicTLUPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,11 @@
#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS
#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS
include "mlir/Pass/PassBase.td"
def DynamicTLU : Pass<"DynamicTLU", "::mlir::func::FuncOp"> {
let summary = "Enable table lookups with luts of arbitrary integer precision.";
let constructor = "mlir::concretelang::createDynamicTLUPass()";
}
#endif

View File

@@ -7,6 +7,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
@@ -270,8 +271,11 @@ mlir::LogicalResult GenGateOp::verify() {
emitErrorBadLutSize(*this, "lut", "ct", expectedSize, width);
return mlir::failure();
}
if (!lut.getElementType().isInteger(64)) {
this->emitOpError() << "should have the i64 constant";
auto elmType = lut.getElementType();
if (!elmType.isSignlessInteger() || elmType.getIntOrFloatBitWidth() > 64) {
this->emitOpError() << "lut must have signless integer elements, with "
"precision not bigger than 64.";
this->emitOpError() << "got : " << elmType.getIntOrFloatBitWidth();
return mlir::failure();
}
return mlir::success();

View File

@@ -4,6 +4,7 @@ add_mlir_library(
Boolean.cpp
Max.cpp
EncryptedMulToDoubleTLU.cpp
DynamicTLU.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS

View File

@@ -0,0 +1,129 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/Dialect/FHE/Analysis/utils.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h>
#include <concretelang/Support/Constants.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Transforms/DialectConversion.h>
#include <unordered_set>
using namespace mlir::concretelang::FHE;
namespace mlir {
namespace concretelang {
namespace {
struct ApplyLookupTableEintOpPattern
: public mlir::OpConversionPattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context)
: mlir::OpConversionPattern<FHE::ApplyLookupTableEintOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// When lowered to the TFHE dialect, the table will need to be properly
// encoded by a function specific to the kind of table lookup executed. This
// function expects the input lut to use 64 bit integers. For this reason,
// every lut that use integers of smaller precision needs to be extended to
// 64 bits first.
bool outputIsSigned =
op.getResult().getType().cast<FHE::FheIntegerInterface>().isSigned();
auto inputLutType = op.getLut().getType();
mlir::Value extendedLut;
if (inputLutType.getElementType().getIntOrFloatBitWidth() == 64) {
extendedLut = adaptor.getLut();
} else {
// This is implemented as a map since the `arith.extsi` is not
// bufferizable :(
mlir::Value init = rewriter.create<mlir::bufferization::AllocTensorOp>(
op.getLoc(),
mlir::RankedTensorType::get(inputLutType.getShape(),
rewriter.getI64Type()),
mlir::ValueRange{});
extendedLut =
rewriter
.create<mlir::linalg::MapOp>(
op.getLoc(), mlir::ValueRange{adaptor.getLut()}, init,
[&](mlir::OpBuilder &builder, mlir::Location loc,
mlir::ValueRange args) {
mlir::Value extended;
if (outputIsSigned) {
extended = builder.create<mlir::arith::ExtSIOp>(
loc, builder.getI64Type(), args[0]);
} else {
extended = builder.create<mlir::arith::ExtUIOp>(
loc, builder.getI64Type(), args[0]);
}
builder.create<mlir::linalg::YieldOp>(
loc, mlir::ValueRange{extended});
})
->getResult(0);
}
auto newOp = rewriter.replaceOpWithNewOp<FHE::ApplyLookupTableEintOp>(
op, op.getResult().getType(), op.getA(), extendedLut);
// Propagating the Oid if any ...
auto optimizerIdAttr = op->getAttr("TFHE.OId");
if (optimizerIdAttr != nullptr)
newOp->setAttr("TFHE.OId", optimizerIdAttr);
return mlir::success();
};
};
} // namespace
class DynamicTLU : public DynamicTLUBase<DynamicTLU> {
public:
void runOnOperation() override {
mlir::func::FuncOp funcOp = getOperation();
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::linalg::MapOp, mlir::linalg::YieldOp,
mlir::bufferization::AllocTensorOp>();
target.addLegalDialect<FHE::FHEDialect>();
target.addDynamicallyLegalOp<FHE::ApplyLookupTableEintOp>(
[&](FHE::ApplyLookupTableEintOp op) {
return op.getLut()
.getType()
.getElementType()
.getIntOrFloatBitWidth() == 64;
});
mlir::RewritePatternSet patterns(funcOp->getContext());
patterns.add<ApplyLookupTableEintOpPattern>(funcOp->getContext());
if (mlir::applyPartialConversion(funcOp, target, std::move(patterns))
.failed()) {
funcOp->emitError("Failed to extend dynamic luts.");
this->signalPassFailure();
}
}
};
std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>>
createDynamicTLUPass() {
return std::make_unique<DynamicTLU>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -29,6 +29,7 @@
#include <mlir/Transforms/Passes.h>
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
#include "concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include <concretelang/Conversion/Passes.h>
@@ -188,6 +189,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass);
addPotentiallyNestedPass(pm, createDynamicTLUPass(), enablePass);
return pm.run(module.getOperation());
}