fix(compiler/lowlfhe): for the v0 give the lweSize of ciphertext as a global parameter of the lowering pass to concrete api call (#62)

This commit is contained in:
Quentin Bourgerie
2021-08-11 14:56:57 +02:00
parent b22f585380
commit 03297fd50d
8 changed files with 42 additions and 39 deletions

View File

@@ -9,7 +9,7 @@ namespace zamalang {
/// Create a pass to convert `LowLFHE` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass();
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize);
} // namespace zamalang
} // namespace mlir

View File

@@ -73,7 +73,7 @@ struct ClientParameters {
};
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module);
} // namespace zamalang
} // namespace mlir

View File

@@ -26,7 +26,7 @@ public:
/// The given module MLIR operation would be modified and the constraints set.
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint,
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
return true;
});

View File

@@ -15,6 +15,8 @@ typedef struct V0Parameter {
size_t ksLevel;
size_t ksLogBase;
V0Parameter() {}
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),

View File

@@ -30,9 +30,10 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
mlir::StringRef funcName,
mlir::StringRef allocName,
uint64_t lweSize,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
allocName(allocName) {}
allocName(allocName), lweSize(lweSize) {}
mlir::LogicalResult static insertForwardDeclaration(
Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName,
@@ -92,17 +93,16 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
// Replace the operation with a call to the `funcName`
{
// Create the err value
auto err = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
// Add the call to the allocation
// TODO - 2018
auto lweSize = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(2048));
mlir::SmallVector<mlir::Value, 1> allocOperands{err, lweSize};
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(lweSize));
mlir::SmallVector<mlir::Value, 1> allocOperands{errOp, lweSizeOp};
auto alloc = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, allocName, op.getType(), allocOperands);
// Add err and allocated value to operands
mlir::SmallVector<mlir::Value, 4> newOperands{err, alloc.getResult(0)};
mlir::SmallVector<mlir::Value, 4> newOperands{errOp, alloc.getResult(0)};
for (auto operand : op->getOperands()) {
newOperands.push_back(operand);
}
@@ -115,21 +115,24 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
private:
std::string funcName;
std::string allocName;
uint64_t lweSize;
};
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// operators to the corresponding function call to the `Concrete C API`.
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns, uint64_t lweSize) {
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::zamalang::LowLFHE::AddLweCiphertextsOp>>(
patterns.getContext(), "add_lwe_ciphertexts_u64",
"allocate_lwe_ciphertext_u64");
"allocate_lwe_ciphertext_u64", lweSize);
}
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
LowLFHEToConcreteCAPIPass(uint64_t lweSize): lweSize(lweSize){};
void runOnOperation() final;
uint64_t lweSize;
};
} // namespace
@@ -142,7 +145,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
// Setup rewrite patterns
mlir::RewritePatternSet patterns(&getContext());
populateLowLFHEToConcreteCAPICall(patterns);
populateLowLFHEToConcreteCAPICall(patterns, lweSize);
// Apply the conversion
mlir::ModuleOp op = getOperation();
@@ -154,8 +157,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
namespace mlir {
namespace zamalang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass() {
return std::make_unique<LowLFHEToConcreteCAPIPass>();
createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) {
return std::make_unique<LowLFHEToConcreteCAPIPass>(lweSize);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -45,13 +45,13 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
createClientParametersForV0(V0Parameter &v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module) {
// Static client parameters from global parameters for v0
ClientParameters c{
.secretKeys{
{"small", {.size = v0Param->nSmall}},
{"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}},
{"small", {.size = v0Param.nSmall}},
{"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}},
},
.bootstrapKeys{
{
@@ -59,9 +59,9 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision,
{
.inputSecretKeyID = "small",
.outputSecretKeyID = "big",
.level = v0Param->brLevel,
.baseLog = v0Param->brLogBase,
.k = v0Param->k,
.level = v0Param.brLevel,
.baseLog = v0Param.brLogBase,
.k = v0Param.k,
// TODO - Compute variance, wait for security estimator
.variance = 0,
},
@@ -73,8 +73,8 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision,
{
.inputSecretKeyID = "big",
.outputSecretKeyID = "small",
.level = v0Param->ksLevel,
.baseLog = v0Param->ksLogBase,
.level = v0Param.ksLevel,
.baseLog = v0Param.ksLogBase,
// TODO - Compute variance, wait for security estimator
.variance = 0,
},

View File

@@ -32,18 +32,19 @@ void addFilteredPassToPassManager(
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint,
FHECircuitConstraint &constraint, V0Parameter &v0Parameter,
llvm::function_ref<bool(std::string)> enablePass) {
mlir::PassManager pm(&context);
constraint = defaultGlobalFHECircuitConstraint;
v0Parameter = *getV0Parameter(constraint.norm2, constraint.p);
// Add all passes to lower from HLFHE to LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass);
constraint = defaultGlobalFHECircuitConstraint;
pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(1 << v0Parameter.polynomialSize), enablePass);
// Run the passes
if (pm.run(module).failed()) {

View File

@@ -175,33 +175,30 @@ processInputBuffer(mlir::MLIRContext &context,
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::FHECircuitConstraint constraint;
mlir::zamalang::V0Parameter v0Parameter;
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, constraint, enablePass)
context, *module, constraint, v0Parameter, enablePass)
.failed()) {
return mlir::failure();
}
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
<< constraint.p << "}\n");
// Retreive the parameters for the v0 approach
mlir::zamalang::V0Parameter *fheParameter =
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
<< fheParameter->k
<< ", polynomialSize: " << fheParameter->polynomialSize
<< ", nSmall: " << fheParameter->nSmall
<< ", brLevel: " << fheParameter->brLevel
<< ", brLogBase: " << fheParameter->brLogBase
<< ", ksLevel: " << fheParameter->ksLevel
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
<< v0Parameter.k
<< ", polynomialSize: " << v0Parameter.polynomialSize
<< ", nSmall: " << v0Parameter.nSmall
<< ", brLevel: " << v0Parameter.brLevel
<< ", brLogBase: " << v0Parameter.brLogBase
<< ", ksLevel: " << v0Parameter.ksLevel
<< ", polynomialSize: " << v0Parameter.ksLogBase << "}\n");
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheParameter, constraint.p, cmdline::jitFuncname, *module);
v0Parameter, constraint.p, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
LOG_ERROR("cannot generate client parameters: " << err << "\n");
return mlir::failure();