From 7d1c43bc47b0d9a1e4d4f0f819f2ca6c8aab805d Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 28 Mar 2023 10:38:31 +0200 Subject: [PATCH] feat(compiler/multi-parameters): Create a pass to apply the multi-parameter circuit solution of the optimize --- .../TFHEGlobalParametrization/Pass.h | 3 +- .../Dialect/TFHE/Transforms/CMakeLists.txt | 8 +- .../Dialect/TFHE/Transforms/Optimization.td | 13 - .../{Optimization.h => Transforms.h} | 10 +- .../Dialect/TFHE/Transforms/Transforms.td | 20 + .../Dialect/TFHE/Transforms/CMakeLists.txt | 1 + .../Dialect/TFHE/Transforms/Optimization.cpp | 2 +- .../TFHECircuitSolutionParametrization.cpp | 509 ++++++++++++++++++ .../compiler/lib/Support/Pipeline.cpp | 2 +- .../unit_tests/concretelang/CMakeLists.txt | 1 + .../concretelang/Dialect/CMakeLists.txt | 1 + .../concretelang/Dialect/TFHE/CMakeLists.txt | 1 + .../Dialect/TFHE/Transforms/CMakeLists.txt | 9 + .../TFHE/Transforms/Parametrization.cpp | 246 +++++++++ 14 files changed, 802 insertions(+), 24 deletions(-) delete mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td rename compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/{Optimization.h => Transforms.h} (60%) create mode 100644 compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td create mode 100644 compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp create mode 100644 compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt create mode 100644 compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/Parametrization.cpp diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/TFHEGlobalParametrization/Pass.h b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/TFHEGlobalParametrization/Pass.h index 249caf4f1..36b7110d6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Conversion/TFHEGlobalParametrization/Pass.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Conversion/TFHEGlobalParametrization/Pass.h @@ -14,8 +14,7 @@ namespace mlir { namespace concretelang { /// Create a pass to inject fhe parameters to the TFHE types and operators. std::unique_ptr> -createConvertTFHEGlobalParametrizationPass( - mlir::concretelang::V0FHEContext &fheContext); +createConvertTFHEGlobalParametrizationPass(const V0Parameter parameter); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt index f19200ffa..451f0567b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS Optimization.td) -mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms) -add_public_tablegen_target(ConcretelangTFHEOptimizationPassIncGen) -add_dependencies(mlir-headers ConcretelangTFHEOptimizationPassIncGen) +set(LLVM_TARGET_DEFINITIONS Transforms.td) +mlir_tablegen(Transforms.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangTFHETransformsPassIncGen) +add_dependencies(mlir-headers ConcretelangTFHETransformsPassIncGen) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td deleted file mode 100644 index 9e48d2081..000000000 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.td +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS -#define CONCRETELANG_TFHE_OPTIMIZATION_PASS - -include "mlir/Pass/PassBase.td" - -def TFHEOptimization : Pass<"tfhe-optimization"> { - let summary = "Optimize TFHE operations"; - let constructor = "mlir::concretelang::createTFHEOptimizationPass()"; - let options = []; - let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ]; -} - -#endif diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h similarity index 60% rename from compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h rename to compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h index 031acf528..d53adfb4f 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Optimization.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.h @@ -6,15 +6,19 @@ #ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS_H #define CONCRETELANG_TFHE_OPTIMIZATION_PASS_H -#include -#include +#include "concrete-optimizer.hpp" +#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" +#include "mlir/Pass/Pass.h" #define GEN_PASS_CLASSES -#include +#include "concretelang/Dialect/TFHE/Transforms/Transforms.h.inc" namespace mlir { namespace concretelang { std::unique_ptr> createTFHEOptimizationPass(); +std::unique_ptr> + createTFHECircuitSolutionParametrizationPass( + concrete_optimizer::dag::CircuitSolution); } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td new file mode 100644 index 000000000..2609e1cce --- /dev/null +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/Transforms/Transforms.td @@ -0,0 +1,20 @@ +#ifndef CONCRETELANG_TFHE_TRANSFORMS_PASS +#define CONCRETELANG_TFHE_TRANSFORMS_PASS + +include "mlir/Pass/PassBase.td" + +def TFHEOptimization : Pass<"tfhe-optimization"> { + let summary = "Optimize TFHE operations"; + let constructor = "mlir::concretelang::createTFHEOptimizationPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ]; +} + +def TFHECircuitSolutionParametrization : Pass<"tfhe-circuit-solution-parametrization"> { + let summary = "Parametrize TFHE with a circuit solution given by the optimizer"; + let constructor = "mlir::concretelang::createTFHECircuitSolutionParametrizationPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ]; +} + +#endif diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt index eada6b25a..fc47e62a3 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library( TFHEDialectTransforms Optimization.cpp + TFHECircuitSolutionParametrization.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE DEPENDS diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp index ba8456c5f..1280d3b6f 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/Optimization.cpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include namespace mlir { diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp new file mode 100644 index 000000000..8e31bdddf --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp @@ -0,0 +1,509 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "concretelang/Dialect/TFHE/IR/TFHEOps.h" +#include "concretelang/Dialect/TFHE/Transforms/Transforms.h" +#include "concretelang/Support/Constants.h" +#include "concretelang/Support/logging.h" + +namespace mlir { +namespace concretelang { + +namespace { + +#define DEBUG(MSG) \ + if (mlir::concretelang::isVerbose()) { \ + llvm::errs() << MSG << "\n"; \ + } + +namespace TFHE = mlir::concretelang::TFHE; + +/// Optimization pass that should choose more efficient ways of performing +/// crypto operations. +class TFHECircuitSolutionParametrizationPass + : public TFHECircuitSolutionParametrizationBase< + TFHECircuitSolutionParametrizationPass> { +public: + TFHECircuitSolutionParametrizationPass( + concrete_optimizer::dag::CircuitSolution solution) + : solution(solution){}; + + void runOnOperation() override { + mlir::Operation *op = getOperation(); + op->walk([&](mlir::func::FuncOp func) { + DEBUG("apply solution: \n" << solution.dump().c_str()); + DEBUG("process func: " << func); + auto context = op->getContext(); + // Process function arguments, change type of arguments according of the + // optimizer identifier stored in the "TFHE.OId" attribute. + for (size_t i = 0; i < func.getNumArguments(); i++) { + auto arg = func.getArgument(i); + auto attr = func.getArgAttrOfType(i, "TFHE.OId"); + if (attr != nullptr) { + DEBUG("process arg = " << arg) + arg.setType(getParametrizedType(arg.getType(), attr)); + } else { + DEBUG("skip arg " << arg) + } + } + // Process operations, apply the instructions keys according of the + // optimizer identifier stored in the "TFHE.OId" + DEBUG("### Apply instruction keys"); + func.walk([&](mlir::Operation *op) { + auto attrOptimizerID = op->getAttrOfType("TFHE.OId"); + // Skip operation is no optimizer identifier + if (attrOptimizerID == nullptr) { + DEBUG("skip operation: " << op->getName()) + return; + } + DEBUG("process operation: " << *op); + auto optimizerID = attrOptimizerID.getInt(); + // Change the output type of the operation + for (auto result : op->getResults()) { + result.setType( + getParametrizedType(result.getType(), attrOptimizerID)); + } + // Set the keyswitch_key attribute + // TODO: Change ambiguous attribute name + auto attrKeyswitchKey = + op->getAttrOfType("key"); + if (attrKeyswitchKey == nullptr) { + DEBUG("no keyswitch key"); + } else { + op->setAttr("key", getKeyswitchKeyAttr(context, optimizerID)); + } + // Set boostrap_key attribute + // TODO: Change ambiguous attribute name + auto attrBootstrapKey = + op->getAttrOfType("key"); + if (attrBootstrapKey == nullptr) { + DEBUG("no bootstrap key"); + } else { + op->setAttr("key", getBootstrapKeyAttr(context, optimizerID)); + } + }); + // The keyswitch operator is an internal node of the optimizer tlu node, + // so it don't follow the same rule than the other operator on the type of + // outputs + DEBUG("### Fixup output of keyswitch") + func.walk([&](TFHE::KeySwitchGLWEOp op) { + DEBUG("process op: " << op) + auto attrKeyswitchKey = + op->getAttrOfType("key"); + assert(attrKeyswitchKey != nullptr); + auto outputKey = attrKeyswitchKey.getOutputKey(); + outputKey = GLWESecretKeyAsLWE(outputKey); + op.getResult().setType( + TFHE::GLWECipherTextType::get(context, outputKey)); + DEBUG("fixed op: " << op) + }); + // Fixup input of the boostrap operator + DEBUG("### Fixup input tlu of bootstrap") + func.walk([&](TFHE::BootstrapGLWEOp op) { + DEBUG("process op: " << op) + auto attrBootstrapKey = + op->getAttrOfType("key"); + assert(attrBootstrapKey != nullptr); + auto polySize = attrBootstrapKey.getPolySize(); + auto lutDefiningOp = op.getLookupTable().getDefiningOp(); + // Dirty fixup of the lookup table as we known the operators that can + // define it + // TODO: Do something more robust, using the GLWE type? + mlir::Builder builder(op->getContext()); + assert(lutDefiningOp != nullptr); + if (auto encodeOp = mlir::dyn_cast( + lutDefiningOp); + encodeOp != nullptr) { + encodeOp.setPolySize(polySize); + } else if (auto constantOp = + mlir::dyn_cast(lutDefiningOp)) { + // Rounded PBS case + auto denseAttr = + constantOp.getValueAttr().dyn_cast(); + auto val = denseAttr.getValues()[0]; + std::vector lut(polySize, val); + constantOp.setValueAttr(mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(lut.size(), + builder.getIntegerType(64)), + lut)); + } + op.getLookupTable().setType(mlir::RankedTensorType::get( + mlir::ArrayRef(polySize), builder.getI64Type())); + // Also fixup the bootstrap key as the TFHENormalization rely on + // GLWESecretKey structure and not on identifier + // TODO: FIXME + auto outputKey = attrBootstrapKey.getOutputKey().getParameterized(); + auto newOutputKey = TFHE::GLWESecretKey::newParameterized( + outputKey->polySize * outputKey->dimension, 1, + outputKey->identifier); + auto newAttrBootstrapKey = TFHE::GLWEBootstrapKeyAttr::get( + context, attrBootstrapKey.getInputKey(), newOutputKey, + attrBootstrapKey.getPolySize(), attrBootstrapKey.getGlweDim(), + attrBootstrapKey.getLevels(), attrBootstrapKey.getBaseLog(), -1); + op.setKeyAttr(newAttrBootstrapKey); + }); + // Fixup incompatible operators with extra conversion keys + DEBUG("### Fixup with extra conversion keys") + func.walk([&](mlir::Operation *op) { + // Skip bootstrap/keyswitch + if (mlir::isa(op) || + mlir::isa(op)) { + return; + } + auto attrOptimizerID = op->getAttrOfType("TFHE.OId"); + // Skip operation with no optimizer identifier + if (attrOptimizerID == nullptr) { + return; + } + DEBUG(" -> process op: " << *op) + // TFHE operators have only one ciphertext result + assert(op->getNumResults() == 1); + auto resType = + op->getResult(0).getType().dyn_cast(); + // For each ciphertext operands apply the extra keyswitch if found + for (const auto &p : llvm::enumerate(op->getOperands())) { + if (resType == nullptr) { + // We don't expect tensor operands to exist at this point of the + // pipeline for now, but if we happen to have some, this assert will + // break, and things will need to be changed to allow tensor ops to + // be parameterized. + // TODO: Actually this case could happens with tensor manipulation + // operators, so for now we just skip it and that should be fixed + // and tested. As the operand will not be fixed the validation of + // operators should not validate the operators. + continue; + } + auto operand = p.value(); + auto operandIdx = p.index(); + DEBUG(" -> processing operand " << operand); + auto operandType = + operand.getType().dyn_cast(); + if (operandType == nullptr) { + DEBUG(" -> skip operand, no glwe"); + continue; + } + if (operandType.getKey() == resType.getKey()) { + DEBUG(" -> skip operand, unnecessary conversion"); + continue; + } + // Lookup for the extra conversion key + auto definingOp = operand.getDefiningOp(); + if (definingOp == nullptr) { + DEBUG(" -> skip operand, no defining operator"); + continue; + } + auto definingOpOptimizerIDAttr = + definingOp->getAttrOfType("TFHE.OId"); + if (definingOpOptimizerIDAttr == nullptr) { + DEBUG(" -> cannot find optimizer id of the defining op"); + continue; + } + DEBUG(" -> get extra conversion key") + auto extraConvKey = getExtraConversionKeyAttr( + context, definingOpOptimizerIDAttr.getInt(), operandType); + if (extraConvKey == nullptr) { + DEBUG(" -> extra conversion key, not found") + assert(false); + } + mlir::IRRewriter rewriter(context); + rewriter.setInsertionPoint(op); + auto newKSK = rewriter.create( + definingOp->getLoc(), resType, operand, extraConvKey); + DEBUG("create extra conversion keyswitch: " << newKSK); + op->setOperand(operandIdx, newKSK); + } + }); + // Propagate types on non parametrized operators + fixupNonParametrizedOps(func); + // Fixup the function signature + fixupFunctionSignature(func); + // Remove optimizer identifiers + removeOptimizerIdentifiers(func); + }); + } + + static mlir::Type getParametrizedType(mlir::Type originalType, + TFHE::GLWECipherTextType newGlwe) { + if (auto oldGlwe = originalType.dyn_cast(); + oldGlwe != nullptr) { + assert(oldGlwe.getKey().isNone()); + return newGlwe; + } else if (auto oldTensor = originalType.dyn_cast(); + oldTensor != nullptr) { + auto oldGlwe = + oldTensor.getElementType().dyn_cast(); + assert(oldGlwe != nullptr); + assert(oldGlwe.getKey().isNone()); + return mlir::RankedTensorType::get(oldTensor.getShape(), newGlwe); + } + assert(false); + } + + mlir::Type getParametrizedType(mlir::Type originalType, + mlir::IntegerAttr optimizerAttrID) { + auto context = originalType.getContext(); + auto newGlwe = + getOutputLWECipherTextType(context, optimizerAttrID.getInt()); + return getParametrizedType(originalType, newGlwe); + } + + static TFHE::GLWECipherTextType getGlweTypeFromType(mlir::Type type) { + if (auto glwe = type.dyn_cast(); + glwe != nullptr) { + return glwe; + } else if (auto tensor = type.dyn_cast(); + tensor != nullptr) { + auto glwe = tensor.getElementType().dyn_cast(); + if (glwe == nullptr) { + return nullptr; + } + return glwe; + } + return nullptr; + } + + // Return the + static TFHE::GLWECipherTextType + getParametrizedGlweTypeFromType(mlir::Type type) { + auto glwe = getGlweTypeFromType(type); + if (glwe != nullptr && glwe.getKey().isParameterized()) { + return glwe; + } + return nullptr; + } + + // Returns true if the type is or contains a glwe type with a none key. + static bool isNoneGlweType(mlir::Type type) { + auto glwe = getGlweTypeFromType(type); + return glwe != nullptr && glwe.getKey().isNone(); + } + + static void + fixupNonParametrizedOp(mlir::Operation *op, + TFHE::GLWECipherTextType parametrizedGlweType) { + DEBUG(" START Fixup {" << *op) + for (auto result : op->getResults()) { + if (isNoneGlweType(result.getType())) { + DEBUG(" -> Fixing result " << result) + result.setType( + getParametrizedType(result.getType(), parametrizedGlweType)); + DEBUG(" -> Fixed result " << result) + // Recurse on all users of the fixed result + for (auto user : result.getUsers()) { + DEBUG(" -> Propagate on user " << *user) + fixupNonParametrizedOp(user, parametrizedGlweType); + } + } + } + // Recursively fixup producer of op operands + mlir::Block *parentBlock = nullptr; + for (auto operand : op->getOperands()) { + if (isNoneGlweType(operand.getType())) { + DEBUG(" -> Propagate on operand " << operand.getType()) + if (auto opResult = operand.dyn_cast(); + opResult != nullptr) { + fixupNonParametrizedOp(opResult.getOwner(), parametrizedGlweType); + continue; + } + if (auto blockArg = operand.dyn_cast(); + blockArg != nullptr) { + DEBUG(" -> Fixing block arg " << blockArg) + blockArg.setType( + getParametrizedType(blockArg.getType(), parametrizedGlweType)); + for (auto users : blockArg.getUsers()) { + fixupNonParametrizedOp(users, parametrizedGlweType); + } + auto blockOwner = blockArg.getOwner(); + if (blockOwner->isEntryBlock()) { + DEBUG(" -> Will propagate on parent op " + << blockOwner->getParentOp()); + assert(parentBlock == blockOwner || parentBlock == nullptr); + parentBlock = blockOwner; + } + continue; + } + // An mlir::Value should always be an OpResult or a BlockArgument + assert(false); + } + } + DEBUG(" } END Fixup") + if (parentBlock != nullptr) { + fixupNonParametrizedOp(parentBlock->getParentOp(), parametrizedGlweType); + } + } + + static void fixupNonParametrizedOps(mlir::func::FuncOp func) { + DEBUG("### Fixup non parametrized ops of function " << func.getSymName()) + // Lookup all operators that uses function arguments + for (const auto arg : func.getArguments()) { + auto parametrizedGlweType = + getParametrizedGlweTypeFromType(arg.getType()); + if (parametrizedGlweType != nullptr) { + DEBUG(" -> Fixup uses of arg " << arg) + // The argument is glwe, so propagate the glwe parametrization to all + // operators which use it + for (auto userOp : arg.getUsers()) { + fixupNonParametrizedOp(userOp, parametrizedGlweType); + } + } + } + // Fixup all operators that take at least a parametrized glwe and produce an + // none glwe + func.walk([&](mlir::Operation *op) { + for (auto operand : op->getOperands()) { + auto parametrizedGlweType = + getParametrizedGlweTypeFromType(operand.getType()); + if (parametrizedGlweType != nullptr) { + // An operand is a parametrized glwe + for (auto result : op->getResults()) { + if (isNoneGlweType(result.getType())) { + DEBUG(" -> Fixup illegal op " << *op) + fixupNonParametrizedOp(op, parametrizedGlweType); + return; + } + } + } + } + }); + } + + static void removeOptimizerIdentifiers(mlir::func::FuncOp func) { + for (size_t i = 0; i < func.getNumArguments(); i++) { + func.removeArgAttr(i, "TFHE.OId"); + } + func.walk([&](mlir::Operation *op) { op->removeAttr("TFHE.OId"); }); + } + + static void fixupFunctionSignature(mlir::func::FuncOp func) { + mlir::SmallVector inputs; + mlir::SmallVector outputs; + // Set inputs by looking actual arguments types + for (auto arg : func.getArguments()) { + inputs.push_back(arg.getType()); + } + // Look for return to set the outputs + func.walk([&](mlir::func::ReturnOp returnOp) { + // TODO: multiple return op + for (auto output : returnOp->getOperandTypes()) { + outputs.push_back(output); + } + }); + auto funcType = + mlir::FunctionType::get(func->getContext(), inputs, outputs); + func.setFunctionType(funcType); + } + + const concrete_optimizer::dag::InstructionKeys & + getInstructionKey(size_t optimizerID) { + DEBUG("lookup instruction key: #" << optimizerID); + return solution.instructions_keys[optimizerID]; + } + + const TFHE::GLWESecretKey GLWESecretKeyAsLWE(TFHE::GLWESecretKey key) { + auto keyP = key.getParameterized(); + assert(keyP.has_value()); + return TFHE::GLWESecretKey::newParameterized( + keyP->polySize * keyP->dimension, 1, keyP->identifier); + } + + const TFHE::GLWESecretKey + toGLWESecretKey(concrete_optimizer::dag::SecretLweKey key) { + return TFHE::GLWESecretKey::newParameterized( + key.glwe_dimension, key.polynomial_size, key.identifier); + } + + const TFHE::GLWESecretKey + toLWESecretKey(concrete_optimizer::dag::SecretLweKey key) { + return TFHE::GLWESecretKey::newParameterized( + key.glwe_dimension * key.polynomial_size, 1, key.identifier); + } + + const TFHE::GLWESecretKey getLWESecretKey(size_t keyID) { + DEBUG("lookup secret key: #" << keyID); + auto key = solution.circuit_keys.secret_keys[keyID]; + assert(keyID == key.identifier); + return toLWESecretKey(key); + } + + const TFHE::GLWESecretKey getInputLWESecretKey(size_t optimizerID) { + auto keyID = getInstructionKey(optimizerID).input_key; + return getLWESecretKey(keyID); + } + + const TFHE::GLWESecretKey getOutputLWESecretKey(size_t optimizerID) { + auto keyID = getInstructionKey(optimizerID).output_key; + return getLWESecretKey(keyID); + } + + const TFHE::GLWEKeyswitchKeyAttr + getKeyswitchKeyAttr(mlir::MLIRContext *context, size_t optimizerID) { + auto keyID = getInstructionKey(optimizerID).tlu_keyswitch_key; + DEBUG("lookup keyswicth key: #" << keyID); + auto key = solution.circuit_keys.keyswitch_keys[keyID]; + return TFHE::GLWEKeyswitchKeyAttr::get( + context, toLWESecretKey(key.input_key), toLWESecretKey(key.output_key), + key.ks_decomposition_parameter.level, + key.ks_decomposition_parameter.log2_base, -1); + } + + const TFHE::GLWEKeyswitchKeyAttr + getExtraConversionKeyAttr(mlir::MLIRContext *context, size_t optimizerID, + TFHE::GLWECipherTextType operandType) { + DEBUG("get extra conversion key for " << operandType); + auto instructionKey = getInstructionKey(optimizerID); + for (const auto &convKSKID : instructionKey.extra_conversion_keys) { + DEBUG("try extra conversion keyswitch #" << convKSKID) + auto convKSK = solution.circuit_keys.conversion_keyswitch_keys[convKSKID]; + auto key = operandType.getKey(); + assert(key.isParameterized()); + if (operandType.getKey().getParameterized().value().identifier == + convKSK.input_key.identifier) { + return TFHE::GLWEKeyswitchKeyAttr::get( + context, toLWESecretKey(convKSK.input_key), + toLWESecretKey(convKSK.output_key), + convKSK.ks_decomposition_parameter.level, + convKSK.ks_decomposition_parameter.log2_base, -1); + } + } + DEBUG("!!! extra conversion key not found"); + return nullptr; + } + + const TFHE::GLWEBootstrapKeyAttr + getBootstrapKeyAttr(mlir::MLIRContext *context, size_t optimizerID) { + auto keyID = getInstructionKey(optimizerID).tlu_bootstrap_key; + DEBUG("lookup bootstrap key: #" << keyID); + auto key = solution.circuit_keys.bootstrap_keys[keyID]; + return TFHE::GLWEBootstrapKeyAttr::get( + context, toLWESecretKey(key.input_key), toGLWESecretKey(key.output_key), + key.output_key.polynomial_size, key.output_key.glwe_dimension, + key.br_decomposition_parameter.level, + key.br_decomposition_parameter.log2_base, -1); + } + + const TFHE::GLWECipherTextType + getOutputLWECipherTextType(mlir::MLIRContext *context, size_t optimizerID) { + auto outputKey = getOutputLWESecretKey(optimizerID); + return TFHE::GLWECipherTextType::get(context, outputKey); + } + +private: + concrete_optimizer::dag::CircuitSolution solution; +}; + +} // end anonymous namespace + +std::unique_ptr> +createTFHECircuitSolutionParametrizationPass( + concrete_optimizer::dag::CircuitSolution solution) { + return std::make_unique(solution); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index bafb9ea82..aaa24a06c 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -41,7 +41,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/CMakeLists.txt b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/CMakeLists.txt index 56f2880f9..0487eb329 100644 --- a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(ClientLib) add_subdirectory(SDFG) add_subdirectory(TestLib) add_subdirectory(Encodings) +add_subdirectory(Dialect) diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/CMakeLists.txt b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/CMakeLists.txt new file mode 100644 index 000000000..43ba85bbc --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TFHE) diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/CMakeLists.txt b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/CMakeLists.txt new file mode 100644 index 000000000..e31af3266 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..f9865b2e3 --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(ConcretelangTFHETransformsTests) + +add_dependencies(ConcretelangUnitTests ConcretelangTFHETransformsTests) + +add_unittest(ConcretelangTFHETransformsTests unit_tests_concretelang_tfhe_transforms Parametrization.cpp) + +target_link_libraries( + unit_tests_concretelang_tfhe_transforms PRIVATE TFHEDialectTransforms MLIRParser MLIRExecutionEngine + TFHEGlobalParametrization ConcretelangSupport) diff --git a/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/Parametrization.cpp b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/Parametrization.cpp new file mode 100644 index 000000000..56c3a9f4d --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/unit_tests/concretelang/Dialect/TFHE/Transforms/Parametrization.cpp @@ -0,0 +1,246 @@ +#include + +#include "mlir/Parser/Parser.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" + +#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" +#include "concretelang/Dialect/TFHE/Transforms/Transforms.h" + +std::string transform(std::string source, + concrete_optimizer::dag::CircuitSolution solution) { + // Register dialect + mlir::DialectRegistry registry; + registry + .insert(); + mlir::MLIRContext mlirContext; + mlirContext.appendDialectRegistry(registry); + + // Parse from string + auto memoryBuffer = llvm::MemoryBuffer::getMemBuffer(source); + llvm::SourceMgr sm; + sm.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + mlir::OwningOpRef mlirModuleRef = + mlir::parseSourceFile(sm, &mlirContext); + + // Apply the parametrization pass + mlir::PassManager pm(&mlirContext); + + pm.addPass(mlir::concretelang::createTFHECircuitSolutionParametrizationPass( + solution)); + + assert(pm.run(mlirModuleRef->getOperation()).succeeded() && + "pass manager fail"); + // mlirModuleRef->dump(); + std::string moduleString; + llvm::raw_string_ostream s(moduleString); + + mlirModuleRef->getOperation()->print(s); + mlirModuleRef->dump(); + return s.str(); +} + +// Returns the secret key id +int addSecretKey(concrete_optimizer::dag::CircuitSolution &solution, + int glwe_dimension, int polynomial_size) { + concrete_optimizer::dag::SecretLweKey secretLweKey; + secretLweKey.description = "single_key"; + secretLweKey.identifier = 0; + secretLweKey.glwe_dimension = glwe_dimension; + secretLweKey.polynomial_size = polynomial_size; + secretLweKey.identifier = solution.circuit_keys.secret_keys.size(); + solution.circuit_keys.secret_keys.push_back(secretLweKey); + return secretLweKey.identifier; +} + +// Returns the keyswitch key id +int addKeyswitchKey(concrete_optimizer::dag::CircuitSolution &solution, + int input_sk, int output_sk, int level, int base_log) { + concrete_optimizer::dag::KeySwitchKey keySwitchKey; + keySwitchKey.input_key = solution.circuit_keys.secret_keys[input_sk]; + keySwitchKey.output_key = solution.circuit_keys.secret_keys[output_sk]; + + keySwitchKey.ks_decomposition_parameter.level = level; + keySwitchKey.ks_decomposition_parameter.log2_base = base_log; + keySwitchKey.identifier = solution.circuit_keys.keyswitch_keys.size(); + solution.circuit_keys.keyswitch_keys.push_back(keySwitchKey); + return keySwitchKey.identifier; +} + +int addExtraKeyswitchKey(concrete_optimizer::dag::CircuitSolution &solution, + int input_sk, int output_sk, int level, int base_log) { + concrete_optimizer::dag::ConversionKeySwitchKey keySwitchKey; + keySwitchKey.input_key = solution.circuit_keys.secret_keys[input_sk]; + keySwitchKey.output_key = solution.circuit_keys.secret_keys[output_sk]; + + keySwitchKey.ks_decomposition_parameter.level = level; + keySwitchKey.ks_decomposition_parameter.log2_base = base_log; + keySwitchKey.identifier = solution.circuit_keys.keyswitch_keys.size(); + solution.circuit_keys.conversion_keyswitch_keys.push_back(keySwitchKey); + return keySwitchKey.identifier; +} + +// Returns the bootstrap key id +int addBootstrapKey(concrete_optimizer::dag::CircuitSolution &solution, + int input_sk, int output_sk, int level, int base_log) { + concrete_optimizer::dag::BootstrapKey bootstrapKey; + // TODO: Interface design identifier or key + bootstrapKey.input_key = solution.circuit_keys.secret_keys[input_sk]; + bootstrapKey.output_key = solution.circuit_keys.secret_keys[output_sk]; + bootstrapKey.br_decomposition_parameter.level = level; + bootstrapKey.br_decomposition_parameter.log2_base = base_log; + bootstrapKey.identifier = solution.circuit_keys.bootstrap_keys.size(); + solution.circuit_keys.bootstrap_keys.push_back(bootstrapKey); + return bootstrapKey.identifier; +} + +void addInstructionKey(concrete_optimizer::dag::CircuitSolution &solution, + int input_key, int output_key, int ksk = -1, + int bsk = -1, + std::vector extra_conversion_keys = {}) { + concrete_optimizer::dag::InstructionKeys instrKey; + instrKey.input_key = input_key; + instrKey.output_key = output_key; + instrKey.tlu_bootstrap_key = bsk; + instrKey.tlu_keyswitch_key = ksk; + for (const auto &item : extra_conversion_keys) { + instrKey.extra_conversion_keys.push_back(item); + } + solution.instructions_keys.push_back(instrKey); +} + +TEST(TFHECircuitParametrization, single_sk) { + std::string source = R"( + func.func @main(%arg0: !TFHE.glwe {TFHE.OId = 0 : i32}, %arg1: !TFHE.glwe {TFHE.OId = 1 : i32}, %arg2: i64) -> !TFHE.glwe { + %0 = "TFHE.add_glwe_int"(%arg0, %arg2) {TFHE.OId = 2 : i32} : (!TFHE.glwe, i64) -> !TFHE.glwe + %1 = "TFHE.add_glwe"(%0, %arg1) {TFHE.OId = 3 : i32} : (!TFHE.glwe, !TFHE.glwe) -> !TFHE.glwe + return %1 : !TFHE.glwe + } +)"; + std::string expected = R"(module { + func.func @main(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>, %arg2: i64) -> !TFHE.glwe> { + %0 = "TFHE.add_glwe_int"(%arg0, %arg2) : (!TFHE.glwe>, i64) -> !TFHE.glwe> + %1 = "TFHE.add_glwe"(%0, %arg1) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> + return %1 : !TFHE.glwe> + } +} +)"; + // TODO: concrete_optimizer::dag::CircuitSolution + concrete_optimizer::dag::CircuitSolution solution; + auto keyId = addSecretKey(solution, 1, 1024); + concrete_optimizer::dag::InstructionKeys instr0; + // %arg0 + addInstructionKey(solution, keyId, keyId); + // %arg1 + addInstructionKey(solution, keyId, keyId); + // %0 + addInstructionKey(solution, keyId, keyId); + // %1 + addInstructionKey(solution, keyId, keyId); + std::string output = transform(source, solution); + ASSERT_EQ(output, expected); +} + +TEST(TFHECircuitParametrization, keyswitch) { + std::string source = R"( + func.func @main(%arg0: !TFHE.glwe {TFHE.OId = 0 : i32}) -> !TFHE.glwe { + %0 = "TFHE.keyswitch_glwe"(%arg0) {TFHE.OId = 1 : i32, key = #TFHE.ksk} : (!TFHE.glwe) -> !TFHE.glwe + return %0 : !TFHE.glwe + } +)"; + std::string expected = R"(module { + func.func @main(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { + %0 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk, sk<1,1,1701>, 2, 12>} : (!TFHE.glwe>) -> !TFHE.glwe> + return %0 : !TFHE.glwe> + } +} +)"; + concrete_optimizer::dag::CircuitSolution solution; + // Add a first secret key + auto sk0 = addSecretKey(solution, 1, 1024); + auto sk1 = addSecretKey(solution, 3, 567); + auto ksk = addKeyswitchKey(solution, sk0, sk1, 2, 12); + // %arg0 + addInstructionKey(solution, sk0, sk0); + // %0 + addInstructionKey(solution, sk0, sk1, ksk); + std::string output = transform(source, solution); + ASSERT_EQ(output, expected); +} + +TEST(TFHECircuitParametrization, boostrap) { + std::string source = R"( + func.func @main(%arg0: !TFHE.glwe {TFHE.OId = 0 : i32}) -> !TFHE.glwe { + %lut = arith.constant dense<-1152921504606846976> : tensor<42xi64> + %0 = "TFHE.bootstrap_glwe"(%arg0, %lut) {TFHE.OId = 1 : i32, key = #TFHE.bsk} : (!TFHE.glwe, tensor<42xi64>) -> !TFHE.glwe + return %0 : !TFHE.glwe + } +)"; + std::string expected = R"(module { + func.func @main(%arg0: !TFHE.glwe>) -> !TFHE.glwe> { + %cst = arith.constant dense<-1152921504606846976> : tensor<1024xi64> + %0 = "TFHE.bootstrap_glwe"(%arg0, %cst) {key = #TFHE.bsk, sk<1,1,1024>, 1024, 1, 2, 12>} : (!TFHE.glwe>, tensor<1024xi64>) -> !TFHE.glwe> + return %0 : !TFHE.glwe> + } +} +)"; + concrete_optimizer::dag::CircuitSolution solution; + // Add a first secret key + auto sk0 = addSecretKey(solution, 3, 567); + auto sk1 = addSecretKey(solution, 1, 1024); + auto bsk = addBootstrapKey(solution, sk0, sk1, 2, 12); + // %arg0 + addInstructionKey(solution, sk0, sk0); + // %0 + addInstructionKey(solution, sk0, sk1, -1, bsk); + std::string output = transform(source, solution); + ASSERT_EQ(output, expected); +} + +// Test the extra conversion keys used to switch between two partitions without +// boostrap +// TODO: Will be a fastKS +TEST(TFHECircuitParametrization, extra_conversion_key) { + std::string source = R"( + func.func @main(%arg0: !TFHE.glwe {TFHE.OId = 0 : i32}, %arg1: !TFHE.glwe {TFHE.OId = 1 : i32}, %arg2: i64) -> !TFHE.glwe { + // Partition 0 + %0 = "TFHE.add_glwe_int"(%arg0, %arg2) {TFHE.OId = 2 : i32} : (!TFHE.glwe, i64) -> !TFHE.glwe + // Partition 1 + %1 = "TFHE.add_glwe"(%0, %arg1) {TFHE.OId = 3 : i32} : (!TFHE.glwe, !TFHE.glwe) -> !TFHE.glwe + return %1 : !TFHE.glwe + } +)"; + std::string expected = R"(module { + func.func @main(%arg0: !TFHE.glwe>, %arg1: !TFHE.glwe>, %arg2: i64) -> !TFHE.glwe> { + %0 = "TFHE.add_glwe_int"(%arg0, %arg2) : (!TFHE.glwe>, i64) -> !TFHE.glwe> + %1 = "TFHE.keyswitch_glwe"(%0) {key = #TFHE.ksk, sk<1,1,1024>, 2, 12>} : (!TFHE.glwe>) -> !TFHE.glwe> + %2 = "TFHE.add_glwe"(%1, %arg1) : (!TFHE.glwe>, !TFHE.glwe>) -> !TFHE.glwe> + return %2 : !TFHE.glwe> + } +} +)"; + concrete_optimizer::dag::CircuitSolution solution; + // Add secret key for partition 0 + auto sk0 = addSecretKey(solution, 3, 2048); + // Add secret key for partition 1 + auto sk1 = addSecretKey(solution, 1, 1024); + // Extra conversion key + auto ksk = addExtraKeyswitchKey(solution, sk0, sk1, 2, 12); + std::vector extra_conversion_keys{(uint64_t)ksk}; + // Add instruction keys + // #0: %arg0 - partition 0 + addInstructionKey(solution, sk0, sk0); + // #1: %arg1 - partition 1 + addInstructionKey(solution, sk1, sk1); + // #2: %0 - partition 0 with conversion to partition 1 + addInstructionKey(solution, sk0, sk0, -1, -1, extra_conversion_keys); + // #3: %1 - partition 1 + addInstructionKey(solution, sk1, sk1); + std::string output = transform(source, solution); + ASSERT_EQ(output, expected); +}