mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: lower FHE.add on eint64 to ops on smaller chunks
this is a first commit to support operations on U64 by decomposing them into smaller chunks (32 chunks of 2 bits). This commit introduce the lowering pass that will be later populated to support other operations.
This commit is contained in:
@@ -25,6 +25,11 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
|
||||
/*description=*/"Get whether the integer is unsigned.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isUnsigned"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isChunked"
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return false; }
|
||||
bool isUnsigned() const { return true; }
|
||||
bool isChunked() const { return false; }
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -61,12 +62,43 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return true; }
|
||||
bool isUnsigned() const { return false; }
|
||||
bool isChunked() const { return false; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger",
|
||||
[MemRefElementTypeInterface, FheIntegerInterface]> {
|
||||
let mnemonic = "chunked_eint";
|
||||
|
||||
let summary = "An encrypted integer composed of multiple chunks";
|
||||
|
||||
let description = [{
|
||||
An encrypted integer composed of multiple chunks.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
!FHE.chunked_eint<64>
|
||||
!FHE.chunked_eint<32>
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (ins "unsigned":$width);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return false; }
|
||||
bool isUnsigned() const { return true; }
|
||||
bool isChunked() const { return true; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_AnyEncryptedInteger : Type<Or<[
|
||||
FHE_EncryptedIntegerType.predicate,
|
||||
FHE_EncryptedSignedIntegerType.predicate
|
||||
FHE_EncryptedSignedIntegerType.predicate,
|
||||
FHE_ChunkedEncryptedIntegerType.predicate
|
||||
]>>;
|
||||
|
||||
def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_FHE_BIGINT_PASS_H
|
||||
#define CONCRETELANG_FHE_BIGINT_PASS_H
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>>
|
||||
createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_FHE_BIGINT_PASS
|
||||
#define CONCRETELANG_FHE_BIGINT_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHEBigIntTransform : Pass<"fhe-big-int-transform"> {
|
||||
let summary = "Transform FHE operations on big integer into operations on chunks of small integer";
|
||||
let constructor = "mlir::concretelang::createFHEBigIntTransformPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS BigInt.td)
|
||||
mlir_tablegen(BigInt.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBigIntPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBigIntPassIncGen)
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h.inc>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Boolean.td)
|
||||
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
|
||||
@@ -1,8 +1,6 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Boolean.td)
|
||||
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
|
||||
set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td)
|
||||
mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
|
||||
add_subdirectory(BigInt)
|
||||
add_subdirectory(Boolean)
|
||||
|
||||
@@ -71,13 +71,20 @@ struct CompilationOptions {
|
||||
|
||||
optimizer::Config optimizerConfig;
|
||||
|
||||
/// When decomposing big integers into chunks, chunkSize is the total number
|
||||
/// of bits used for the message, including the carry, while chunkWidth is
|
||||
/// only the number of bits used during encoding and decoding of a big integer
|
||||
unsigned int chunkSize;
|
||||
unsigned int chunkWidth;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
|
||||
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4),
|
||||
chunkWidth(2){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
clientParametersFuncName = funcname;
|
||||
|
||||
@@ -48,6 +48,11 @@ mlir::LogicalResult
|
||||
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
unsigned int chunkSize, unsigned int chunkWidth);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -89,3 +89,33 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) {
|
||||
|
||||
return getChecked(loc, loc.getContext(), width);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ChunkedEncryptedIntegerType::verify(
|
||||
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
|
||||
if (p == 0) {
|
||||
emitError() << "FHE.chunked_eint doesn't support precision of 0";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<" << getWidth() << ">";
|
||||
}
|
||||
|
||||
mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) {
|
||||
if (p.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
int width;
|
||||
|
||||
if (p.parseInteger(width))
|
||||
return mlir::Type();
|
||||
|
||||
if (p.parseGreater())
|
||||
return mlir::Type();
|
||||
|
||||
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
|
||||
|
||||
return getChecked(loc, loc.getContext(), width);
|
||||
}
|
||||
|
||||
@@ -29,14 +29,25 @@ bool verifyEncryptedIntegerInputAndResultConsistency(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input.isChunked() != result.isChunked()) {
|
||||
op.emitOpError("should have the same composition (chunked or not) of "
|
||||
"encrypted input and result");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op,
|
||||
FheIntegerInterface &a,
|
||||
IntegerType &b) {
|
||||
|
||||
if (a.getWidth() + 1 != b.getWidth()) {
|
||||
if (a.isChunked()) {
|
||||
if (a.getWidth() != b.getWidth()) {
|
||||
op.emitOpError("should have the width of plain input equal to width of "
|
||||
"encrypted input");
|
||||
return false;
|
||||
}
|
||||
} else if (a.getWidth() + 1 != b.getWidth()) {
|
||||
op.emitOpError("should have the width of plain input equal to width of "
|
||||
"encrypted input + 1");
|
||||
return false;
|
||||
@@ -58,6 +69,12 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.isChunked() != b.isChunked()) {
|
||||
op.emitOpError("should have the same composition (chunked or not) of "
|
||||
"encrypted inputs");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
222
compiler/lib/Dialect/FHE/Transforms/BigInt.cpp
Normal file
222
compiler/lib/Dialect/FHE/Transforms/BigInt.cpp
Normal file
@@ -0,0 +1,222 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Affine/IR/AffineOps.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <concretelang/Conversion/Utils/Legality.h>
|
||||
#include <concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
/// Construct a table lookup to extract the carry bit
|
||||
mlir::Value getTruthTableCarryExtract(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc,
|
||||
unsigned int chunkSize,
|
||||
unsigned int chunkWidth) {
|
||||
auto tableSize = 1 << chunkSize;
|
||||
std::vector<llvm::APInt> values;
|
||||
values.reserve(tableSize);
|
||||
for (auto i = 0; i < tableSize; i++) {
|
||||
if (i < 1 << chunkWidth)
|
||||
values.push_back(llvm::APInt(1, 0, false));
|
||||
else
|
||||
values.push_back(llvm::APInt(1, 1, false));
|
||||
}
|
||||
auto truthTableAttr = mlir::DenseElementsAttr::get(
|
||||
mlir::RankedTensorType::get({tableSize}, rewriter.getIntegerType(64)),
|
||||
values);
|
||||
auto truthTable =
|
||||
rewriter.create<mlir::arith::ConstantOp>(loc, truthTableAttr);
|
||||
return truthTable.getResult();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
namespace typing {
|
||||
|
||||
/// Converts `FHE::ChunkedEncryptedInteger` into a tensor of
|
||||
/// `FHE::EncryptedInteger`.
|
||||
mlir::RankedTensorType
|
||||
convertChunkedEint(mlir::MLIRContext *context,
|
||||
FHE::ChunkedEncryptedIntegerType chunkedEint,
|
||||
unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
auto eint = FHE::EncryptedIntegerType::get(context, chunkSize);
|
||||
auto bigIntWidth = chunkedEint.getWidth();
|
||||
assert(bigIntWidth % chunkWidth == 0 &&
|
||||
"chunkWidth must divide width of the big integer");
|
||||
auto numberOfChunks = bigIntWidth / chunkWidth;
|
||||
std::vector<int64_t> shape({numberOfChunks});
|
||||
return mlir::RankedTensorType::get(shape, eint);
|
||||
}
|
||||
|
||||
/// The type converter used to transform `FHE` ops on chunked integers
|
||||
class TypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([chunkSize,
|
||||
chunkWidth](FHE::ChunkedEncryptedIntegerType type) {
|
||||
return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace typing
|
||||
|
||||
class AddEintPattern
|
||||
: public mlir::OpConversionPattern<mlir::concretelang::FHE::AddEintOp> {
|
||||
public:
|
||||
AddEintPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
unsigned int chunkSize, unsigned int chunkWidth)
|
||||
: mlir::OpConversionPattern<mlir::concretelang::FHE::AddEintOp>(
|
||||
converter, context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
|
||||
chunkSize(chunkSize), chunkWidth(chunkWidth) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
auto tensorType = adaptor.a().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
assert(shape.size() == 1 &&
|
||||
"chunked integer should be converted to flat tensors, but tensor "
|
||||
"have more than one dimension");
|
||||
auto eintChunkWidth = tensorType.getElementType()
|
||||
.dyn_cast<FHE::EncryptedIntegerType>()
|
||||
.getWidth();
|
||||
assert(eintChunkWidth == chunkSize && "wrong tensor elements width");
|
||||
auto numberOfChunks = shape[0];
|
||||
|
||||
mlir::Value carry =
|
||||
rewriter
|
||||
.create<FHE::ZeroEintOp>(op.getLoc(),
|
||||
FHE::EncryptedIntegerType::get(
|
||||
rewriter.getContext(), chunkSize))
|
||||
.getResult();
|
||||
|
||||
mlir::Value resultTensor =
|
||||
rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.a().getType())
|
||||
.getResult();
|
||||
// used to shift the carry bit to the left
|
||||
mlir::Value twoPowerChunkSizeCst =
|
||||
rewriter
|
||||
.create<mlir::arith::ConstantIntOp>(op.getLoc(), 1 << chunkWidth,
|
||||
chunkSize + 1)
|
||||
.getResult();
|
||||
// Create the loop
|
||||
int64_t lb = 0, step = 1;
|
||||
auto forOp = rewriter.create<mlir::AffineForOp>(
|
||||
op.getLoc(), lb, numberOfChunks, step, resultTensor,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
// add inputs with the previous carry (init to 0)
|
||||
mlir::Value leftEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.a(), iter);
|
||||
mlir::Value rightEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.b(), iter);
|
||||
mlir::Value result =
|
||||
builder.create<FHE::AddEintOp>(loc, leftEint, rightEint)
|
||||
.getResult();
|
||||
mlir::Value resultWithCarry =
|
||||
builder.create<FHE::AddEintOp>(loc, result, carry).getResult();
|
||||
// compute the new carry: either 1 or 0
|
||||
carry =
|
||||
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
op.getLoc(),
|
||||
FHE::EncryptedIntegerType::get(rewriter.getContext(),
|
||||
chunkSize),
|
||||
resultWithCarry,
|
||||
getTruthTableCarryExtract(rewriter, op.getLoc(), chunkSize,
|
||||
chunkWidth));
|
||||
// remove the carry bit from the result
|
||||
mlir::Value shiftedCarry =
|
||||
builder
|
||||
.create<FHE::MulEintIntOp>(loc, carry, twoPowerChunkSizeCst)
|
||||
.getResult();
|
||||
mlir::Value finalResult =
|
||||
builder.create<FHE::SubEintOp>(loc, resultWithCarry, shiftedCarry)
|
||||
.getResult();
|
||||
// insert the result in the result tensor
|
||||
mlir::Value tensorResult = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, finalResult, args[0], iter);
|
||||
builder.create<mlir::AffineYieldOp>(loc, tensorResult);
|
||||
});
|
||||
rewriter.replaceOp(op, forOp.getResult(0));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned int chunkSize, chunkWidth;
|
||||
};
|
||||
|
||||
/// Perfoms the transformation of big integer operations
|
||||
class FHEBigIntTransformPass
|
||||
: public FHEBigIntTransformBase<FHEBigIntTransformPass> {
|
||||
public:
|
||||
FHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth)
|
||||
: chunkSize(chunkSize), chunkWidth(chunkWidth){};
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::Operation *op = getOperation();
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
typing::TypeConverter converter(chunkSize, chunkWidth);
|
||||
|
||||
// Legal ops created during pattern application
|
||||
target.addLegalOp<mlir::AffineForOp, mlir::AffineYieldOp,
|
||||
mlir::arith::ConstantOp, mlir::arith::ConstantIndexOp,
|
||||
FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintOp,
|
||||
FHE::MulEintIntOp, FHE::SubEintOp,
|
||||
FHE::ApplyLookupTableEintOp, mlir::tensor::ExtractOp,
|
||||
mlir::tensor::InsertOp>();
|
||||
concretelang::addDynamicallyLegalTypeOp<FHE::AddEintOp>(target, converter);
|
||||
// Func ops are only legal with converted types
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::func::ReturnOp>>(patterns.getContext(), converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
|
||||
converter);
|
||||
|
||||
patterns.add<AddEintPattern>(converter, &getContext(), chunkSize,
|
||||
chunkWidth);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned int chunkSize, chunkWidth;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>>
|
||||
createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
assert(chunkSize >= chunkWidth + 1 &&
|
||||
"chunkSize must be greater than chunkWidth");
|
||||
return std::make_unique<FHEBigIntTransformPass>(chunkSize, chunkWidth);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
|
||||
#include <concretelang/Support/Constants.h>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_mlir_library(
|
||||
FHEDialectTransforms
|
||||
BigInt.cpp
|
||||
Boolean.cpp
|
||||
EncryptedMulToDoubleTLU.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
@@ -296,6 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return StreamStringError("Rewriting of encrypted mul failed");
|
||||
}
|
||||
|
||||
if (mlir::concretelang::pipeline::transformFHEBigInt(
|
||||
mlirContext, module, enablePass, options.chunkSize,
|
||||
options.chunkWidth)
|
||||
.failed()) {
|
||||
return errorDiag("Transforming FHE big integer ops failed");
|
||||
}
|
||||
|
||||
// FHE High level pass to determine FHE parameters
|
||||
if (auto err = this->determineFHEParameters(res))
|
||||
return std::move(err);
|
||||
@@ -414,8 +421,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
mlirContext, module, enablePass,
|
||||
options.unrollLoopsWithSDFGConvertibleOps)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Extraction of SDFG operations from BConcrete representation failed");
|
||||
return errorDiag("Extraction of SDFG operations from BConcrete "
|
||||
"representation failed");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,8 +434,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
||||
return errorDiag("Lowering from Bufferized Concrete to canonical MLIR "
|
||||
"dialects failed");
|
||||
}
|
||||
|
||||
// SDFG -> Canonical dialects
|
||||
@@ -595,16 +602,17 @@ CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) {
|
||||
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
|
||||
const std::string CompilerEngine::Library::LINKER = "ld";
|
||||
#ifdef __APPLE__
|
||||
// We need to tell the linker that some symbols will be missing during linking,
|
||||
// this symbols should be available during runtime however. This is the case
|
||||
// when JIT compiling, the JIT should either link to the runtime library that
|
||||
// has the missing symbols, or it would have been loaded even prior to that.
|
||||
// Starting from Mac 11 (Big Sur), it appears we need to add -L
|
||||
// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for the
|
||||
// sharedlib to link properly.
|
||||
// We need to tell the linker that some symbols will be missing during
|
||||
// linking, this symbols should be available during runtime however. This is
|
||||
// the case when JIT compiling, the JIT should either link to the runtime
|
||||
// library that has the missing symbols, or it would have been loaded even
|
||||
// prior to that. Starting from Mac 11 (Big Sur), it appears we need to add -L
|
||||
// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for
|
||||
// the sharedlib to link properly.
|
||||
const std::string CompilerEngine::Library::LINKER_SHARED_OPT =
|
||||
" -dylib -undefined dynamic_lookup -L "
|
||||
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem -o ";
|
||||
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem "
|
||||
"-o ";
|
||||
const std::string CompilerEngine::Library::DOT_SHARED_LIB_EXT = ".dylib";
|
||||
#else // Linux
|
||||
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
|
||||
@@ -798,7 +806,8 @@ llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
|
||||
std::vector<std::string> extraArgs;
|
||||
std::string fullRuntimeLibraryName = "";
|
||||
#ifdef __APPLE__
|
||||
// to issue the command for fixing the runtime dependency of the generated lib
|
||||
// to issue the command for fixing the runtime dependency of the generated
|
||||
// lib
|
||||
bool fixRuntimeDep = false;
|
||||
#endif
|
||||
if (!runtimeLibraryPath.empty()) {
|
||||
@@ -841,12 +850,13 @@ llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
|
||||
sharedLibraryPath = path.get();
|
||||
#ifdef __APPLE__
|
||||
// when dellocate is used to include dependencies in python wheels, the
|
||||
// runtime library will have an id that is prefixed with /DLC, and that path
|
||||
// doesn't exist. So when generated libraries won't be able to find it
|
||||
// during load time. To solve this, we change the dep in the generated
|
||||
// library to be relative to the rpath which should be set correctly during
|
||||
// linking. This shouldn't have an impact when /DLC/concrete/.dylibs/* isn't
|
||||
// a dependecy in the first place (when not using python).
|
||||
// runtime library will have an id that is prefixed with /DLC, and that
|
||||
// path doesn't exist. So when generated libraries won't be able to find
|
||||
// it during load time. To solve this, we change the dep in the generated
|
||||
// library to be relative to the rpath which should be set correctly
|
||||
// during linking. This shouldn't have an impact when
|
||||
// /DLC/concrete/.dylibs/* isn't a dependecy in the first place (when not
|
||||
// using python).
|
||||
if (fixRuntimeDep) {
|
||||
std::string fixRuntimeDepCmd = "install_name_tool -change "
|
||||
"/DLC/concrete/.dylibs/" +
|
||||
|
||||
@@ -35,7 +35,8 @@
|
||||
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
|
||||
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
|
||||
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
|
||||
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
|
||||
@@ -218,6 +219,23 @@ transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
mlir::PassManager pm(&context);
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createFHEBigIntTransformPass(chunkSize, chunkWidth),
|
||||
enablePass);
|
||||
// We want to fully unroll for loops introduced by the BigInt transform since
|
||||
// MANP doesn't support loops. This is a workaround that make the IR much
|
||||
// bigger than it should be
|
||||
addPotentiallyNestedPass(pm, mlir::createLoopUnrollPass(-1, false, true),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -206,6 +206,18 @@ llvm::cl::list<uint64_t>
|
||||
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
|
||||
llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::opt<unsigned int> chunkSize(
|
||||
"chunk-size",
|
||||
llvm::cl::desc(
|
||||
"Chunk size while decomposing big integers into chunks, default is 4"),
|
||||
llvm::cl::init<unsigned int>(4));
|
||||
|
||||
llvm::cl::opt<unsigned int> chunkWidth(
|
||||
"chunk-width",
|
||||
llvm::cl::desc(
|
||||
"Chunk width while decomposing big integers into chunks, default is 2"),
|
||||
llvm::cl::init<unsigned int>(2));
|
||||
|
||||
llvm::cl::opt<std::string> jitKeySetCachePath(
|
||||
"jit-keyset-cache-path",
|
||||
llvm::cl::desc("Path to cache KeySet content (unsecure)"));
|
||||
@@ -338,6 +350,8 @@ cmdlineCompilationOptions() {
|
||||
cmdline::unrollLoopsWithSDFGConvertibleOps;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
options.chunkSize = cmdline::chunkSize;
|
||||
options.chunkWidth = cmdline::chunkWidth;
|
||||
|
||||
if (!cmdline::v0Constraint.empty()) {
|
||||
if (cmdline::v0Constraint.size() != 2) {
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>>
|
||||
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5
|
||||
// CHECK-NEXT: %[[V2:.*]] = affine.for %arg2 = 0 to 32 iter_args(%arg3 = %[[V1]]) -> (tensor<32x!FHE.eint<4>>) {
|
||||
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %arg0[%arg2] : tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %arg1[%arg2] : tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint"(%[[V3]], %[[V4]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[V0]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]> : tensor<16xi64>
|
||||
// CHECK-NEXT: %[[V7:.*]] = "FHE.apply_lookup_table"(%[[V6]], %[[cst]]) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[c4_i5]]) : (!FHE.eint<4>, i5) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V9:.*]] = "FHE.sub_eint"(%[[V6]], %[[V8]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V10:.*]] = tensor.insert %[[V9]] into %arg3[%arg2] : tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: affine.yield %[[V10]] : tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
34
compiler/tests/check_tests/Dialect/FHE/big_int.mlir
Normal file
34
compiler/tests/check_tests/Dialect/FHE/big_int.mlir
Normal file
@@ -0,0 +1,34 @@
|
||||
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
|
||||
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
|
||||
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// TODO: max/min
|
||||
Reference in New Issue
Block a user