feat: lower FHE.add on eint64 to ops on smaller chunks

this is a first commit to support operations on U64 by decomposing them
into smaller chunks (32 chunks of 2 bits). This commit introduce the
lowering pass that will be later populated to support other operations.
This commit is contained in:
youben11
2023-02-01 09:46:41 +01:00
committed by Ayoub Benaissa
parent fb680340f9
commit d41d14dbb8
21 changed files with 492 additions and 30 deletions

View File

@@ -25,6 +25,11 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
/*description=*/"Get whether the integer is unsigned.",
/*retTy=*/"bool",
/*methodName=*/"isUnsigned"
>,
InterfaceMethod<
/*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).",
/*retTy=*/"bool",
/*methodName=*/"isChunked"
>
];
}

View File

@@ -33,6 +33,7 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
let extraClassDeclaration = [{
bool isSigned() const { return false; }
bool isUnsigned() const { return true; }
bool isChunked() const { return false; }
}];
}
@@ -61,12 +62,43 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
let extraClassDeclaration = [{
bool isSigned() const { return true; }
bool isUnsigned() const { return false; }
bool isChunked() const { return false; }
}];
}
def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger",
[MemRefElementTypeInterface, FheIntegerInterface]> {
let mnemonic = "chunked_eint";
let summary = "An encrypted integer composed of multiple chunks";
let description = [{
An encrypted integer composed of multiple chunks.
Examples:
```mlir
!FHE.chunked_eint<64>
!FHE.chunked_eint<32>
```
}];
let parameters = (ins "unsigned":$width);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let extraClassDeclaration = [{
bool isSigned() const { return false; }
bool isUnsigned() const { return true; }
bool isChunked() const { return true; }
}];
}
def FHE_AnyEncryptedInteger : Type<Or<[
FHE_EncryptedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate
FHE_EncryptedSignedIntegerType.predicate,
FHE_ChunkedEncryptedIntegerType.predicate
]>>;
def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",

View File

@@ -0,0 +1,24 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_FHE_BIGINT_PASS_H
#define CONCRETELANG_FHE_BIGINT_PASS_H
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>>
createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_FHE_BIGINT_PASS
#define CONCRETELANG_FHE_BIGINT_PASS
include "mlir/Pass/PassBase.td"
def FHEBigIntTransform : Pass<"fhe-big-int-transform"> {
let summary = "Transform FHE operations on big integer into operations on chunks of small integer";
let constructor = "mlir::concretelang::createFHEBigIntTransformPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ];
}
#endif

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS BigInt.td)
mlir_tablegen(BigInt.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEBigIntPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEBigIntPassIncGen)

View File

@@ -10,7 +10,7 @@
#include <mlir/Pass/Pass.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/FHE/Transforms/Boolean.h.inc>
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h.inc>
namespace mlir {
namespace concretelang {

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Boolean.td)
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)

View File

@@ -1,8 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Boolean.td)
mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen)
add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen)
set(LLVM_TARGET_DEFINITIONS EncryptedMulToDoubleTLU.td)
mlir_tablegen(EncryptedMulToDoubleTLU.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(EncryptedMulToDoubleTLUPassIncGen)
add_dependencies(mlir-headers EncryptedMulToDoubleTLUPassIncGen)
add_subdirectory(BigInt)
add_subdirectory(Boolean)

View File

@@ -71,13 +71,20 @@ struct CompilationOptions {
optimizer::Config optimizerConfig;
/// When decomposing big integers into chunks, chunkSize is the total number
/// of bits used for the message, including the carry, while chunkWidth is
/// only the number of bits used during encoding and decoding of a big integer
unsigned int chunkSize;
unsigned int chunkWidth;
CompilationOptions()
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
clientParametersFuncName(llvm::None),
optimizerConfig(optimizer::DEFAULT_CONFIG){};
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkSize(4),
chunkWidth(2){};
CompilationOptions(std::string funcname) : CompilationOptions() {
clientParametersFuncName = funcname;

View File

@@ -48,6 +48,11 @@ mlir::LogicalResult
transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
unsigned int chunkSize, unsigned int chunkWidth);
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -89,3 +89,33 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) {
return getChecked(loc, loc.getContext(), width);
}
mlir::LogicalResult ChunkedEncryptedIntegerType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
if (p == 0) {
emitError() << "FHE.chunked_eint doesn't support precision of 0";
return mlir::failure();
}
return mlir::success();
}
void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const {
p << "<" << getWidth() << ">";
}
mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) {
if (p.parseLess())
return mlir::Type();
int width;
if (p.parseInteger(width))
return mlir::Type();
if (p.parseGreater())
return mlir::Type();
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
return getChecked(loc, loc.getContext(), width);
}

View File

@@ -29,14 +29,25 @@ bool verifyEncryptedIntegerInputAndResultConsistency(
return false;
}
if (input.isChunked() != result.isChunked()) {
op.emitOpError("should have the same composition (chunked or not) of "
"encrypted input and result");
return false;
}
return true;
}
bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op,
FheIntegerInterface &a,
IntegerType &b) {
if (a.getWidth() + 1 != b.getWidth()) {
if (a.isChunked()) {
if (a.getWidth() != b.getWidth()) {
op.emitOpError("should have the width of plain input equal to width of "
"encrypted input");
return false;
}
} else if (a.getWidth() + 1 != b.getWidth()) {
op.emitOpError("should have the width of plain input equal to width of "
"encrypted input + 1");
return false;
@@ -58,6 +69,12 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
return false;
}
if (a.isChunked() != b.isChunked()) {
op.emitOpError("should have the same composition (chunked or not) of "
"encrypted inputs");
return false;
}
return true;
}

View File

@@ -0,0 +1,222 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/DialectConversion.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <concretelang/Conversion/Utils/Legality.h>
#include <concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
#include <concretelang/Support/Constants.h>
namespace mlir {
namespace concretelang {
/// Construct a table lookup to extract the carry bit
mlir::Value getTruthTableCarryExtract(mlir::PatternRewriter &rewriter,
mlir::Location loc,
unsigned int chunkSize,
unsigned int chunkWidth) {
auto tableSize = 1 << chunkSize;
std::vector<llvm::APInt> values;
values.reserve(tableSize);
for (auto i = 0; i < tableSize; i++) {
if (i < 1 << chunkWidth)
values.push_back(llvm::APInt(1, 0, false));
else
values.push_back(llvm::APInt(1, 1, false));
}
auto truthTableAttr = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({tableSize}, rewriter.getIntegerType(64)),
values);
auto truthTable =
rewriter.create<mlir::arith::ConstantOp>(loc, truthTableAttr);
return truthTable.getResult();
}
namespace {
namespace typing {
/// Converts `FHE::ChunkedEncryptedInteger` into a tensor of
/// `FHE::EncryptedInteger`.
mlir::RankedTensorType
convertChunkedEint(mlir::MLIRContext *context,
FHE::ChunkedEncryptedIntegerType chunkedEint,
unsigned int chunkSize, unsigned int chunkWidth) {
auto eint = FHE::EncryptedIntegerType::get(context, chunkSize);
auto bigIntWidth = chunkedEint.getWidth();
assert(bigIntWidth % chunkWidth == 0 &&
"chunkWidth must divide width of the big integer");
auto numberOfChunks = bigIntWidth / chunkWidth;
std::vector<int64_t> shape({numberOfChunks});
return mlir::RankedTensorType::get(shape, eint);
}
/// The type converter used to transform `FHE` ops on chunked integers
class TypeConverter : public mlir::TypeConverter {
public:
TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) {
addConversion([](mlir::Type type) { return type; });
addConversion([chunkSize,
chunkWidth](FHE::ChunkedEncryptedIntegerType type) {
return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth);
});
}
};
} // namespace typing
class AddEintPattern
: public mlir::OpConversionPattern<mlir::concretelang::FHE::AddEintOp> {
public:
AddEintPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
unsigned int chunkSize, unsigned int chunkWidth)
: mlir::OpConversionPattern<mlir::concretelang::FHE::AddEintOp>(
converter, context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT),
chunkSize(chunkSize), chunkWidth(chunkWidth) {}
mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto tensorType = adaptor.a().getType().dyn_cast<mlir::RankedTensorType>();
auto shape = tensorType.getShape();
assert(shape.size() == 1 &&
"chunked integer should be converted to flat tensors, but tensor "
"have more than one dimension");
auto eintChunkWidth = tensorType.getElementType()
.dyn_cast<FHE::EncryptedIntegerType>()
.getWidth();
assert(eintChunkWidth == chunkSize && "wrong tensor elements width");
auto numberOfChunks = shape[0];
mlir::Value carry =
rewriter
.create<FHE::ZeroEintOp>(op.getLoc(),
FHE::EncryptedIntegerType::get(
rewriter.getContext(), chunkSize))
.getResult();
mlir::Value resultTensor =
rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), adaptor.a().getType())
.getResult();
// used to shift the carry bit to the left
mlir::Value twoPowerChunkSizeCst =
rewriter
.create<mlir::arith::ConstantIntOp>(op.getLoc(), 1 << chunkWidth,
chunkSize + 1)
.getResult();
// Create the loop
int64_t lb = 0, step = 1;
auto forOp = rewriter.create<mlir::AffineForOp>(
op.getLoc(), lb, numberOfChunks, step, resultTensor,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
// add inputs with the previous carry (init to 0)
mlir::Value leftEint =
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.a(), iter);
mlir::Value rightEint =
builder.create<mlir::tensor::ExtractOp>(loc, adaptor.b(), iter);
mlir::Value result =
builder.create<FHE::AddEintOp>(loc, leftEint, rightEint)
.getResult();
mlir::Value resultWithCarry =
builder.create<FHE::AddEintOp>(loc, result, carry).getResult();
// compute the new carry: either 1 or 0
carry =
rewriter.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
op.getLoc(),
FHE::EncryptedIntegerType::get(rewriter.getContext(),
chunkSize),
resultWithCarry,
getTruthTableCarryExtract(rewriter, op.getLoc(), chunkSize,
chunkWidth));
// remove the carry bit from the result
mlir::Value shiftedCarry =
builder
.create<FHE::MulEintIntOp>(loc, carry, twoPowerChunkSizeCst)
.getResult();
mlir::Value finalResult =
builder.create<FHE::SubEintOp>(loc, resultWithCarry, shiftedCarry)
.getResult();
// insert the result in the result tensor
mlir::Value tensorResult = builder.create<mlir::tensor::InsertOp>(
loc, finalResult, args[0], iter);
builder.create<mlir::AffineYieldOp>(loc, tensorResult);
});
rewriter.replaceOp(op, forOp.getResult(0));
return mlir::success();
}
private:
unsigned int chunkSize, chunkWidth;
};
/// Perfoms the transformation of big integer operations
class FHEBigIntTransformPass
: public FHEBigIntTransformBase<FHEBigIntTransformPass> {
public:
FHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth)
: chunkSize(chunkSize), chunkWidth(chunkWidth){};
void runOnOperation() override {
mlir::Operation *op = getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
typing::TypeConverter converter(chunkSize, chunkWidth);
// Legal ops created during pattern application
target.addLegalOp<mlir::AffineForOp, mlir::AffineYieldOp,
mlir::arith::ConstantOp, mlir::arith::ConstantIndexOp,
FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintOp,
FHE::MulEintIntOp, FHE::SubEintOp,
FHE::ApplyLookupTableEintOp, mlir::tensor::ExtractOp,
mlir::tensor::InsertOp>();
concretelang::addDynamicallyLegalTypeOp<FHE::AddEintOp>(target, converter);
// Func ops are only legal with converted types
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>>(patterns.getContext(), converter);
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
converter);
patterns.add<AddEintPattern>(converter, &getContext(), chunkSize,
chunkWidth);
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
private:
unsigned int chunkSize, chunkWidth;
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<>>
createFHEBigIntTransformPass(unsigned int chunkSize, unsigned int chunkWidth) {
assert(chunkSize >= chunkWidth + 1 &&
"chunkSize must be greater than chunkWidth");
return std::make_unique<FHEBigIntTransformPass>(chunkSize, chunkWidth);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -9,7 +9,7 @@
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
#include <concretelang/Support/Constants.h>
namespace mlir {

View File

@@ -1,5 +1,6 @@
add_mlir_library(
FHEDialectTransforms
BigInt.cpp
Boolean.cpp
EncryptedMulToDoubleTLU.cpp
ADDITIONAL_HEADER_DIRS

View File

@@ -296,6 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return StreamStringError("Rewriting of encrypted mul failed");
}
if (mlir::concretelang::pipeline::transformFHEBigInt(
mlirContext, module, enablePass, options.chunkSize,
options.chunkWidth)
.failed()) {
return errorDiag("Transforming FHE big integer ops failed");
}
// FHE High level pass to determine FHE parameters
if (auto err = this->determineFHEParameters(res))
return std::move(err);
@@ -414,8 +421,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
mlirContext, module, enablePass,
options.unrollLoopsWithSDFGConvertibleOps)
.failed()) {
return errorDiag(
"Extraction of SDFG operations from BConcrete representation failed");
return errorDiag("Extraction of SDFG operations from BConcrete "
"representation failed");
}
}
@@ -427,8 +434,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
enablePass)
.failed()) {
return errorDiag(
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
return errorDiag("Lowering from Bufferized Concrete to canonical MLIR "
"dialects failed");
}
// SDFG -> Canonical dialects
@@ -595,16 +602,17 @@ CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) {
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
const std::string CompilerEngine::Library::LINKER = "ld";
#ifdef __APPLE__
// We need to tell the linker that some symbols will be missing during linking,
// this symbols should be available during runtime however. This is the case
// when JIT compiling, the JIT should either link to the runtime library that
// has the missing symbols, or it would have been loaded even prior to that.
// Starting from Mac 11 (Big Sur), it appears we need to add -L
// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for the
// sharedlib to link properly.
// We need to tell the linker that some symbols will be missing during
// linking, this symbols should be available during runtime however. This is
// the case when JIT compiling, the JIT should either link to the runtime
// library that has the missing symbols, or it would have been loaded even
// prior to that. Starting from Mac 11 (Big Sur), it appears we need to add -L
// /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem for
// the sharedlib to link properly.
const std::string CompilerEngine::Library::LINKER_SHARED_OPT =
" -dylib -undefined dynamic_lookup -L "
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem -o ";
"/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib -lSystem "
"-o ";
const std::string CompilerEngine::Library::DOT_SHARED_LIB_EXT = ".dylib";
#else // Linux
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
@@ -798,7 +806,8 @@ llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
std::vector<std::string> extraArgs;
std::string fullRuntimeLibraryName = "";
#ifdef __APPLE__
// to issue the command for fixing the runtime dependency of the generated lib
// to issue the command for fixing the runtime dependency of the generated
// lib
bool fixRuntimeDep = false;
#endif
if (!runtimeLibraryPath.empty()) {
@@ -841,12 +850,13 @@ llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
sharedLibraryPath = path.get();
#ifdef __APPLE__
// when dellocate is used to include dependencies in python wheels, the
// runtime library will have an id that is prefixed with /DLC, and that path
// doesn't exist. So when generated libraries won't be able to find it
// during load time. To solve this, we change the dep in the generated
// library to be relative to the rpath which should be set correctly during
// linking. This shouldn't have an impact when /DLC/concrete/.dylibs/* isn't
// a dependecy in the first place (when not using python).
// runtime library will have an id that is prefixed with /DLC, and that
// path doesn't exist. So when generated libraries won't be able to find
// it during load time. To solve this, we change the dep in the generated
// library to be relative to the rpath which should be set correctly
// during linking. This shouldn't have an impact when
// /DLC/concrete/.dylibs/* isn't a dependecy in the first place (when not
// using python).
if (fixRuntimeDep) {
std::string fixRuntimeDepCmd = "install_name_tool -change "
"/DLC/concrete/.dylibs/" +

View File

@@ -35,7 +35,8 @@
#include <concretelang/Dialect/Concrete/Transforms/Optimization.h>
#include <concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h>
#include <concretelang/Dialect/FHE/Analysis/MANP.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean.h>
#include <concretelang/Dialect/FHE/Transforms/BigInt/BigInt.h>
#include <concretelang/Dialect/FHE/Transforms/Boolean/Boolean.h>
#include <concretelang/Dialect/FHE/Transforms/EncryptedMulToDoubleTLU.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
@@ -218,6 +219,23 @@ transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
transformFHEBigInt(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
unsigned int chunkSize, unsigned int chunkWidth) {
mlir::PassManager pm(&context);
addPotentiallyNestedPass(
pm,
mlir::concretelang::createFHEBigIntTransformPass(chunkSize, chunkWidth),
enablePass);
// We want to fully unroll for loops introduced by the BigInt transform since
// MANP doesn't support loops. This is a workaround that make the IR much
// bigger than it should be
addPotentiallyNestedPass(pm, mlir::createLoopUnrollPass(-1, false, true),
enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -206,6 +206,18 @@ llvm::cl::list<uint64_t>
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::opt<unsigned int> chunkSize(
"chunk-size",
llvm::cl::desc(
"Chunk size while decomposing big integers into chunks, default is 4"),
llvm::cl::init<unsigned int>(4));
llvm::cl::opt<unsigned int> chunkWidth(
"chunk-width",
llvm::cl::desc(
"Chunk width while decomposing big integers into chunks, default is 2"),
llvm::cl::init<unsigned int>(2));
llvm::cl::opt<std::string> jitKeySetCachePath(
"jit-keyset-cache-path",
llvm::cl::desc("Path to cache KeySet content (unsecure)"));
@@ -338,6 +350,8 @@ cmdlineCompilationOptions() {
cmdline::unrollLoopsWithSDFGConvertibleOps;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;
options.chunkSize = cmdline::chunkSize;
options.chunkWidth = cmdline::chunkWidth;
if (!cmdline::v0Constraint.empty()) {
if (cmdline::v0Constraint.size() != 2) {

View File

@@ -0,0 +1,24 @@
// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>>
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4>
// CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>>
// CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5
// CHECK-NEXT: %[[V2:.*]] = affine.for %arg2 = 0 to 32 iter_args(%arg3 = %[[V1]]) -> (tensor<32x!FHE.eint<4>>) {
// CHECK-NEXT: %[[V3:.*]] = tensor.extract %arg0[%arg2] : tensor<32x!FHE.eint<4>>
// CHECK-NEXT: %[[V4:.*]] = tensor.extract %arg1[%arg2] : tensor<32x!FHE.eint<4>>
// CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint"(%[[V3]], %[[V4]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[V0]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]> : tensor<16xi64>
// CHECK-NEXT: %[[V7:.*]] = "FHE.apply_lookup_table"(%[[V6]], %[[cst]]) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
// CHECK-NEXT: %[[V8:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[c4_i5]]) : (!FHE.eint<4>, i5) -> !FHE.eint<4>
// CHECK-NEXT: %[[V9:.*]] = "FHE.sub_eint"(%[[V6]], %[[V8]]) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: %[[V10:.*]] = tensor.insert %[[V9]] into %arg3[%arg2] : tensor<32x!FHE.eint<4>>
// CHECK-NEXT: affine.yield %[[V10]] : tensor<32x!FHE.eint<4>>
// CHECK-NEXT: }
// CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}

View File

@@ -0,0 +1,34 @@
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
%0 = arith.constant 1 : i64
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
%0 = arith.constant 1 : i64
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// TODO: max/min