Files
concrete/compilers/concrete-compiler/compiler/lib/Support/TFHECircuitKeys.cpp
aPere3 cacffadbd2 feat(compiler): add support for multikey
This commit brings support for multiple secret keys in the TFHE
dialect. In particular, a parameterized `TFHE` circuit can now be
given as input, with any combination of (semantically valid) of
ks/bs/woppbs mixing different secret keys, and compiled down to a
valid executable function, with server keys properly looked up.

Secret keys are now stateful objects which can be:
-> none/unparameterized (syntax `sk?`): The keys are in state after
   the lowering from the `FHE` dialect.
-> parameterized (syntax `sk<identifier, polysize, dimension>`): The
   keys were parameterized, either by user or by the optimizer. The
   `identifier` field can be used to disambiguate two keys with same
   `polysize` and `dimension`.
-> normalized (syntax `sk[index]<polysize, dimension>`): The keys were
   attached to their index in the list of keys in the runtime context.

The _normalization_ of key indices also acts on the ksk, bsk and pksk,
which are given indices in the same spirit now.

Finally, in order to allow parameterized `TFHE` circuit to be given
as input and compiled down to executable functions, we added a way to
pass the encodings that are used to encode/decode the circuit
inputs/outputs. In the case of a compilation from the `FHE` dialect,
those informations are automatically extracted from the higher level
informations available in this dialect.
2023-04-14 15:01:18 +02:00

157 lines
4.9 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Support/TFHECircuitKeys.h"
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "llvm/ADT/SmallVector.h"
#include <llvm/ADT/SmallVector.h>
#include <optional>
// Faster than using a full fledged hash-set for small sets, and the array can
// be recovered right away.
template <typename V> struct SmallSet {
llvm::SmallVector<V, 10> vector;
void insert(V val) {
for (auto vectorVal : vector) {
if (vectorVal == val) {
return;
}
}
vector.push_back(val);
}
};
template <typename V, unsigned N>
std::optional<size_t> vectorIndex(llvm::SmallVector<V, N> vector, V val) {
for (size_t i = 0; i < vector.size(); i++) {
auto potentialVal = vector[i];
if (potentialVal == val) {
return i;
}
}
return std::nullopt;
}
namespace mlir {
namespace concretelang {
namespace TFHE {
template <typename V, unsigned int N>
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const mlir::SmallVector<V, N> vect) {
OS << "[";
for (auto v : vect) {
OS << v << ",";
}
OS << "]";
return OS;
}
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const TFHECircuitKeys cks) {
OS << "TFHECircuitKeys{\n"
<< " secretKeys:" << cks.secretKeys << "\n"
<< " keyswitchKeys:" << cks.keyswitchKeys << "\n"
<< " bootstrapKeys:" << cks.bootstrapKeys << "\n"
<< " packingKeyswitchKeys:" << cks.packingKeyswitchKeys
<< "\n"
"}";
return OS;
}
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp) {
// Gathering circuit secret keys
SmallSet<TFHE::GLWESecretKey> secretKeys;
auto tryInsert = [&](mlir::Type type) {
if (auto glweType = type.dyn_cast<TFHE::GLWECipherTextType>()) {
secretKeys.insert(glweType.getKey());
} else if (auto tensorType = type.dyn_cast<mlir::RankedTensorType>()) {
if (auto elementType = tensorType.getElementType()
.dyn_cast<TFHE::GLWECipherTextType>()) {
secretKeys.insert(elementType.getKey());
}
}
};
moduleOp->walk([&](mlir::Operation *op) {
for (auto operand : op->getOperands()) {
tryInsert(operand.getType());
}
for (auto result : op->getResults()) {
tryInsert(result.getType());
}
});
moduleOp->walk([&](mlir::func::FuncOp op) {
for (auto argType : op.getArgumentTypes()) {
tryInsert(argType);
}
for (auto resultType : op.getResultTypes()) {
tryInsert(resultType);
}
});
// Gathering circuit keyswitch keys
SmallSet<TFHE::GLWEKeyswitchKeyAttr> keyswitchKeys;
moduleOp->walk([&](TFHE::KeySwitchGLWEOp op) {
keyswitchKeys.insert(op.getKeyAttr());
secretKeys.insert(op.getKeyAttr().getInputKey());
secretKeys.insert(op.getKeyAttr().getOutputKey());
});
// Gathering circuit bootstrap keys
SmallSet<TFHE::GLWEBootstrapKeyAttr> bootstrapKeys;
moduleOp->walk([&](TFHE::BootstrapGLWEOp op) {
bootstrapKeys.insert(op.getKeyAttr());
secretKeys.insert(op.getKeyAttr().getInputKey());
secretKeys.insert(op.getKeyAttr().getOutputKey());
});
// Gathering circuit packing keyswitch keys
SmallSet<TFHE::GLWEPackingKeyswitchKeyAttr> packingKeyswitchKeys;
moduleOp->walk([&](TFHE::WopPBSGLWEOp op) {
keyswitchKeys.insert(op.getKskAttr());
secretKeys.insert(op.getKskAttr().getInputKey());
secretKeys.insert(op.getKskAttr().getOutputKey());
bootstrapKeys.insert(op.getBskAttr());
secretKeys.insert(op.getBskAttr().getInputKey());
secretKeys.insert(op.getBskAttr().getOutputKey());
packingKeyswitchKeys.insert(op.getPkskAttr());
secretKeys.insert(op.getPkskAttr().getInputKey());
secretKeys.insert(op.getPkskAttr().getOutputKey());
});
return TFHECircuitKeys{secretKeys.vector, bootstrapKeys.vector,
keyswitchKeys.vector, packingKeyswitchKeys.vector};
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getSecretKeyIndex(TFHE::GLWESecretKey key) {
return vectorIndex(this->secretKeys, key);
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key) {
return vectorIndex(this->bootstrapKeys, key);
}
std::optional<uint64_t>
TFHE::TFHECircuitKeys::getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key) {
return vectorIndex(this->keyswitchKeys, key);
}
std::optional<uint64_t> TFHE::TFHECircuitKeys::getPackingKeyswitchKeyIndex(
TFHE::GLWEPackingKeyswitchKeyAttr key) {
return vectorIndex(this->packingKeyswitchKeys, key);
}
} // namespace TFHE
} // namespace concretelang
} // namespace mlir