mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: accept no evaluation keys
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
#include <map>
|
||||
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
@@ -19,14 +20,13 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::BIG_KEY;
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::Encoding;
|
||||
using ::concretelang::clientlib::EncryptionGate;
|
||||
using ::concretelang::clientlib::LweSecretKeyID;
|
||||
using ::concretelang::clientlib::Precision;
|
||||
using ::concretelang::clientlib::SMALL_KEY;
|
||||
using ::concretelang::clientlib::Variance;
|
||||
|
||||
const auto securityLevel = SECURITY_LEVEL_128;
|
||||
@@ -105,34 +105,43 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c;
|
||||
c.secretKeys = {
|
||||
{SMALL_KEY, {/*.size = */ v0Param.nSmall}},
|
||||
{BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
|
||||
};
|
||||
c.bootstrapKeys = {
|
||||
{
|
||||
"bsk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ BIG_KEY,
|
||||
/*.level = */ v0Param.brLevel,
|
||||
/*.baseLog = */ v0Param.brLogBase,
|
||||
/*.glweDimension = */ v0Param.glweDimension,
|
||||
/*.variance = */ encryptionVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
c.keyswitchKeys = {
|
||||
{
|
||||
"ksk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ BIG_KEY,
|
||||
/*.outputSecretKeyID = */ SMALL_KEY,
|
||||
/*.level = */ v0Param.ksLevel,
|
||||
/*.baseLog = */ v0Param.ksLogBase,
|
||||
/*.variance = */ keyswitchVariance,
|
||||
},
|
||||
},
|
||||
{clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
|
||||
};
|
||||
bool has_small_key = v0Param.nSmall != 0;
|
||||
bool has_bootstrap = v0Param.brLevel != 0;
|
||||
if (has_small_key) {
|
||||
c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}});
|
||||
}
|
||||
if (has_bootstrap) {
|
||||
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
|
||||
c.bootstrapKeys = {
|
||||
{
|
||||
clientlib::BOOTSTRAP_KEY,
|
||||
{
|
||||
/*.inputSecretKeyID = */ inputKey,
|
||||
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.level = */ v0Param.brLevel,
|
||||
/*.baseLog = */ v0Param.brLogBase,
|
||||
/*.glweDimension = */ v0Param.glweDimension,
|
||||
/*.variance = */ encryptionVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
if (has_small_key) {
|
||||
c.keyswitchKeys = {
|
||||
{
|
||||
clientlib::KEYSWITCH_KEY,
|
||||
{
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ v0Param.ksLevel,
|
||||
/*.baseLog = */ v0Param.ksLogBase,
|
||||
/*.variance = */ keyswitchVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
c.functionName = (std::string)functionName;
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
||||
@@ -155,16 +164,19 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
? false
|
||||
: inputs.back().isa<mlir::concretelang::Concrete::ContextType>();
|
||||
|
||||
auto gateFromType = [&](mlir::Type ty) {
|
||||
return gateFromMLIRType(clientlib::BIG_KEY, encryptionVariance, ty);
|
||||
};
|
||||
for (auto inType = funcType.getInputs().begin();
|
||||
inType < funcType.getInputs().end() - hasContext; inType++) {
|
||||
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, *inType);
|
||||
auto gate = gateFromType(*inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto outType : funcType.getResults()) {
|
||||
auto gate = gateFromMLIRType(BIG_KEY, encryptionVariance, outType);
|
||||
auto gate = gateFromType(outType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user