mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
fix(compiler/lowlfhe): for the v0 give the lweSize of ciphertext as a global parameter of the lowering pass to concrete api call (#62)
This commit is contained in:
@@ -9,7 +9,7 @@ namespace zamalang {
|
||||
/// Create a pass to convert `LowLFHE` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass();
|
||||
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ struct ClientParameters {
|
||||
};
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -26,7 +26,7 @@ public:
|
||||
/// The given module MLIR operation would be modified and the constraints set.
|
||||
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint,
|
||||
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
|
||||
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
|
||||
return true;
|
||||
});
|
||||
|
||||
@@ -15,6 +15,8 @@ typedef struct V0Parameter {
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
V0Parameter() {}
|
||||
|
||||
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
|
||||
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
|
||||
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
|
||||
|
||||
@@ -30,9 +30,10 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
|
||||
mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
uint64_t lweSize,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
|
||||
allocName(allocName) {}
|
||||
allocName(allocName), lweSize(lweSize) {}
|
||||
|
||||
mlir::LogicalResult static insertForwardDeclaration(
|
||||
Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName,
|
||||
@@ -92,17 +93,16 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
// Create the err value
|
||||
auto err = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
// Add the call to the allocation
|
||||
// TODO - 2018
|
||||
auto lweSize = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(2048));
|
||||
mlir::SmallVector<mlir::Value, 1> allocOperands{err, lweSize};
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweSize));
|
||||
mlir::SmallVector<mlir::Value, 1> allocOperands{errOp, lweSizeOp};
|
||||
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, allocName, op.getType(), allocOperands);
|
||||
|
||||
// Add err and allocated value to operands
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{err, alloc.getResult(0)};
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{errOp, alloc.getResult(0)};
|
||||
for (auto operand : op->getOperands()) {
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
@@ -115,21 +115,24 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
private:
|
||||
std::string funcName;
|
||||
std::string allocName;
|
||||
uint64_t lweSize;
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns, uint64_t lweSize) {
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::zamalang::LowLFHE::AddLweCiphertextsOp>>(
|
||||
patterns.getContext(), "add_lwe_ciphertexts_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
"allocate_lwe_ciphertext_u64", lweSize);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowLFHEToConcreteCAPIPass
|
||||
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
|
||||
LowLFHEToConcreteCAPIPass(uint64_t lweSize): lweSize(lweSize){};
|
||||
void runOnOperation() final;
|
||||
uint64_t lweSize;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -142,7 +145,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
|
||||
// Setup rewrite patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
populateLowLFHEToConcreteCAPICall(patterns, lweSize);
|
||||
|
||||
// Apply the conversion
|
||||
mlir::ModuleOp op = getOperation();
|
||||
@@ -154,8 +157,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass() {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>();
|
||||
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>(lweSize);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -45,13 +45,13 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module) {
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c{
|
||||
.secretKeys{
|
||||
{"small", {.size = v0Param->nSmall}},
|
||||
{"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}},
|
||||
{"small", {.size = v0Param.nSmall}},
|
||||
{"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}},
|
||||
},
|
||||
.bootstrapKeys{
|
||||
{
|
||||
@@ -59,9 +59,9 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
{
|
||||
.inputSecretKeyID = "small",
|
||||
.outputSecretKeyID = "big",
|
||||
.level = v0Param->brLevel,
|
||||
.baseLog = v0Param->brLogBase,
|
||||
.k = v0Param->k,
|
||||
.level = v0Param.brLevel,
|
||||
.baseLog = v0Param.brLogBase,
|
||||
.k = v0Param.k,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
},
|
||||
@@ -73,8 +73,8 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
{
|
||||
.inputSecretKeyID = "big",
|
||||
.outputSecretKeyID = "small",
|
||||
.level = v0Param->ksLevel,
|
||||
.baseLog = v0Param->ksLogBase,
|
||||
.level = v0Param.ksLevel,
|
||||
.baseLog = v0Param.ksLogBase,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
},
|
||||
|
||||
@@ -32,18 +32,19 @@ void addFilteredPassToPassManager(
|
||||
|
||||
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint,
|
||||
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
constraint = defaultGlobalFHECircuitConstraint;
|
||||
v0Parameter = *getV0Parameter(constraint.norm2, constraint.p);
|
||||
// Add all passes to lower from HLFHE to LLVM Dialect
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
|
||||
constraint = defaultGlobalFHECircuitConstraint;
|
||||
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(1 << v0Parameter.polynomialSize), enablePass);
|
||||
|
||||
// Run the passes
|
||||
if (pm.run(module).failed()) {
|
||||
|
||||
@@ -175,33 +175,30 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
|
||||
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
mlir::zamalang::V0Parameter v0Parameter;
|
||||
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
context, *module, constraint, enablePass)
|
||||
context, *module, constraint, v0Parameter, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
|
||||
<< constraint.p << "}\n");
|
||||
|
||||
// Retreive the parameters for the v0 approach
|
||||
mlir::zamalang::V0Parameter *fheParameter =
|
||||
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
|
||||
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
|
||||
<< fheParameter->k
|
||||
<< ", polynomialSize: " << fheParameter->polynomialSize
|
||||
<< ", nSmall: " << fheParameter->nSmall
|
||||
<< ", brLevel: " << fheParameter->brLevel
|
||||
<< ", brLogBase: " << fheParameter->brLogBase
|
||||
<< ", ksLevel: " << fheParameter->ksLevel
|
||||
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
|
||||
<< v0Parameter.k
|
||||
<< ", polynomialSize: " << v0Parameter.polynomialSize
|
||||
<< ", nSmall: " << v0Parameter.nSmall
|
||||
<< ", brLevel: " << v0Parameter.brLevel
|
||||
<< ", brLogBase: " << v0Parameter.brLogBase
|
||||
<< ", ksLevel: " << v0Parameter.ksLevel
|
||||
<< ", polynomialSize: " << v0Parameter.ksLogBase << "}\n");
|
||||
|
||||
// Generate the keySet
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
if (cmdline::generateKeySet || cmdline::runJit) {
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheParameter, constraint.p, cmdline::jitFuncname, *module);
|
||||
v0Parameter, constraint.p, cmdline::jitFuncname, *module);
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
LOG_ERROR("cannot generate client parameters: " << err << "\n");
|
||||
return mlir::failure();
|
||||
|
||||
Reference in New Issue
Block a user