mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
This commit brings support for multiple secret keys in the TFHE dialect. In particular, a parameterized `TFHE` circuit can now be given as input, with any combination of (semantically valid) of ks/bs/woppbs mixing different secret keys, and compiled down to a valid executable function, with server keys properly looked up. Secret keys are now stateful objects which can be: -> none/unparameterized (syntax `sk?`): The keys are in state after the lowering from the `FHE` dialect. -> parameterized (syntax `sk<identifier, polysize, dimension>`): The keys were parameterized, either by user or by the optimizer. The `identifier` field can be used to disambiguate two keys with same `polysize` and `dimension`. -> normalized (syntax `sk[index]<polysize, dimension>`): The keys were attached to their index in the list of keys in the runtime context. The _normalization_ of key indices also acts on the ksk, bsk and pksk, which are given indices in the same spirit now. Finally, in order to allow parameterized `TFHE` circuit to be given as input and compiled down to executable functions, we added a way to pass the encodings that are used to encode/decode the circuit inputs/outputs. In the case of a compilation from the `FHE` dialect, those informations are automatically extracted from the higher level informations available in this dialect.
157 lines
4.9 KiB
C++
157 lines
4.9 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Support/TFHECircuitKeys.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHEParameters.h"
|
|
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include <llvm/ADT/SmallVector.h>
|
|
#include <optional>
|
|
|
|
// Faster than using a full fledged hash-set for small sets, and the array can
|
|
// be recovered right away.
|
|
template <typename V> struct SmallSet {
|
|
llvm::SmallVector<V, 10> vector;
|
|
|
|
void insert(V val) {
|
|
for (auto vectorVal : vector) {
|
|
if (vectorVal == val) {
|
|
return;
|
|
}
|
|
}
|
|
vector.push_back(val);
|
|
}
|
|
};
|
|
|
|
template <typename V, unsigned N>
|
|
std::optional<size_t> vectorIndex(llvm::SmallVector<V, N> vector, V val) {
|
|
for (size_t i = 0; i < vector.size(); i++) {
|
|
auto potentialVal = vector[i];
|
|
if (potentialVal == val) {
|
|
return i;
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
namespace TFHE {
|
|
|
|
template <typename V, unsigned int N>
|
|
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
|
const mlir::SmallVector<V, N> vect) {
|
|
OS << "[";
|
|
for (auto v : vect) {
|
|
OS << v << ",";
|
|
}
|
|
OS << "]";
|
|
return OS;
|
|
}
|
|
|
|
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
|
const TFHECircuitKeys cks) {
|
|
|
|
OS << "TFHECircuitKeys{\n"
|
|
<< " secretKeys:" << cks.secretKeys << "\n"
|
|
<< " keyswitchKeys:" << cks.keyswitchKeys << "\n"
|
|
<< " bootstrapKeys:" << cks.bootstrapKeys << "\n"
|
|
<< " packingKeyswitchKeys:" << cks.packingKeyswitchKeys
|
|
<< "\n"
|
|
"}";
|
|
return OS;
|
|
}
|
|
|
|
TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp) {
|
|
// Gathering circuit secret keys
|
|
SmallSet<TFHE::GLWESecretKey> secretKeys;
|
|
auto tryInsert = [&](mlir::Type type) {
|
|
if (auto glweType = type.dyn_cast<TFHE::GLWECipherTextType>()) {
|
|
secretKeys.insert(glweType.getKey());
|
|
} else if (auto tensorType = type.dyn_cast<mlir::RankedTensorType>()) {
|
|
if (auto elementType = tensorType.getElementType()
|
|
.dyn_cast<TFHE::GLWECipherTextType>()) {
|
|
secretKeys.insert(elementType.getKey());
|
|
}
|
|
}
|
|
};
|
|
moduleOp->walk([&](mlir::Operation *op) {
|
|
for (auto operand : op->getOperands()) {
|
|
tryInsert(operand.getType());
|
|
}
|
|
for (auto result : op->getResults()) {
|
|
tryInsert(result.getType());
|
|
}
|
|
});
|
|
moduleOp->walk([&](mlir::func::FuncOp op) {
|
|
for (auto argType : op.getArgumentTypes()) {
|
|
tryInsert(argType);
|
|
}
|
|
for (auto resultType : op.getResultTypes()) {
|
|
tryInsert(resultType);
|
|
}
|
|
});
|
|
|
|
// Gathering circuit keyswitch keys
|
|
SmallSet<TFHE::GLWEKeyswitchKeyAttr> keyswitchKeys;
|
|
moduleOp->walk([&](TFHE::KeySwitchGLWEOp op) {
|
|
keyswitchKeys.insert(op.getKeyAttr());
|
|
secretKeys.insert(op.getKeyAttr().getInputKey());
|
|
secretKeys.insert(op.getKeyAttr().getOutputKey());
|
|
});
|
|
|
|
// Gathering circuit bootstrap keys
|
|
SmallSet<TFHE::GLWEBootstrapKeyAttr> bootstrapKeys;
|
|
moduleOp->walk([&](TFHE::BootstrapGLWEOp op) {
|
|
bootstrapKeys.insert(op.getKeyAttr());
|
|
secretKeys.insert(op.getKeyAttr().getInputKey());
|
|
secretKeys.insert(op.getKeyAttr().getOutputKey());
|
|
});
|
|
|
|
// Gathering circuit packing keyswitch keys
|
|
SmallSet<TFHE::GLWEPackingKeyswitchKeyAttr> packingKeyswitchKeys;
|
|
moduleOp->walk([&](TFHE::WopPBSGLWEOp op) {
|
|
keyswitchKeys.insert(op.getKskAttr());
|
|
secretKeys.insert(op.getKskAttr().getInputKey());
|
|
secretKeys.insert(op.getKskAttr().getOutputKey());
|
|
bootstrapKeys.insert(op.getBskAttr());
|
|
secretKeys.insert(op.getBskAttr().getInputKey());
|
|
secretKeys.insert(op.getBskAttr().getOutputKey());
|
|
packingKeyswitchKeys.insert(op.getPkskAttr());
|
|
secretKeys.insert(op.getPkskAttr().getInputKey());
|
|
secretKeys.insert(op.getPkskAttr().getOutputKey());
|
|
});
|
|
|
|
return TFHECircuitKeys{secretKeys.vector, bootstrapKeys.vector,
|
|
keyswitchKeys.vector, packingKeyswitchKeys.vector};
|
|
}
|
|
|
|
std::optional<uint64_t>
|
|
TFHE::TFHECircuitKeys::getSecretKeyIndex(TFHE::GLWESecretKey key) {
|
|
return vectorIndex(this->secretKeys, key);
|
|
}
|
|
|
|
std::optional<uint64_t>
|
|
TFHE::TFHECircuitKeys::getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key) {
|
|
return vectorIndex(this->bootstrapKeys, key);
|
|
}
|
|
|
|
std::optional<uint64_t>
|
|
TFHE::TFHECircuitKeys::getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key) {
|
|
return vectorIndex(this->keyswitchKeys, key);
|
|
}
|
|
|
|
std::optional<uint64_t> TFHE::TFHECircuitKeys::getPackingKeyswitchKeyIndex(
|
|
TFHE::GLWEPackingKeyswitchKeyAttr key) {
|
|
return vectorIndex(this->packingKeyswitchKeys, key);
|
|
}
|
|
|
|
} // namespace TFHE
|
|
} // namespace concretelang
|
|
} // namespace mlir
|