mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
feat(compiler): add support for dynamic luts in compiler.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
add_subdirectory(EncryptedMulToDoubleTLU)
|
||||
add_subdirectory(DynamicTLU)
|
||||
add_subdirectory(BigInt)
|
||||
add_subdirectory(Boolean)
|
||||
add_subdirectory(Max)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS DynamicTLU.td)
|
||||
mlir_tablegen(DynamicTLU.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(DynamicTLUPassIncGen)
|
||||
add_dependencies(mlir-headers DynamicTLUPassIncGen)
|
||||
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H
|
||||
#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
|
||||
#include <concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> createDynamicTLUPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef CONCRETELANG_FHE_DYNAMIC_TLU_PASS
|
||||
#define CONCRETELANG_FHE_DYNAMIC_TLU_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def DynamicTLU : Pass<"DynamicTLU", "::mlir::func::FuncOp"> {
|
||||
let summary = "Enable table lookups with luts of arbitrary integer precision.";
|
||||
let constructor = "mlir::concretelang::createDynamicTLUPass()";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
@@ -270,8 +271,11 @@ mlir::LogicalResult GenGateOp::verify() {
|
||||
emitErrorBadLutSize(*this, "lut", "ct", expectedSize, width);
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!lut.getElementType().isInteger(64)) {
|
||||
this->emitOpError() << "should have the i64 constant";
|
||||
auto elmType = lut.getElementType();
|
||||
if (!elmType.isSignlessInteger() || elmType.getIntOrFloatBitWidth() > 64) {
|
||||
this->emitOpError() << "lut must have signless integer elements, with "
|
||||
"precision not bigger than 64.";
|
||||
this->emitOpError() << "got : " << elmType.getIntOrFloatBitWidth();
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
|
||||
@@ -4,6 +4,7 @@ add_mlir_library(
|
||||
Boolean.cpp
|
||||
Max.cpp
|
||||
EncryptedMulToDoubleTLU.cpp
|
||||
DynamicTLU.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Dialect/FHE/Analysis/utils.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <unordered_set>
|
||||
|
||||
using namespace mlir::concretelang::FHE;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace {
|
||||
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public mlir::OpConversionPattern<FHE::ApplyLookupTableEintOp> {
|
||||
|
||||
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpConversionPattern<FHE::ApplyLookupTableEintOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
|
||||
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// When lowered to the TFHE dialect, the table will need to be properly
|
||||
// encoded by a function specific to the kind of table lookup executed. This
|
||||
// function expects the input lut to use 64 bit integers. For this reason,
|
||||
// every lut that use integers of smaller precision needs to be extended to
|
||||
// 64 bits first.
|
||||
|
||||
bool outputIsSigned =
|
||||
op.getResult().getType().cast<FHE::FheIntegerInterface>().isSigned();
|
||||
auto inputLutType = op.getLut().getType();
|
||||
|
||||
mlir::Value extendedLut;
|
||||
if (inputLutType.getElementType().getIntOrFloatBitWidth() == 64) {
|
||||
extendedLut = adaptor.getLut();
|
||||
} else {
|
||||
// This is implemented as a map since the `arith.extsi` is not
|
||||
// bufferizable :(
|
||||
mlir::Value init = rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
op.getLoc(),
|
||||
mlir::RankedTensorType::get(inputLutType.getShape(),
|
||||
rewriter.getI64Type()),
|
||||
mlir::ValueRange{});
|
||||
|
||||
extendedLut =
|
||||
rewriter
|
||||
.create<mlir::linalg::MapOp>(
|
||||
op.getLoc(), mlir::ValueRange{adaptor.getLut()}, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extended;
|
||||
if (outputIsSigned) {
|
||||
extended = builder.create<mlir::arith::ExtSIOp>(
|
||||
loc, builder.getI64Type(), args[0]);
|
||||
} else {
|
||||
extended = builder.create<mlir::arith::ExtUIOp>(
|
||||
loc, builder.getI64Type(), args[0]);
|
||||
}
|
||||
builder.create<mlir::linalg::YieldOp>(
|
||||
loc, mlir::ValueRange{extended});
|
||||
})
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
auto newOp = rewriter.replaceOpWithNewOp<FHE::ApplyLookupTableEintOp>(
|
||||
op, op.getResult().getType(), op.getA(), extendedLut);
|
||||
|
||||
// Propagating the Oid if any ...
|
||||
auto optimizerIdAttr = op->getAttr("TFHE.OId");
|
||||
if (optimizerIdAttr != nullptr)
|
||||
newOp->setAttr("TFHE.OId", optimizerIdAttr);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class DynamicTLU : public DynamicTLUBase<DynamicTLU> {
|
||||
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp funcOp = getOperation();
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
target.addLegalDialect<mlir::arith::ArithDialect>();
|
||||
target.addLegalOp<mlir::linalg::MapOp, mlir::linalg::YieldOp,
|
||||
mlir::bufferization::AllocTensorOp>();
|
||||
target.addLegalDialect<FHE::FHEDialect>();
|
||||
target.addDynamicallyLegalOp<FHE::ApplyLookupTableEintOp>(
|
||||
[&](FHE::ApplyLookupTableEintOp op) {
|
||||
return op.getLut()
|
||||
.getType()
|
||||
.getElementType()
|
||||
.getIntOrFloatBitWidth() == 64;
|
||||
});
|
||||
|
||||
mlir::RewritePatternSet patterns(funcOp->getContext());
|
||||
|
||||
patterns.add<ApplyLookupTableEintOpPattern>(funcOp->getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(funcOp, target, std::move(patterns))
|
||||
.failed()) {
|
||||
funcOp->emitError("Failed to extend dynamic luts.");
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>>
|
||||
createDynamicTLUPass() {
|
||||
return std::make_unique<DynamicTLU>();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -29,6 +29,7 @@
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include "concretelang/Conversion/TFHEKeyNormalization/Pass.h"
|
||||
#include "concretelang/Dialect/FHE/Transforms/DynamicTLU/DynamicTLU.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include <concretelang/Conversion/Passes.h>
|
||||
@@ -188,6 +189,7 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
addPotentiallyNestedPass(pm, createEncryptedMulToDoubleTLUPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, createFHEMaxTransformPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, createDynamicTLUPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user