mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(compiler): Prepare the MidLFHE parameters injection
This commit is contained in:
@@ -4,12 +4,14 @@
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass to convert `LowLFHE` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize);
|
||||
createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &context);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -20,14 +20,14 @@ LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context,
|
||||
return LweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
PlaintextType convertIntToPlaintextType(mlir::MLIRContext *context,
|
||||
IntegerType &type) {
|
||||
return PlaintextType::get(context, type.getWidth());
|
||||
PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &type) {
|
||||
return PlaintextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
CleartextType convertIntToCleartextType(mlir::MLIRContext *context,
|
||||
IntegerType &type) {
|
||||
return CleartextType::get(context, type.getWidth());
|
||||
CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context,
|
||||
GLWECipherTextType &type) {
|
||||
return CleartextType::get(context, type.getP() + 1);
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
@@ -44,13 +44,11 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter,
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto integer_type = arg1.getType().cast<IntegerType>();
|
||||
mlir::Value createAddPlainLweCiphertextWithGlwe(
|
||||
mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) {
|
||||
PlaintextType encoded_type =
|
||||
convertIntToPlaintextType(rewriter.getContext(), integer_type);
|
||||
convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded =
|
||||
rewriter
|
||||
@@ -67,6 +65,15 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto glwe = arg0.getType().cast<GLWECipherTextType>();
|
||||
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, arg0, arg1, result,
|
||||
glwe);
|
||||
}
|
||||
|
||||
mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
@@ -76,16 +83,17 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
.create<mlir::zamalang::LowLFHE::NegateLweCiphertextOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1)
|
||||
.result();
|
||||
return createAddPlainLweCiphertext(rewriter, loc, negated_arg1, arg0, result);
|
||||
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
|
||||
result, arg1_type);
|
||||
}
|
||||
|
||||
mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto integer_type = arg1.getType().cast<IntegerType>();
|
||||
auto glwe = arg0.getType().cast<GLWECipherTextType>();
|
||||
CleartextType encoded_type =
|
||||
convertIntToCleartextType(rewriter.getContext(), integer_type);
|
||||
convertCleartextTypeFromGlwe(rewriter.getContext(), glwe);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded = rewriter
|
||||
.create<mlir::zamalang::LowLFHE::IntToCleartextOp>(
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
struct V0FHEConstraint {
|
||||
size_t norm2;
|
||||
size_t p;
|
||||
};
|
||||
|
||||
struct V0Parameter {
|
||||
size_t k;
|
||||
size_t polynomialSize;
|
||||
size_t nSmall;
|
||||
size_t brLevel;
|
||||
size_t brLogBase;
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
V0Parameter() {}
|
||||
|
||||
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
|
||||
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
|
||||
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
|
||||
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
|
||||
|
||||
size_t getNBigGlweSize() { return k * (1 << polynomialSize); }
|
||||
};
|
||||
|
||||
struct V0FHEContext {
|
||||
V0FHEConstraint constraint;
|
||||
V0Parameter parameter;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -73,8 +73,8 @@ struct ClientParameters {
|
||||
};
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module);
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
|
||||
mlir::ModuleOp module);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -12,13 +12,6 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// For the v0 we compute a global constraint, this is defined here as the
|
||||
/// high-level verification pass is not yet implemented.
|
||||
struct FHECircuitConstraint {
|
||||
size_t norm2;
|
||||
size_t p;
|
||||
};
|
||||
|
||||
class CompilerTools {
|
||||
public:
|
||||
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
|
||||
@@ -26,7 +19,7 @@ public:
|
||||
/// The given module MLIR operation would be modified and the constraints set.
|
||||
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
|
||||
V0FHEContext &fheContext,
|
||||
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
|
||||
return true;
|
||||
});
|
||||
|
||||
@@ -1,30 +1,13 @@
|
||||
#ifndef ZAMALANG_SUPPORT_V0Parameter_H_
|
||||
#define ZAMALANG_SUPPORT_V0Parameter_H_
|
||||
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
typedef struct V0Parameter {
|
||||
size_t k;
|
||||
size_t polynomialSize;
|
||||
size_t nSmall;
|
||||
size_t brLevel;
|
||||
size_t brLogBase;
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
V0Parameter() {}
|
||||
|
||||
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
|
||||
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
|
||||
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
|
||||
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
|
||||
|
||||
} V0Parameter;
|
||||
|
||||
V0Parameter *getV0Parameter(size_t norm, size_t p);
|
||||
V0Parameter *getV0Parameter(V0FHEConstraint constraint);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -202,9 +202,10 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns,
|
||||
namespace {
|
||||
struct LowLFHEToConcreteCAPIPass
|
||||
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
|
||||
LowLFHEToConcreteCAPIPass(uint64_t lweSize) : lweSize(lweSize){};
|
||||
LowLFHEToConcreteCAPIPass(mlir::zamalang::V0FHEContext &fheContext)
|
||||
: fheContext(fheContext){};
|
||||
void runOnOperation() final;
|
||||
uint64_t lweSize;
|
||||
mlir::zamalang::V0FHEContext &fheContext;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -217,6 +218,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
|
||||
// Setup rewrite patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
auto lweSize = 1 << fheContext.parameter.polynomialSize;
|
||||
populateLowLFHEToConcreteCAPICall(patterns, lweSize);
|
||||
|
||||
// Apply the conversion
|
||||
@@ -229,8 +231,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>(lweSize);
|
||||
createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &fheContext) {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>(fheContext);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -45,13 +45,14 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module) {
|
||||
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
mlir::ModuleOp module) {
|
||||
auto v0Param = fheContext.parameter;
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c{
|
||||
.secretKeys{
|
||||
{"small", {.size = v0Param.nSmall}},
|
||||
{"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}},
|
||||
{"big", {.size = v0Param.getNBigGlweSize()}},
|
||||
},
|
||||
.bootstrapKeys{
|
||||
{
|
||||
@@ -92,7 +93,8 @@ createClientParametersForV0(V0Parameter &v0Param, Precision precision,
|
||||
}
|
||||
|
||||
// For the v0 the precision is global
|
||||
Encoding encoding = {.precision = precision};
|
||||
auto precision = fheContext.constraint.p;
|
||||
Encoding encoding = {.precision = fheContext.constraint.p};
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getType();
|
||||
|
||||
@@ -30,18 +30,17 @@ CompilerEngine::compileFHE(std::string mlir_input) {
|
||||
return llvm::make_error<llvm::StringError>("mlir parsing failed",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
mlir::zamalang::V0Parameter v0Parameter;
|
||||
mlir::zamalang::V0FHEContext fheContext;
|
||||
// Lower to MLIR Std
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
*context, module_ref.get(), constraint, v0Parameter)
|
||||
*context, module_ref.get(), fheContext)
|
||||
.failed()) {
|
||||
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
v0Parameter, constraint.p, "main", module_ref.get());
|
||||
fheContext, "main", module_ref.get());
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot generate client parameters", llvm::inconvertibleErrorCode());
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// This is temporary while we doesn't yet have the high-level verification pass
|
||||
FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7};
|
||||
V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7};
|
||||
|
||||
void initLLVMNativeTarget() {
|
||||
// Initialize LLVM targets.
|
||||
@@ -32,23 +32,23 @@ void addFilteredPassToPassManager(
|
||||
|
||||
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
|
||||
V0FHEContext &fheContext,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
constraint = defaultGlobalFHECircuitConstraint;
|
||||
v0Parameter = *getV0Parameter(constraint.norm2, constraint.p);
|
||||
fheContext.constraint = defaultGlobalFHECircuitConstraint;
|
||||
fheContext.parameter = *getV0Parameter(fheContext.constraint);
|
||||
// Add all passes to lower from HLFHE to LLVM Dialect
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm,
|
||||
mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(
|
||||
1 << v0Parameter.polynomialSize),
|
||||
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(fheContext),
|
||||
enablePass);
|
||||
|
||||
// Run the passes
|
||||
|
||||
@@ -115,15 +115,15 @@ V0Parameter parameters[NORM2_MAX][P_MAX] = {
|
||||
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0),
|
||||
V0Parameter(0, 0, 0, 0, 0, 0, 0)}};
|
||||
|
||||
V0Parameter *getV0Parameter(size_t norm, size_t p) {
|
||||
if (norm > NORM2_MAX) {
|
||||
V0Parameter *getV0Parameter(V0FHEConstraint constraint) {
|
||||
if (constraint.norm2 > NORM2_MAX) {
|
||||
return nullptr;
|
||||
}
|
||||
if (p > P_MAX) {
|
||||
if (constraint.p > P_MAX) {
|
||||
return nullptr;
|
||||
}
|
||||
// - 1 is an offset as norm and p are in [1, ...] and not [0, ...]
|
||||
auto param = ¶meters[norm - 1][p - 1];
|
||||
auto param = ¶meters[constraint.norm2 - 1][constraint.p - 1];
|
||||
if (param->k == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -175,30 +175,32 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
};
|
||||
|
||||
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
mlir::zamalang::V0Parameter v0Parameter;
|
||||
mlir::zamalang::V0FHEContext fheContext;
|
||||
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
context, *module, constraint, v0Parameter, enablePass)
|
||||
context, *module, fheContext, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
|
||||
<< constraint.p << "}\n");
|
||||
LOG_VERBOSE("### Global FHE constraint: {norm2:"
|
||||
<< fheContext.constraint.norm2
|
||||
<< ", p:" << fheContext.constraint.p << "}\n");
|
||||
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
|
||||
<< v0Parameter.k
|
||||
<< ", polynomialSize: " << v0Parameter.polynomialSize
|
||||
<< ", nSmall: " << v0Parameter.nSmall << ", brLevel: "
|
||||
<< v0Parameter.brLevel << ", brLogBase: " << v0Parameter.brLogBase
|
||||
<< ", ksLevel: " << v0Parameter.ksLevel
|
||||
<< ", polynomialSize: " << v0Parameter.ksLogBase << "}\n");
|
||||
<< fheContext.parameter.k
|
||||
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
|
||||
<< ", nSmall: " << fheContext.parameter.nSmall
|
||||
<< ", brLevel: " << fheContext.parameter.brLevel
|
||||
<< ", brLogBase: " << fheContext.parameter.brLogBase
|
||||
<< ", ksLevel: " << fheContext.parameter.ksLevel
|
||||
<< ", polynomialSize: " << fheContext.parameter.ksLogBase
|
||||
<< "}\n");
|
||||
|
||||
// Generate the keySet
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
if (cmdline::generateKeySet || cmdline::runJit) {
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
v0Parameter, constraint.p, cmdline::jitFuncname, *module);
|
||||
fheContext, cmdline::jitFuncname, *module);
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
LOG_ERROR("cannot generate client parameters: " << err << "\n");
|
||||
return mlir::failure();
|
||||
|
||||
Reference in New Issue
Block a user