refactor(compiler): Prepare the MidLFHE parameters injection

This commit is contained in:
Quentin Bourgerie
2021-08-13 18:01:10 +02:00
parent f948db1228
commit 8057ee7553
12 changed files with 111 additions and 80 deletions

View File

@@ -4,12 +4,14 @@
#include "mlir/Pass/Pass.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
namespace mlir {
namespace zamalang {
/// Create a pass to convert `LowLFHE` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize);
createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &context);
} // namespace zamalang
} // namespace mlir

View File

@@ -20,14 +20,14 @@ LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context,
return LweCiphertextType::get(context);
}
PlaintextType convertIntToPlaintextType(mlir::MLIRContext *context,
IntegerType &type) {
return PlaintextType::get(context, type.getWidth());
PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context,
GLWECipherTextType &type) {
return PlaintextType::get(context, type.getP() + 1);
}
CleartextType convertIntToCleartextType(mlir::MLIRContext *context,
IntegerType &type) {
return CleartextType::get(context, type.getWidth());
CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context,
GLWECipherTextType &type) {
return CleartextType::get(context, type.getP() + 1);
}
template <class Operator>
@@ -44,13 +44,11 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
return op.getODSResults(0).front();
}
mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
auto integer_type = arg1.getType().cast<IntegerType>();
mlir::Value createAddPlainLweCiphertextWithGlwe(
mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) {
PlaintextType encoded_type =
convertIntToPlaintextType(rewriter.getContext(), integer_type);
convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe);
// encode int into plaintext
mlir::Value encoded =
rewriter
@@ -67,6 +65,15 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
return op.getODSResults(0).front();
}
mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
auto glwe = arg0.getType().cast<GLWECipherTextType>();
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, arg0, arg1, result,
glwe);
}
mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result) {
@@ -76,16 +83,17 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1)
.result();
return createAddPlainLweCiphertext(rewriter, loc, negated_arg1, arg0, result);
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
result, arg1_type);
}
mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
auto integer_type = arg1.getType().cast<IntegerType>();
auto glwe = arg0.getType().cast<GLWECipherTextType>();
CleartextType encoded_type =
convertIntToCleartextType(rewriter.getContext(), integer_type);
convertCleartextTypeFromGlwe(rewriter.getContext(), glwe);
// encode int into plaintext
mlir::Value encoded = rewriter
.create<mlir::zamalang::LowLFHE::IntToCleartextOp>(

View File

@@ -0,0 +1,40 @@
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
#include <cstddef>
namespace mlir {
namespace zamalang {
struct V0FHEConstraint {
size_t norm2;
size_t p;
};
struct V0Parameter {
size_t k;
size_t polynomialSize;
size_t nSmall;
size_t brLevel;
size_t brLogBase;
size_t ksLevel;
size_t ksLogBase;
V0Parameter() {}
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
size_t getNBigGlweSize() { return k * (1 << polynomialSize); }
};
struct V0FHEContext {
V0FHEConstraint constraint;
V0Parameter parameter;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -73,8 +73,8 @@ struct ClientParameters {
};
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module);
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
mlir::ModuleOp module);
} // namespace zamalang
} // namespace mlir

View File

@@ -12,13 +12,6 @@
namespace mlir {
namespace zamalang {
/// For the v0 we compute a global constraint, this is defined here as the
/// high-level verification pass is not yet implemented.
struct FHECircuitConstraint {
size_t norm2;
size_t p;
};
class CompilerTools {
public:
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
@@ -26,7 +19,7 @@ public:
/// The given module MLIR operation would be modified and the constraints set.
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
V0FHEContext &fheContext,
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
return true;
});

View File

@@ -1,30 +1,13 @@
#ifndef ZAMALANG_SUPPORT_V0Parameter_H_
#define ZAMALANG_SUPPORT_V0Parameter_H_
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
#include <cstddef>
namespace mlir {
namespace zamalang {
typedef struct V0Parameter {
size_t k;
size_t polynomialSize;
size_t nSmall;
size_t brLevel;
size_t brLogBase;
size_t ksLevel;
size_t ksLogBase;
V0Parameter() {}
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
} V0Parameter;
V0Parameter *getV0Parameter(size_t norm, size_t p);
V0Parameter *getV0Parameter(V0FHEConstraint constraint);
} // namespace zamalang
} // namespace mlir

View File

@@ -202,9 +202,10 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns,
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
LowLFHEToConcreteCAPIPass(uint64_t lweSize) : lweSize(lweSize){};
LowLFHEToConcreteCAPIPass(mlir::zamalang::V0FHEContext &fheContext)
: fheContext(fheContext){};
void runOnOperation() final;
uint64_t lweSize;
mlir::zamalang::V0FHEContext &fheContext;
};
} // namespace
@@ -217,6 +218,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
// Setup rewrite patterns
mlir::RewritePatternSet patterns(&getContext());
auto lweSize = 1 << fheContext.parameter.polynomialSize;
populateLowLFHEToConcreteCAPICall(patterns, lweSize);
// Apply the conversion
@@ -229,8 +231,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) {
return std::make_unique<LowLFHEToConcreteCAPIPass>(lweSize);
createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &fheContext) {
return std::make_unique<LowLFHEToConcreteCAPIPass>(fheContext);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -45,13 +45,14 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module) {
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
mlir::ModuleOp module) {
auto v0Param = fheContext.parameter;
// Static client parameters from global parameters for v0
ClientParameters c{
.secretKeys{
{"small", {.size = v0Param.nSmall}},
{"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}},
{"big", {.size = v0Param.getNBigGlweSize()}},
},
.bootstrapKeys{
{
@@ -92,7 +93,8 @@ createClientParametersForV0(V0Parameter &v0Param, Precision precision,
}
// For the v0 the precision is global
Encoding encoding = {.precision = precision};
auto precision = fheContext.constraint.p;
Encoding encoding = {.precision = fheContext.constraint.p};
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getType();

View File

@@ -30,18 +30,17 @@ CompilerEngine::compileFHE(std::string mlir_input) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
}
mlir::zamalang::FHECircuitConstraint constraint;
mlir::zamalang::V0Parameter v0Parameter;
mlir::zamalang::V0FHEContext fheContext;
// Lower to MLIR Std
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
*context, module_ref.get(), constraint, v0Parameter)
*context, module_ref.get(), fheContext)
.failed()) {
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
llvm::inconvertibleErrorCode());
}
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
v0Parameter, constraint.p, "main", module_ref.get());
fheContext, "main", module_ref.get());
if (auto err = clientParameter.takeError()) {
return llvm::make_error<llvm::StringError>(
"cannot generate client parameters", llvm::inconvertibleErrorCode());

View File

@@ -9,7 +9,7 @@ namespace mlir {
namespace zamalang {
// This is temporary while we doesn't yet have the high-level verification pass
FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7};
V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7};
void initLLVMNativeTarget() {
// Initialize LLVM targets.
@@ -32,23 +32,23 @@ void addFilteredPassToPassManager(
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
V0FHEContext &fheContext,
llvm::function_ref<bool(std::string)> enablePass) {
mlir::PassManager pm(&context);
constraint = defaultGlobalFHECircuitConstraint;
v0Parameter = *getV0Parameter(constraint.norm2, constraint.p);
fheContext.constraint = defaultGlobalFHECircuitConstraint;
fheContext.parameter = *getV0Parameter(fheContext.constraint);
// Add all passes to lower from HLFHE to LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm,
mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(
1 << v0Parameter.polynomialSize),
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(fheContext),
enablePass);
// Run the passes

View File

@@ -115,15 +115,15 @@ V0Parameter parameters[NORM2_MAX][P_MAX] = {
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0),
V0Parameter(0, 0, 0, 0, 0, 0, 0)}};
V0Parameter *getV0Parameter(size_t norm, size_t p) {
if (norm > NORM2_MAX) {
V0Parameter *getV0Parameter(V0FHEConstraint constraint) {
if (constraint.norm2 > NORM2_MAX) {
return nullptr;
}
if (p > P_MAX) {
if (constraint.p > P_MAX) {
return nullptr;
}
// - 1 is an offset as norm and p are in [1, ...] and not [0, ...]
auto param = &parameters[norm - 1][p - 1];
auto param = &parameters[constraint.norm2 - 1][constraint.p - 1];
if (param->k == 0) {
return nullptr;
}

View File

@@ -175,30 +175,32 @@ processInputBuffer(mlir::MLIRContext &context,
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::FHECircuitConstraint constraint;
mlir::zamalang::V0Parameter v0Parameter;
mlir::zamalang::V0FHEContext fheContext;
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, constraint, v0Parameter, enablePass)
context, *module, fheContext, enablePass)
.failed()) {
return mlir::failure();
}
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
<< constraint.p << "}\n");
LOG_VERBOSE("### Global FHE constraint: {norm2:"
<< fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n");
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
<< v0Parameter.k
<< ", polynomialSize: " << v0Parameter.polynomialSize
<< ", nSmall: " << v0Parameter.nSmall << ", brLevel: "
<< v0Parameter.brLevel << ", brLogBase: " << v0Parameter.brLogBase
<< ", ksLevel: " << v0Parameter.ksLevel
<< ", polynomialSize: " << v0Parameter.ksLogBase << "}\n");
<< fheContext.parameter.k
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
<< ", nSmall: " << fheContext.parameter.nSmall
<< ", brLevel: " << fheContext.parameter.brLevel
<< ", brLogBase: " << fheContext.parameter.brLogBase
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", polynomialSize: " << fheContext.parameter.ksLogBase
<< "}\n");
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
v0Parameter, constraint.p, cmdline::jitFuncname, *module);
fheContext, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
LOG_ERROR("cannot generate client parameters: " << err << "\n");
return mlir::failure();