Files
concrete/compilers/concrete-compiler/compiler/lib/Support/ClientParametersGeneration.cpp
aPere3 cacffadbd2 feat(compiler): add support for multikey
This commit brings support for multiple secret keys in the TFHE
dialect. In particular, a parameterized `TFHE` circuit can now be
given as input, with any combination of (semantically valid) of
ks/bs/woppbs mixing different secret keys, and compiled down to a
valid executable function, with server keys properly looked up.

Secret keys are now stateful objects which can be:
-> none/unparameterized (syntax `sk?`): The keys are in state after
   the lowering from the `FHE` dialect.
-> parameterized (syntax `sk<identifier, polysize, dimension>`): The
   keys were parameterized, either by user or by the optimizer. The
   `identifier` field can be used to disambiguate two keys with same
   `polysize` and `dimension`.
-> normalized (syntax `sk[index]<polysize, dimension>`): The keys were
   attached to their index in the list of keys in the runtime context.

The _normalization_ of key indices also acts on the ksk, bsk and pksk,
which are given indices in the same spirit now.

Finally, in order to allow parameterized `TFHE` circuit to be given
as input and compiled down to executable functions, we added a way to
pass the encodings that are used to encode/decode the circuit
inputs/outputs. In the case of a compilation from the `FHE` dialect,
those informations are automatically extracted from the higher level
informations available in this dialect.
2023-04-14 15:01:18 +02:00

379 lines
14 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <llvm/ADT/SmallVector.h>
#include <map>
#include <optional>
#include <unordered_set>
#include <variant>
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/SmallSet.h>
#include <llvm/Support/Error.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "concrete/curves.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/TFHECircuitKeys.h"
#include "concretelang/Support/Variants.h"
#include "llvm/Config/abi-breaking.h"
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
using ::concretelang::clientlib::Variance;
const auto keyFormat = concrete::BINARY;
llvm::Expected<CircuitGate>
generateGate(mlir::Type type, encodings::Encoding encoding,
concrete::SecurityCurve curve,
std::optional<CRTDecomposition> maybeCrt) {
auto scalarVisitor = overloaded{
[&](encodings::EncryptedIntegerScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
TFHE::GLWESecretKeyNormalized normKey;
if (type.isa<RankedTensorType>()) {
normKey = type.cast<RankedTensorType>()
.getElementType()
.cast<TFHE::GLWECipherTextType>()
.getKey()
.getNormalized()
.value();
} else {
normKey = type.cast<TFHE::GLWECipherTextType>()
.getKey()
.getNormalized()
.value();
}
size_t width = enc.width;
bool isSigned = enc.isSigned;
uint64_t size = 0;
std::vector<int64_t> dims{};
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ isSigned,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ isSigned,
},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::EncryptedChunkedIntegerScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
auto tensorType = type.cast<mlir::RankedTensorType>();
auto glweType =
tensorType.getElementType().cast<TFHE::GLWECipherTextType>();
auto normKey = glweType.getKey().getNormalized().value();
size_t width = enc.chunkSize;
assert(enc.width % enc.chunkWidth == 0);
uint64_t size = enc.width / enc.chunkWidth;
bool isSigned = enc.isSigned;
std::vector<int64_t> dims{
(int64_t)size,
};
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
CRTDecomposition crt = maybeCrt.value_or(std::vector<int64_t>());
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
/*.sign = */ isSigned,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ isSigned,
},
/*.chunkInfo = */
std::optional<ChunkInfo>(
{(unsigned int)enc.chunkSize, (unsigned int)enc.chunkWidth}),
};
},
[&](encodings::EncryptedBoolScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
auto glweType = type.cast<TFHE::GLWECipherTextType>();
auto normKey = glweType.getKey().getNormalized().value();
size_t width =
mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
LweSecretKeyID secretKeyID = normKey.index;
Variance variance = curve.getVariance(1, normKey.dimension, 64);
return CircuitGate{
/* .encryption = */ std::optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ std::vector<int64_t>(),
/* .sign = */ false,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ false,
},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::PlaintextScalarEncoding enc)
-> llvm::Expected<CircuitGate> {
size_t width = type.getIntOrFloatBitWidth();
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ std::nullopt,
};
},
[&](encodings::IndexScalarEncoding enc) -> llvm::Expected<CircuitGate> {
// TODO - The index type is dependant of the target architecture,
// so actually we assume we target only 64 bits, we need to have
// some the size of the word of the target system.
size_t width = 64;
bool sign = type.isSignedInteger();
return CircuitGate{
/*.encryption = */ std::nullopt,
/*.shape = */
{/*.width = */ width,
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
/*.chunkInfo = */ std::nullopt,
};
},
[&](auto enc) -> llvm::Expected<CircuitGate> {
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape there",
llvm::inconvertibleErrorCode());
}};
auto genericVisitor = overloaded{
[&](encodings::ScalarEncoding enc) -> llvm::Expected<CircuitGate> {
return std::visit(scalarVisitor, enc);
},
[&](encodings::TensorEncoding enc) -> llvm::Expected<CircuitGate> {
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
auto scalarGate = generateGate(tensor.getElementType(),
enc.scalarEncoding, curve, maybeCrt);
if (auto err = scalarGate.takeError()) {
return std::move(err);
}
if (maybeCrt.has_value()) {
// When using crt, the last dimension of the tensor is for the members
// of the decomposition. It should not be used.
scalarGate->shape.dimensions =
tensor.getShape().take_front(tensor.getShape().size() - 1).vec();
} else {
scalarGate->shape.dimensions = tensor.getShape().vec();
}
scalarGate->shape.size = 1;
for (auto dimSize : scalarGate->shape.dimensions) {
scalarGate->shape.size *= dimSize;
}
return scalarGate;
},
[&](auto enc) -> llvm::Expected<CircuitGate> {
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape here",
llvm::inconvertibleErrorCode());
}};
return std::visit(genericVisitor, encoding);
}
template <typename V> struct HashValComparator {
bool operator()(const V &lhs, const V &rhs) const {
return hash_value(lhs) < hash_value(rhs);
}
};
template <typename V> using Set = llvm::SmallSet<V, 10, HashValComparator<V>>;
void extractCircuitKeys(ClientParameters &output,
TFHE::TFHECircuitKeys circuitKeys,
concrete::SecurityCurve curve) {
// Pushing secret keys
for (auto sk : circuitKeys.secretKeys) {
clientlib::LweSecretKeyParam skParam;
skParam.dimension = sk.getNormalized().value().dimension;
output.secretKeys.push_back(skParam);
}
// Pushing keyswitch keys
for (auto ksk : circuitKeys.keyswitchKeys) {
clientlib::KeyswitchKeyParam kskParam;
auto inputNormKey = ksk.getInputKey().getNormalized().value();
auto outputNormKey = ksk.getOutputKey().getNormalized().value();
kskParam.inputSecretKeyID = inputNormKey.index;
kskParam.outputSecretKeyID = outputNormKey.index;
kskParam.level = ksk.getLevels();
kskParam.baseLog = ksk.getBaseLog();
kskParam.variance = curve.getVariance(1, outputNormKey.dimension, 64);
output.keyswitchKeys.push_back(kskParam);
}
// Pushing bootstrap keys
for (auto bsk : circuitKeys.bootstrapKeys) {
clientlib::BootstrapKeyParam bskParam;
auto inputNormKey = bsk.getInputKey().getNormalized().value();
auto outputNormKey = bsk.getOutputKey().getNormalized().value();
bskParam.inputSecretKeyID = inputNormKey.index;
bskParam.outputSecretKeyID = outputNormKey.index;
bskParam.level = bsk.getLevels();
bskParam.baseLog = bsk.getBaseLog();
bskParam.glweDimension = bsk.getGlweDim();
bskParam.polynomialSize = bsk.getPolySize();
bskParam.variance =
curve.getVariance(bsk.getGlweDim(), bsk.getPolySize(), 64);
bskParam.inputLweDimension = inputNormKey.dimension;
output.bootstrapKeys.push_back(bskParam);
}
// Pushing circuit packing keyswitch keys
for (auto pksk : circuitKeys.packingKeyswitchKeys) {
clientlib::PackingKeyswitchKeyParam pkskParam;
auto inputNormKey = pksk.getInputKey().getNormalized().value();
auto outputNormKey = pksk.getOutputKey().getNormalized().value();
pkskParam.inputSecretKeyID = inputNormKey.index;
pkskParam.outputSecretKeyID = outputNormKey.index;
pkskParam.level = pksk.getLevels();
pkskParam.baseLog = pksk.getBaseLog();
pkskParam.glweDimension = pksk.getGlweDim();
pkskParam.polynomialSize = pksk.getOutputPolySize();
pkskParam.inputLweDimension = inputNormKey.dimension;
pkskParam.variance =
curve.getVariance(outputNormKey.dimension, outputNormKey.polySize, 64);
output.packingKeyswitchKeys.push_back(pkskParam);
}
}
llvm::Expected<std::monostate>
extractCircuitGates(ClientParameters &output, mlir::func::FuncOp funcOp,
encodings::CircuitEncodings encodings,
concrete::SecurityCurve curve,
std::optional<CRTDecomposition> maybeCrt) {
// Create input and output circuit gate parameters
auto funcType = funcOp.getFunctionType();
for (auto val : llvm::zip(funcType.getInputs(), encodings.inputEncodings)) {
auto ty = std::get<0>(val);
auto encoding = std::get<1>(val);
auto gate = generateGate(ty, encoding, curve, maybeCrt);
if (auto err = gate.takeError()) {
return std::move(err);
}
output.inputs.push_back(gate.get());
}
for (auto val : llvm::zip(funcType.getResults(), encodings.outputEncodings)) {
auto ty = std::get<0>(val);
auto encoding = std::get<1>(val);
auto gate = generateGate(ty, encoding, curve, maybeCrt);
if (auto err = gate.takeError()) {
return std::move(err);
}
output.outputs.push_back(gate.get());
}
return std::monostate();
}
llvm::Expected<ClientParameters>
createClientParametersFromTFHE(mlir::ModuleOp module,
llvm::StringRef functionName, int bitsOfSecurity,
encodings::CircuitEncodings encodings,
std::optional<CRTDecomposition> maybeCrt) {
// Check that security curves exist
const auto curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
if (curve == nullptr) {
return StreamStringError("Cannot find security curves for ")
<< bitsOfSecurity << "bits";
}
// Check that the specified function can be found
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError(
"cannot find the function for generate client parameters: ")
<< functionName;
}
// Create client parameters
ClientParameters output;
output.functionName = (std::string)functionName;
// We extract the keys of the circuit
auto circuitKeys = TFHE::extractCircuitKeys(module);
// We extract all the keys used in the circuit
extractCircuitKeys(output, circuitKeys, *curve);
// We generate the gates for the inputs aud outputs
if (auto err =
extractCircuitGates(output, *funcOp, encodings, *curve, maybeCrt)
.takeError()) {
return std::move(err);
}
return output;
}
} // namespace concretelang
} // namespace mlir