// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Support/TFHECircuitKeys.h" #include "concretelang/Dialect/TFHE/IR/TFHEAttrs.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEParameters.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "llvm/ADT/SmallVector.h" #include #include // Faster than using a full fledged hash-set for small sets, and the array can // be recovered right away. template struct SmallSet { llvm::SmallVector vector; void insert(V val) { for (auto vectorVal : vector) { if (vectorVal == val) { return; } } vector.push_back(val); } }; template std::optional vectorIndex(llvm::SmallVector vector, V val) { for (size_t i = 0; i < vector.size(); i++) { auto potentialVal = vector[i]; if (potentialVal == val) { return i; } } return std::nullopt; } namespace mlir { namespace concretelang { namespace TFHE { template llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const mlir::SmallVector vect) { OS << "["; for (auto v : vect) { OS << v << ","; } OS << "]"; return OS; } llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const TFHECircuitKeys cks) { OS << "TFHECircuitKeys{\n" << " secretKeys:" << cks.secretKeys << "\n" << " keyswitchKeys:" << cks.keyswitchKeys << "\n" << " bootstrapKeys:" << cks.bootstrapKeys << "\n" << " packingKeyswitchKeys:" << cks.packingKeyswitchKeys << "\n" "}"; return OS; } TFHECircuitKeys extractCircuitKeys(mlir::ModuleOp moduleOp) { // Gathering circuit secret keys SmallSet secretKeys; auto tryInsert = [&](mlir::Type type) { if (auto glweType = type.dyn_cast()) { secretKeys.insert(glweType.getKey()); } else if (auto tensorType = type.dyn_cast()) { if (auto elementType = tensorType.getElementType() .dyn_cast()) { secretKeys.insert(elementType.getKey()); } } }; moduleOp->walk([&](mlir::Operation *op) { for (auto operand : op->getOperands()) { tryInsert(operand.getType()); } for (auto result : op->getResults()) { tryInsert(result.getType()); } }); moduleOp->walk([&](mlir::func::FuncOp op) { for (auto argType : op.getArgumentTypes()) { tryInsert(argType); } for (auto resultType : op.getResultTypes()) { tryInsert(resultType); } }); // Gathering circuit keyswitch keys SmallSet keyswitchKeys; moduleOp->walk([&](TFHE::KeySwitchGLWEOp op) { keyswitchKeys.insert(op.getKeyAttr()); secretKeys.insert(op.getKeyAttr().getInputKey()); secretKeys.insert(op.getKeyAttr().getOutputKey()); }); // Gathering circuit bootstrap keys SmallSet bootstrapKeys; moduleOp->walk([&](TFHE::BootstrapGLWEOp op) { bootstrapKeys.insert(op.getKeyAttr()); secretKeys.insert(op.getKeyAttr().getInputKey()); secretKeys.insert(op.getKeyAttr().getOutputKey()); }); // Gathering circuit packing keyswitch keys SmallSet packingKeyswitchKeys; moduleOp->walk([&](TFHE::WopPBSGLWEOp op) { keyswitchKeys.insert(op.getKskAttr()); secretKeys.insert(op.getKskAttr().getInputKey()); secretKeys.insert(op.getKskAttr().getOutputKey()); bootstrapKeys.insert(op.getBskAttr()); secretKeys.insert(op.getBskAttr().getInputKey()); secretKeys.insert(op.getBskAttr().getOutputKey()); packingKeyswitchKeys.insert(op.getPkskAttr()); secretKeys.insert(op.getPkskAttr().getInputKey()); secretKeys.insert(op.getPkskAttr().getOutputKey()); }); return TFHECircuitKeys{secretKeys.vector, bootstrapKeys.vector, keyswitchKeys.vector, packingKeyswitchKeys.vector}; } std::optional TFHE::TFHECircuitKeys::getSecretKeyIndex(TFHE::GLWESecretKey key) { return vectorIndex(this->secretKeys, key); } std::optional TFHE::TFHECircuitKeys::getBootstrapKeyIndex(TFHE::GLWEBootstrapKeyAttr key) { return vectorIndex(this->bootstrapKeys, key); } std::optional TFHE::TFHECircuitKeys::getKeyswitchKeyIndex(TFHE::GLWEKeyswitchKeyAttr key) { return vectorIndex(this->keyswitchKeys, key); } std::optional TFHE::TFHECircuitKeys::getPackingKeyswitchKeyIndex( TFHE::GLWEPackingKeyswitchKeyAttr key) { return vectorIndex(this->packingKeyswitchKeys, key); } } // namespace TFHE } // namespace concretelang } // namespace mlir