feat(compiler/multi-parameters): Create a pass to apply the multi-parameter circuit solution of the optimize

This commit is contained in:
Quentin Bourgerie
2023-03-28 10:38:31 +02:00
parent f0ca5aa427
commit 7d1c43bc47
14 changed files with 802 additions and 24 deletions

View File

@@ -14,8 +14,7 @@ namespace mlir {
namespace concretelang {
/// Create a pass to inject fhe parameters to the TFHE types and operators.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTFHEGlobalParametrizationPass(
mlir::concretelang::V0FHEContext &fheContext);
createConvertTFHEGlobalParametrizationPass(const V0Parameter parameter);
} // namespace concretelang
} // namespace mlir

View File

@@ -1,4 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Optimization.td)
mlir_tablegen(Optimization.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangTFHEOptimizationPassIncGen)
add_dependencies(mlir-headers ConcretelangTFHEOptimizationPassIncGen)
set(LLVM_TARGET_DEFINITIONS Transforms.td)
mlir_tablegen(Transforms.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangTFHETransformsPassIncGen)
add_dependencies(mlir-headers ConcretelangTFHETransformsPassIncGen)

View File

@@ -1,13 +0,0 @@
#ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS
#define CONCRETELANG_TFHE_OPTIMIZATION_PASS
include "mlir/Pass/PassBase.td"
def TFHEOptimization : Pass<"tfhe-optimization"> {
let summary = "Optimize TFHE operations";
let constructor = "mlir::concretelang::createTFHEOptimizationPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ];
}
#endif

View File

@@ -6,15 +6,19 @@
#ifndef CONCRETELANG_TFHE_OPTIMIZATION_PASS_H
#define CONCRETELANG_TFHE_OPTIMIZATION_PASS_H
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <mlir/Pass/Pass.h>
#include "concrete-optimizer.hpp"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "mlir/Pass/Pass.h"
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/TFHE/Transforms/Optimization.h.inc>
#include "concretelang/Dialect/TFHE/Transforms/Transforms.h.inc"
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>> createTFHEOptimizationPass();
std::unique_ptr<mlir::OperationPass<>>
createTFHECircuitSolutionParametrizationPass(
concrete_optimizer::dag::CircuitSolution);
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,20 @@
#ifndef CONCRETELANG_TFHE_TRANSFORMS_PASS
#define CONCRETELANG_TFHE_TRANSFORMS_PASS
include "mlir/Pass/PassBase.td"
def TFHEOptimization : Pass<"tfhe-optimization"> {
let summary = "Optimize TFHE operations";
let constructor = "mlir::concretelang::createTFHEOptimizationPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ];
}
def TFHECircuitSolutionParametrization : Pass<"tfhe-circuit-solution-parametrization"> {
let summary = "Parametrize TFHE with a circuit solution given by the optimizer";
let constructor = "mlir::concretelang::createTFHECircuitSolutionParametrizationPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::TFHE::TFHEDialect" ];
}
#endif

View File

@@ -1,6 +1,7 @@
add_mlir_library(
TFHEDialectTransforms
Optimization.cpp
TFHECircuitSolutionParametrization.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/TFHE
DEPENDS

View File

@@ -8,7 +8,7 @@
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <concretelang/Dialect/TFHE/IR/TFHEOps.h>
#include <concretelang/Dialect/TFHE/Transforms/Optimization.h>
#include <concretelang/Dialect/TFHE/Transforms/Transforms.h>
#include <concretelang/Support/Constants.h>
namespace mlir {

View File

@@ -0,0 +1,509 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/Transforms/Transforms.h"
#include "concretelang/Support/Constants.h"
#include "concretelang/Support/logging.h"
namespace mlir {
namespace concretelang {
namespace {
#define DEBUG(MSG) \
if (mlir::concretelang::isVerbose()) { \
llvm::errs() << MSG << "\n"; \
}
namespace TFHE = mlir::concretelang::TFHE;
/// Optimization pass that should choose more efficient ways of performing
/// crypto operations.
class TFHECircuitSolutionParametrizationPass
: public TFHECircuitSolutionParametrizationBase<
TFHECircuitSolutionParametrizationPass> {
public:
TFHECircuitSolutionParametrizationPass(
concrete_optimizer::dag::CircuitSolution solution)
: solution(solution){};
void runOnOperation() override {
mlir::Operation *op = getOperation();
op->walk([&](mlir::func::FuncOp func) {
DEBUG("apply solution: \n" << solution.dump().c_str());
DEBUG("process func: " << func);
auto context = op->getContext();
// Process function arguments, change type of arguments according of the
// optimizer identifier stored in the "TFHE.OId" attribute.
for (size_t i = 0; i < func.getNumArguments(); i++) {
auto arg = func.getArgument(i);
auto attr = func.getArgAttrOfType<mlir::IntegerAttr>(i, "TFHE.OId");
if (attr != nullptr) {
DEBUG("process arg = " << arg)
arg.setType(getParametrizedType(arg.getType(), attr));
} else {
DEBUG("skip arg " << arg)
}
}
// Process operations, apply the instructions keys according of the
// optimizer identifier stored in the "TFHE.OId"
DEBUG("### Apply instruction keys");
func.walk([&](mlir::Operation *op) {
auto attrOptimizerID = op->getAttrOfType<IntegerAttr>("TFHE.OId");
// Skip operation is no optimizer identifier
if (attrOptimizerID == nullptr) {
DEBUG("skip operation: " << op->getName())
return;
}
DEBUG("process operation: " << *op);
auto optimizerID = attrOptimizerID.getInt();
// Change the output type of the operation
for (auto result : op->getResults()) {
result.setType(
getParametrizedType(result.getType(), attrOptimizerID));
}
// Set the keyswitch_key attribute
// TODO: Change ambiguous attribute name
auto attrKeyswitchKey =
op->getAttrOfType<TFHE::GLWEKeyswitchKeyAttr>("key");
if (attrKeyswitchKey == nullptr) {
DEBUG("no keyswitch key");
} else {
op->setAttr("key", getKeyswitchKeyAttr(context, optimizerID));
}
// Set boostrap_key attribute
// TODO: Change ambiguous attribute name
auto attrBootstrapKey =
op->getAttrOfType<TFHE::GLWEBootstrapKeyAttr>("key");
if (attrBootstrapKey == nullptr) {
DEBUG("no bootstrap key");
} else {
op->setAttr("key", getBootstrapKeyAttr(context, optimizerID));
}
});
// The keyswitch operator is an internal node of the optimizer tlu node,
// so it don't follow the same rule than the other operator on the type of
// outputs
DEBUG("### Fixup output of keyswitch")
func.walk([&](TFHE::KeySwitchGLWEOp op) {
DEBUG("process op: " << op)
auto attrKeyswitchKey =
op->getAttrOfType<TFHE::GLWEKeyswitchKeyAttr>("key");
assert(attrKeyswitchKey != nullptr);
auto outputKey = attrKeyswitchKey.getOutputKey();
outputKey = GLWESecretKeyAsLWE(outputKey);
op.getResult().setType(
TFHE::GLWECipherTextType::get(context, outputKey));
DEBUG("fixed op: " << op)
});
// Fixup input of the boostrap operator
DEBUG("### Fixup input tlu of bootstrap")
func.walk([&](TFHE::BootstrapGLWEOp op) {
DEBUG("process op: " << op)
auto attrBootstrapKey =
op->getAttrOfType<TFHE::GLWEBootstrapKeyAttr>("key");
assert(attrBootstrapKey != nullptr);
auto polySize = attrBootstrapKey.getPolySize();
auto lutDefiningOp = op.getLookupTable().getDefiningOp();
// Dirty fixup of the lookup table as we known the operators that can
// define it
// TODO: Do something more robust, using the GLWE type?
mlir::Builder builder(op->getContext());
assert(lutDefiningOp != nullptr);
if (auto encodeOp = mlir::dyn_cast<TFHE::EncodeExpandLutForBootstrapOp>(
lutDefiningOp);
encodeOp != nullptr) {
encodeOp.setPolySize(polySize);
} else if (auto constantOp =
mlir::dyn_cast<arith::ConstantOp>(lutDefiningOp)) {
// Rounded PBS case
auto denseAttr =
constantOp.getValueAttr().dyn_cast<mlir::DenseIntElementsAttr>();
auto val = denseAttr.getValues<int64_t>()[0];
std::vector<int64_t> lut(polySize, val);
constantOp.setValueAttr(mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(lut.size(),
builder.getIntegerType(64)),
lut));
}
op.getLookupTable().setType(mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(polySize), builder.getI64Type()));
// Also fixup the bootstrap key as the TFHENormalization rely on
// GLWESecretKey structure and not on identifier
// TODO: FIXME
auto outputKey = attrBootstrapKey.getOutputKey().getParameterized();
auto newOutputKey = TFHE::GLWESecretKey::newParameterized(
outputKey->polySize * outputKey->dimension, 1,
outputKey->identifier);
auto newAttrBootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
context, attrBootstrapKey.getInputKey(), newOutputKey,
attrBootstrapKey.getPolySize(), attrBootstrapKey.getGlweDim(),
attrBootstrapKey.getLevels(), attrBootstrapKey.getBaseLog(), -1);
op.setKeyAttr(newAttrBootstrapKey);
});
// Fixup incompatible operators with extra conversion keys
DEBUG("### Fixup with extra conversion keys")
func.walk([&](mlir::Operation *op) {
// Skip bootstrap/keyswitch
if (mlir::isa<TFHE::BootstrapGLWEOp>(op) ||
mlir::isa<TFHE::KeySwitchGLWEOp>(op)) {
return;
}
auto attrOptimizerID = op->getAttrOfType<IntegerAttr>("TFHE.OId");
// Skip operation with no optimizer identifier
if (attrOptimizerID == nullptr) {
return;
}
DEBUG(" -> process op: " << *op)
// TFHE operators have only one ciphertext result
assert(op->getNumResults() == 1);
auto resType =
op->getResult(0).getType().dyn_cast<TFHE::GLWECipherTextType>();
// For each ciphertext operands apply the extra keyswitch if found
for (const auto &p : llvm::enumerate(op->getOperands())) {
if (resType == nullptr) {
// We don't expect tensor operands to exist at this point of the
// pipeline for now, but if we happen to have some, this assert will
// break, and things will need to be changed to allow tensor ops to
// be parameterized.
// TODO: Actually this case could happens with tensor manipulation
// operators, so for now we just skip it and that should be fixed
// and tested. As the operand will not be fixed the validation of
// operators should not validate the operators.
continue;
}
auto operand = p.value();
auto operandIdx = p.index();
DEBUG(" -> processing operand " << operand);
auto operandType =
operand.getType().dyn_cast<TFHE::GLWECipherTextType>();
if (operandType == nullptr) {
DEBUG(" -> skip operand, no glwe");
continue;
}
if (operandType.getKey() == resType.getKey()) {
DEBUG(" -> skip operand, unnecessary conversion");
continue;
}
// Lookup for the extra conversion key
auto definingOp = operand.getDefiningOp();
if (definingOp == nullptr) {
DEBUG(" -> skip operand, no defining operator");
continue;
}
auto definingOpOptimizerIDAttr =
definingOp->getAttrOfType<mlir::IntegerAttr>("TFHE.OId");
if (definingOpOptimizerIDAttr == nullptr) {
DEBUG(" -> cannot find optimizer id of the defining op");
continue;
}
DEBUG(" -> get extra conversion key")
auto extraConvKey = getExtraConversionKeyAttr(
context, definingOpOptimizerIDAttr.getInt(), operandType);
if (extraConvKey == nullptr) {
DEBUG(" -> extra conversion key, not found")
assert(false);
}
mlir::IRRewriter rewriter(context);
rewriter.setInsertionPoint(op);
auto newKSK = rewriter.create<TFHE::KeySwitchGLWEOp>(
definingOp->getLoc(), resType, operand, extraConvKey);
DEBUG("create extra conversion keyswitch: " << newKSK);
op->setOperand(operandIdx, newKSK);
}
});
// Propagate types on non parametrized operators
fixupNonParametrizedOps(func);
// Fixup the function signature
fixupFunctionSignature(func);
// Remove optimizer identifiers
removeOptimizerIdentifiers(func);
});
}
static mlir::Type getParametrizedType(mlir::Type originalType,
TFHE::GLWECipherTextType newGlwe) {
if (auto oldGlwe = originalType.dyn_cast<TFHE::GLWECipherTextType>();
oldGlwe != nullptr) {
assert(oldGlwe.getKey().isNone());
return newGlwe;
} else if (auto oldTensor = originalType.dyn_cast<mlir::RankedTensorType>();
oldTensor != nullptr) {
auto oldGlwe =
oldTensor.getElementType().dyn_cast<TFHE::GLWECipherTextType>();
assert(oldGlwe != nullptr);
assert(oldGlwe.getKey().isNone());
return mlir::RankedTensorType::get(oldTensor.getShape(), newGlwe);
}
assert(false);
}
mlir::Type getParametrizedType(mlir::Type originalType,
mlir::IntegerAttr optimizerAttrID) {
auto context = originalType.getContext();
auto newGlwe =
getOutputLWECipherTextType(context, optimizerAttrID.getInt());
return getParametrizedType(originalType, newGlwe);
}
static TFHE::GLWECipherTextType getGlweTypeFromType(mlir::Type type) {
if (auto glwe = type.dyn_cast<TFHE::GLWECipherTextType>();
glwe != nullptr) {
return glwe;
} else if (auto tensor = type.dyn_cast<mlir::RankedTensorType>();
tensor != nullptr) {
auto glwe = tensor.getElementType().dyn_cast<TFHE::GLWECipherTextType>();
if (glwe == nullptr) {
return nullptr;
}
return glwe;
}
return nullptr;
}
// Return the
static TFHE::GLWECipherTextType
getParametrizedGlweTypeFromType(mlir::Type type) {
auto glwe = getGlweTypeFromType(type);
if (glwe != nullptr && glwe.getKey().isParameterized()) {
return glwe;
}
return nullptr;
}
// Returns true if the type is or contains a glwe type with a none key.
static bool isNoneGlweType(mlir::Type type) {
auto glwe = getGlweTypeFromType(type);
return glwe != nullptr && glwe.getKey().isNone();
}
static void
fixupNonParametrizedOp(mlir::Operation *op,
TFHE::GLWECipherTextType parametrizedGlweType) {
DEBUG(" START Fixup {" << *op)
for (auto result : op->getResults()) {
if (isNoneGlweType(result.getType())) {
DEBUG(" -> Fixing result " << result)
result.setType(
getParametrizedType(result.getType(), parametrizedGlweType));
DEBUG(" -> Fixed result " << result)
// Recurse on all users of the fixed result
for (auto user : result.getUsers()) {
DEBUG(" -> Propagate on user " << *user)
fixupNonParametrizedOp(user, parametrizedGlweType);
}
}
}
// Recursively fixup producer of op operands
mlir::Block *parentBlock = nullptr;
for (auto operand : op->getOperands()) {
if (isNoneGlweType(operand.getType())) {
DEBUG(" -> Propagate on operand " << operand.getType())
if (auto opResult = operand.dyn_cast<mlir::OpResult>();
opResult != nullptr) {
fixupNonParametrizedOp(opResult.getOwner(), parametrizedGlweType);
continue;
}
if (auto blockArg = operand.dyn_cast<mlir::BlockArgument>();
blockArg != nullptr) {
DEBUG(" -> Fixing block arg " << blockArg)
blockArg.setType(
getParametrizedType(blockArg.getType(), parametrizedGlweType));
for (auto users : blockArg.getUsers()) {
fixupNonParametrizedOp(users, parametrizedGlweType);
}
auto blockOwner = blockArg.getOwner();
if (blockOwner->isEntryBlock()) {
DEBUG(" -> Will propagate on parent op "
<< blockOwner->getParentOp());
assert(parentBlock == blockOwner || parentBlock == nullptr);
parentBlock = blockOwner;
}
continue;
}
// An mlir::Value should always be an OpResult or a BlockArgument
assert(false);
}
}
DEBUG(" } END Fixup")
if (parentBlock != nullptr) {
fixupNonParametrizedOp(parentBlock->getParentOp(), parametrizedGlweType);
}
}
static void fixupNonParametrizedOps(mlir::func::FuncOp func) {
DEBUG("### Fixup non parametrized ops of function " << func.getSymName())
// Lookup all operators that uses function arguments
for (const auto arg : func.getArguments()) {
auto parametrizedGlweType =
getParametrizedGlweTypeFromType(arg.getType());
if (parametrizedGlweType != nullptr) {
DEBUG(" -> Fixup uses of arg " << arg)
// The argument is glwe, so propagate the glwe parametrization to all
// operators which use it
for (auto userOp : arg.getUsers()) {
fixupNonParametrizedOp(userOp, parametrizedGlweType);
}
}
}
// Fixup all operators that take at least a parametrized glwe and produce an
// none glwe
func.walk([&](mlir::Operation *op) {
for (auto operand : op->getOperands()) {
auto parametrizedGlweType =
getParametrizedGlweTypeFromType(operand.getType());
if (parametrizedGlweType != nullptr) {
// An operand is a parametrized glwe
for (auto result : op->getResults()) {
if (isNoneGlweType(result.getType())) {
DEBUG(" -> Fixup illegal op " << *op)
fixupNonParametrizedOp(op, parametrizedGlweType);
return;
}
}
}
}
});
}
static void removeOptimizerIdentifiers(mlir::func::FuncOp func) {
for (size_t i = 0; i < func.getNumArguments(); i++) {
func.removeArgAttr(i, "TFHE.OId");
}
func.walk([&](mlir::Operation *op) { op->removeAttr("TFHE.OId"); });
}
static void fixupFunctionSignature(mlir::func::FuncOp func) {
mlir::SmallVector<mlir::Type> inputs;
mlir::SmallVector<mlir::Type> outputs;
// Set inputs by looking actual arguments types
for (auto arg : func.getArguments()) {
inputs.push_back(arg.getType());
}
// Look for return to set the outputs
func.walk([&](mlir::func::ReturnOp returnOp) {
// TODO: multiple return op
for (auto output : returnOp->getOperandTypes()) {
outputs.push_back(output);
}
});
auto funcType =
mlir::FunctionType::get(func->getContext(), inputs, outputs);
func.setFunctionType(funcType);
}
const concrete_optimizer::dag::InstructionKeys &
getInstructionKey(size_t optimizerID) {
DEBUG("lookup instruction key: #" << optimizerID);
return solution.instructions_keys[optimizerID];
}
const TFHE::GLWESecretKey GLWESecretKeyAsLWE(TFHE::GLWESecretKey key) {
auto keyP = key.getParameterized();
assert(keyP.has_value());
return TFHE::GLWESecretKey::newParameterized(
keyP->polySize * keyP->dimension, 1, keyP->identifier);
}
const TFHE::GLWESecretKey
toGLWESecretKey(concrete_optimizer::dag::SecretLweKey key) {
return TFHE::GLWESecretKey::newParameterized(
key.glwe_dimension, key.polynomial_size, key.identifier);
}
const TFHE::GLWESecretKey
toLWESecretKey(concrete_optimizer::dag::SecretLweKey key) {
return TFHE::GLWESecretKey::newParameterized(
key.glwe_dimension * key.polynomial_size, 1, key.identifier);
}
const TFHE::GLWESecretKey getLWESecretKey(size_t keyID) {
DEBUG("lookup secret key: #" << keyID);
auto key = solution.circuit_keys.secret_keys[keyID];
assert(keyID == key.identifier);
return toLWESecretKey(key);
}
const TFHE::GLWESecretKey getInputLWESecretKey(size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).input_key;
return getLWESecretKey(keyID);
}
const TFHE::GLWESecretKey getOutputLWESecretKey(size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).output_key;
return getLWESecretKey(keyID);
}
const TFHE::GLWEKeyswitchKeyAttr
getKeyswitchKeyAttr(mlir::MLIRContext *context, size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).tlu_keyswitch_key;
DEBUG("lookup keyswicth key: #" << keyID);
auto key = solution.circuit_keys.keyswitch_keys[keyID];
return TFHE::GLWEKeyswitchKeyAttr::get(
context, toLWESecretKey(key.input_key), toLWESecretKey(key.output_key),
key.ks_decomposition_parameter.level,
key.ks_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWEKeyswitchKeyAttr
getExtraConversionKeyAttr(mlir::MLIRContext *context, size_t optimizerID,
TFHE::GLWECipherTextType operandType) {
DEBUG("get extra conversion key for " << operandType);
auto instructionKey = getInstructionKey(optimizerID);
for (const auto &convKSKID : instructionKey.extra_conversion_keys) {
DEBUG("try extra conversion keyswitch #" << convKSKID)
auto convKSK = solution.circuit_keys.conversion_keyswitch_keys[convKSKID];
auto key = operandType.getKey();
assert(key.isParameterized());
if (operandType.getKey().getParameterized().value().identifier ==
convKSK.input_key.identifier) {
return TFHE::GLWEKeyswitchKeyAttr::get(
context, toLWESecretKey(convKSK.input_key),
toLWESecretKey(convKSK.output_key),
convKSK.ks_decomposition_parameter.level,
convKSK.ks_decomposition_parameter.log2_base, -1);
}
}
DEBUG("!!! extra conversion key not found");
return nullptr;
}
const TFHE::GLWEBootstrapKeyAttr
getBootstrapKeyAttr(mlir::MLIRContext *context, size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).tlu_bootstrap_key;
DEBUG("lookup bootstrap key: #" << keyID);
auto key = solution.circuit_keys.bootstrap_keys[keyID];
return TFHE::GLWEBootstrapKeyAttr::get(
context, toLWESecretKey(key.input_key), toGLWESecretKey(key.output_key),
key.output_key.polynomial_size, key.output_key.glwe_dimension,
key.br_decomposition_parameter.level,
key.br_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWECipherTextType
getOutputLWECipherTextType(mlir::MLIRContext *context, size_t optimizerID) {
auto outputKey = getOutputLWESecretKey(optimizerID);
return TFHE::GLWECipherTextType::get(context, outputKey);
}
private:
concrete_optimizer::dag::CircuitSolution solution;
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<>>
createTFHECircuitSolutionParametrizationPass(
concrete_optimizer::dag::CircuitSolution solution) {
return std::make_unique<TFHECircuitSolutionParametrizationPass>(solution);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -41,7 +41,7 @@
#include <concretelang/Dialect/FHE/Transforms/Max/Max.h>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h>
#include <concretelang/Dialect/RT/Analysis/Autopar.h>
#include <concretelang/Dialect/TFHE/Transforms/Optimization.h>
#include <concretelang/Dialect/TFHE/Transforms/Transforms.h>
#include <concretelang/Support/Pipeline.h>
#include <concretelang/Support/logging.h>
#include <concretelang/Support/math.h>

View File

@@ -4,3 +4,4 @@ add_subdirectory(ClientLib)
add_subdirectory(SDFG)
add_subdirectory(TestLib)
add_subdirectory(Encodings)
add_subdirectory(Dialect)

View File

@@ -0,0 +1 @@
add_subdirectory(TFHE)

View File

@@ -0,0 +1 @@
add_subdirectory(Transforms)

View File

@@ -0,0 +1,9 @@
add_custom_target(ConcretelangTFHETransformsTests)
add_dependencies(ConcretelangUnitTests ConcretelangTFHETransformsTests)
add_unittest(ConcretelangTFHETransformsTests unit_tests_concretelang_tfhe_transforms Parametrization.cpp)
target_link_libraries(
unit_tests_concretelang_tfhe_transforms PRIVATE TFHEDialectTransforms MLIRParser MLIRExecutionEngine
TFHEGlobalParametrization ConcretelangSupport)

View File

@@ -0,0 +1,246 @@
#include <gtest/gtest.h>
#include "mlir/Parser/Parser.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/Transforms/Transforms.h"
std::string transform(std::string source,
concrete_optimizer::dag::CircuitSolution solution) {
// Register dialect
mlir::DialectRegistry registry;
registry
.insert<mlir::concretelang::TFHE::TFHEDialect, mlir::func::FuncDialect>();
mlir::MLIRContext mlirContext;
mlirContext.appendDialectRegistry(registry);
// Parse from string
auto memoryBuffer = llvm::MemoryBuffer::getMemBuffer(source);
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
mlir::OwningOpRef<mlir::ModuleOp> mlirModuleRef =
mlir::parseSourceFile<mlir::ModuleOp>(sm, &mlirContext);
// Apply the parametrization pass
mlir::PassManager pm(&mlirContext);
pm.addPass(mlir::concretelang::createTFHECircuitSolutionParametrizationPass(
solution));
assert(pm.run(mlirModuleRef->getOperation()).succeeded() &&
"pass manager fail");
// mlirModuleRef->dump();
std::string moduleString;
llvm::raw_string_ostream s(moduleString);
mlirModuleRef->getOperation()->print(s);
mlirModuleRef->dump();
return s.str();
}
// Returns the secret key id
int addSecretKey(concrete_optimizer::dag::CircuitSolution &solution,
int glwe_dimension, int polynomial_size) {
concrete_optimizer::dag::SecretLweKey secretLweKey;
secretLweKey.description = "single_key";
secretLweKey.identifier = 0;
secretLweKey.glwe_dimension = glwe_dimension;
secretLweKey.polynomial_size = polynomial_size;
secretLweKey.identifier = solution.circuit_keys.secret_keys.size();
solution.circuit_keys.secret_keys.push_back(secretLweKey);
return secretLweKey.identifier;
}
// Returns the keyswitch key id
int addKeyswitchKey(concrete_optimizer::dag::CircuitSolution &solution,
int input_sk, int output_sk, int level, int base_log) {
concrete_optimizer::dag::KeySwitchKey keySwitchKey;
keySwitchKey.input_key = solution.circuit_keys.secret_keys[input_sk];
keySwitchKey.output_key = solution.circuit_keys.secret_keys[output_sk];
keySwitchKey.ks_decomposition_parameter.level = level;
keySwitchKey.ks_decomposition_parameter.log2_base = base_log;
keySwitchKey.identifier = solution.circuit_keys.keyswitch_keys.size();
solution.circuit_keys.keyswitch_keys.push_back(keySwitchKey);
return keySwitchKey.identifier;
}
int addExtraKeyswitchKey(concrete_optimizer::dag::CircuitSolution &solution,
int input_sk, int output_sk, int level, int base_log) {
concrete_optimizer::dag::ConversionKeySwitchKey keySwitchKey;
keySwitchKey.input_key = solution.circuit_keys.secret_keys[input_sk];
keySwitchKey.output_key = solution.circuit_keys.secret_keys[output_sk];
keySwitchKey.ks_decomposition_parameter.level = level;
keySwitchKey.ks_decomposition_parameter.log2_base = base_log;
keySwitchKey.identifier = solution.circuit_keys.keyswitch_keys.size();
solution.circuit_keys.conversion_keyswitch_keys.push_back(keySwitchKey);
return keySwitchKey.identifier;
}
// Returns the bootstrap key id
int addBootstrapKey(concrete_optimizer::dag::CircuitSolution &solution,
int input_sk, int output_sk, int level, int base_log) {
concrete_optimizer::dag::BootstrapKey bootstrapKey;
// TODO: Interface design identifier or key
bootstrapKey.input_key = solution.circuit_keys.secret_keys[input_sk];
bootstrapKey.output_key = solution.circuit_keys.secret_keys[output_sk];
bootstrapKey.br_decomposition_parameter.level = level;
bootstrapKey.br_decomposition_parameter.log2_base = base_log;
bootstrapKey.identifier = solution.circuit_keys.bootstrap_keys.size();
solution.circuit_keys.bootstrap_keys.push_back(bootstrapKey);
return bootstrapKey.identifier;
}
void addInstructionKey(concrete_optimizer::dag::CircuitSolution &solution,
int input_key, int output_key, int ksk = -1,
int bsk = -1,
std::vector<uint64_t> extra_conversion_keys = {}) {
concrete_optimizer::dag::InstructionKeys instrKey;
instrKey.input_key = input_key;
instrKey.output_key = output_key;
instrKey.tlu_bootstrap_key = bsk;
instrKey.tlu_keyswitch_key = ksk;
for (const auto &item : extra_conversion_keys) {
instrKey.extra_conversion_keys.push_back(item);
}
solution.instructions_keys.push_back(instrKey);
}
TEST(TFHECircuitParametrization, single_sk) {
std::string source = R"(
func.func @main(%arg0: !TFHE.glwe<sk?> {TFHE.OId = 0 : i32}, %arg1: !TFHE.glwe<sk?> {TFHE.OId = 1 : i32}, %arg2: i64) -> !TFHE.glwe<sk?> {
%0 = "TFHE.add_glwe_int"(%arg0, %arg2) {TFHE.OId = 2 : i32} : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
%1 = "TFHE.add_glwe"(%0, %arg1) {TFHE.OId = 3 : i32} : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
return %1 : !TFHE.glwe<sk?>
}
)";
std::string expected = R"(module {
func.func @main(%arg0: !TFHE.glwe<sk<0,1,1024>>, %arg1: !TFHE.glwe<sk<0,1,1024>>, %arg2: i64) -> !TFHE.glwe<sk<0,1,1024>> {
%0 = "TFHE.add_glwe_int"(%arg0, %arg2) : (!TFHE.glwe<sk<0,1,1024>>, i64) -> !TFHE.glwe<sk<0,1,1024>>
%1 = "TFHE.add_glwe"(%0, %arg1) : (!TFHE.glwe<sk<0,1,1024>>, !TFHE.glwe<sk<0,1,1024>>) -> !TFHE.glwe<sk<0,1,1024>>
return %1 : !TFHE.glwe<sk<0,1,1024>>
}
}
)";
// TODO: concrete_optimizer::dag::CircuitSolution
concrete_optimizer::dag::CircuitSolution solution;
auto keyId = addSecretKey(solution, 1, 1024);
concrete_optimizer::dag::InstructionKeys instr0;
// %arg0
addInstructionKey(solution, keyId, keyId);
// %arg1
addInstructionKey(solution, keyId, keyId);
// %0
addInstructionKey(solution, keyId, keyId);
// %1
addInstructionKey(solution, keyId, keyId);
std::string output = transform(source, solution);
ASSERT_EQ(output, expected);
}
TEST(TFHECircuitParametrization, keyswitch) {
std::string source = R"(
func.func @main(%arg0: !TFHE.glwe<sk?> {TFHE.OId = 0 : i32}) -> !TFHE.glwe<sk?> {
%0 = "TFHE.keyswitch_glwe"(%arg0) {TFHE.OId = 1 : i32, key = #TFHE.ksk<sk?, sk?, -1, -1>} : (!TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
return %0 : !TFHE.glwe<sk?>
}
)";
std::string expected = R"(module {
func.func @main(%arg0: !TFHE.glwe<sk<0,1,1024>>) -> !TFHE.glwe<sk<1,1,1701>> {
%0 = "TFHE.keyswitch_glwe"(%arg0) {key = #TFHE.ksk<sk<0,1,1024>, sk<1,1,1701>, 2, 12>} : (!TFHE.glwe<sk<0,1,1024>>) -> !TFHE.glwe<sk<1,1,1701>>
return %0 : !TFHE.glwe<sk<1,1,1701>>
}
}
)";
concrete_optimizer::dag::CircuitSolution solution;
// Add a first secret key
auto sk0 = addSecretKey(solution, 1, 1024);
auto sk1 = addSecretKey(solution, 3, 567);
auto ksk = addKeyswitchKey(solution, sk0, sk1, 2, 12);
// %arg0
addInstructionKey(solution, sk0, sk0);
// %0
addInstructionKey(solution, sk0, sk1, ksk);
std::string output = transform(source, solution);
ASSERT_EQ(output, expected);
}
TEST(TFHECircuitParametrization, boostrap) {
std::string source = R"(
func.func @main(%arg0: !TFHE.glwe<sk?> {TFHE.OId = 0 : i32}) -> !TFHE.glwe<sk?> {
%lut = arith.constant dense<-1152921504606846976> : tensor<42xi64>
%0 = "TFHE.bootstrap_glwe"(%arg0, %lut) {TFHE.OId = 1 : i32, key = #TFHE.bsk<sk?, sk?, -1, -1, -1, -1>} : (!TFHE.glwe<sk?>, tensor<42xi64>) -> !TFHE.glwe<sk?>
return %0 : !TFHE.glwe<sk?>
}
)";
std::string expected = R"(module {
func.func @main(%arg0: !TFHE.glwe<sk<0,1,1701>>) -> !TFHE.glwe<sk<1,1,1024>> {
%cst = arith.constant dense<-1152921504606846976> : tensor<1024xi64>
%0 = "TFHE.bootstrap_glwe"(%arg0, %cst) {key = #TFHE.bsk<sk<0,1,1701>, sk<1,1,1024>, 1024, 1, 2, 12>} : (!TFHE.glwe<sk<0,1,1701>>, tensor<1024xi64>) -> !TFHE.glwe<sk<1,1,1024>>
return %0 : !TFHE.glwe<sk<1,1,1024>>
}
}
)";
concrete_optimizer::dag::CircuitSolution solution;
// Add a first secret key
auto sk0 = addSecretKey(solution, 3, 567);
auto sk1 = addSecretKey(solution, 1, 1024);
auto bsk = addBootstrapKey(solution, sk0, sk1, 2, 12);
// %arg0
addInstructionKey(solution, sk0, sk0);
// %0
addInstructionKey(solution, sk0, sk1, -1, bsk);
std::string output = transform(source, solution);
ASSERT_EQ(output, expected);
}
// Test the extra conversion keys used to switch between two partitions without
// boostrap
// TODO: Will be a fastKS
TEST(TFHECircuitParametrization, extra_conversion_key) {
std::string source = R"(
func.func @main(%arg0: !TFHE.glwe<sk?> {TFHE.OId = 0 : i32}, %arg1: !TFHE.glwe<sk?> {TFHE.OId = 1 : i32}, %arg2: i64) -> !TFHE.glwe<sk?> {
// Partition 0
%0 = "TFHE.add_glwe_int"(%arg0, %arg2) {TFHE.OId = 2 : i32} : (!TFHE.glwe<sk?>, i64) -> !TFHE.glwe<sk?>
// Partition 1
%1 = "TFHE.add_glwe"(%0, %arg1) {TFHE.OId = 3 : i32} : (!TFHE.glwe<sk?>, !TFHE.glwe<sk?>) -> !TFHE.glwe<sk?>
return %1 : !TFHE.glwe<sk?>
}
)";
std::string expected = R"(module {
func.func @main(%arg0: !TFHE.glwe<sk<0,1,6144>>, %arg1: !TFHE.glwe<sk<1,1,1024>>, %arg2: i64) -> !TFHE.glwe<sk<1,1,1024>> {
%0 = "TFHE.add_glwe_int"(%arg0, %arg2) : (!TFHE.glwe<sk<0,1,6144>>, i64) -> !TFHE.glwe<sk<0,1,6144>>
%1 = "TFHE.keyswitch_glwe"(%0) {key = #TFHE.ksk<sk<0,1,6144>, sk<1,1,1024>, 2, 12>} : (!TFHE.glwe<sk<0,1,6144>>) -> !TFHE.glwe<sk<1,1,1024>>
%2 = "TFHE.add_glwe"(%1, %arg1) : (!TFHE.glwe<sk<1,1,1024>>, !TFHE.glwe<sk<1,1,1024>>) -> !TFHE.glwe<sk<1,1,1024>>
return %2 : !TFHE.glwe<sk<1,1,1024>>
}
}
)";
concrete_optimizer::dag::CircuitSolution solution;
// Add secret key for partition 0
auto sk0 = addSecretKey(solution, 3, 2048);
// Add secret key for partition 1
auto sk1 = addSecretKey(solution, 1, 1024);
// Extra conversion key
auto ksk = addExtraKeyswitchKey(solution, sk0, sk1, 2, 12);
std::vector<uint64_t> extra_conversion_keys{(uint64_t)ksk};
// Add instruction keys
// #0: %arg0 - partition 0
addInstructionKey(solution, sk0, sk0);
// #1: %arg1 - partition 1
addInstructionKey(solution, sk1, sk1);
// #2: %0 - partition 0 with conversion to partition 1
addInstructionKey(solution, sk0, sk0, -1, -1, extra_conversion_keys);
// #3: %1 - partition 1
addInstructionKey(solution, sk1, sk1);
std::string output = transform(source, solution);
ASSERT_EQ(output, expected);
}