Files
concrete/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Transforms/TFHECircuitSolutionParametrization.cpp
Bourgerie Quentin b6da228dd4 fix(compiler): Add extra conversion keyswitch just after bootstrap
Rely on strong asumptions of the optimization (see comment)

close https://github.com/zama-ai/concrete-internal/issues/352
2023-06-21 08:53:36 +02:00

547 lines
22 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/Transforms/Transforms.h"
#include "concretelang/Support/Constants.h"
#include "concretelang/Support/logging.h"
namespace mlir {
namespace concretelang {
namespace {
#define DEBUG(MSG) \
if (llvm::DebugFlag) \
llvm::errs() << MSG << "\n";
#define VERBOSE(MSG) \
if (mlir::concretelang::isVerbose()) { \
llvm::errs() << MSG << "\n"; \
}
namespace TFHE = mlir::concretelang::TFHE;
/// Optimization pass that should choose more efficient ways of performing
/// crypto operations.
class TFHECircuitSolutionParametrizationPass
: public TFHECircuitSolutionParametrizationBase<
TFHECircuitSolutionParametrizationPass> {
public:
TFHECircuitSolutionParametrizationPass(
concrete_optimizer::dag::CircuitSolution solution)
: solution(solution){};
void runOnOperation() override {
mlir::Operation *op = getOperation();
op->walk([&](mlir::func::FuncOp func) {
DEBUG("apply solution: \n" << solution.dump().c_str());
DEBUG("process func: " << func);
// Process function arguments, change type of arguments according of the
// optimizer identifier stored in the "TFHE.OId" attribute.
for (size_t i = 0; i < func.getNumArguments(); i++) {
auto arg = func.getArgument(i);
auto attr = func.getArgAttrOfType<mlir::IntegerAttr>(i, "TFHE.OId");
if (attr != nullptr) {
DEBUG("process arg = " << arg)
arg.setType(getParametrizedType(arg.getType(), attr));
} else {
DEBUG("skip arg " << arg)
}
}
// Process operations, apply the instructions keys according of the
// optimizer identifier stored in the "TFHE.OId"
VERBOSE("\n### BEFORE Apply instruction keys " << func);
applyInstructionKeys(func);
// The keyswitch operator is an internal node of the optimizer tlu node,
// so it don't follow the same rule than the other operator on the type of
// outputs
VERBOSE("\n### BEFORE Fixup keyswitch \n" << func);
fixupKeyswitchOuputs(func);
// Propagate types on non parametrized operators
VERBOSE("\n### BEFORE Fixup non parametrized ops \n" << func);
fixupNonParametrizedOps(func);
// Fixup incompatible operators with extra conversion keys
VERBOSE("\n### BEFORE Fixup with extra conversion keys \n" << func);
fixupIncompatibleLeveledOpWithExtraConversionKeys(func);
// Fixup the function signature
VERBOSE("\n### BEFORE Fixup function signature \n" << func);
fixupFunctionSignature(func);
// Remove optimizer identifiers
VERBOSE("\n### BEFORE Remove optimizer identifiers \n" << func);
removeOptimizerIdentifiers(func);
});
}
static mlir::Type getParametrizedType(mlir::Type originalType,
TFHE::GLWECipherTextType newGlwe) {
if (auto oldGlwe = originalType.dyn_cast<TFHE::GLWECipherTextType>();
oldGlwe != nullptr) {
assert(oldGlwe.getKey().isNone());
return newGlwe;
} else if (auto oldTensor = originalType.dyn_cast<mlir::RankedTensorType>();
oldTensor != nullptr) {
auto oldGlwe =
oldTensor.getElementType().dyn_cast<TFHE::GLWECipherTextType>();
assert(oldGlwe != nullptr);
assert(oldGlwe.getKey().isNone());
return mlir::RankedTensorType::get(oldTensor.getShape(), newGlwe);
}
assert(false);
}
mlir::Type getParametrizedType(mlir::Type originalType,
mlir::IntegerAttr optimizerAttrID) {
auto context = originalType.getContext();
auto newGlwe =
getOutputLWECipherTextType(context, optimizerAttrID.getInt());
return getParametrizedType(originalType, newGlwe);
}
static TFHE::GLWECipherTextType getGlweTypeFromType(mlir::Type type) {
if (auto glwe = type.dyn_cast<TFHE::GLWECipherTextType>();
glwe != nullptr) {
return glwe;
} else if (auto tensor = type.dyn_cast<mlir::RankedTensorType>();
tensor != nullptr) {
auto glwe = tensor.getElementType().dyn_cast<TFHE::GLWECipherTextType>();
if (glwe == nullptr) {
return nullptr;
}
return glwe;
}
return nullptr;
}
// Return the
static TFHE::GLWECipherTextType
getParametrizedGlweTypeFromType(mlir::Type type) {
auto glwe = getGlweTypeFromType(type);
if (glwe != nullptr && glwe.getKey().isParameterized()) {
return glwe;
}
return nullptr;
}
// Returns true if the type is or contains a glwe type with a none key.
static bool isNoneGlweType(mlir::Type type) {
auto glwe = getGlweTypeFromType(type);
return glwe != nullptr && glwe.getKey().isNone();
}
void applyInstructionKeys(mlir::func::FuncOp func) {
auto context = func.getContext();
func.walk([&](mlir::Operation *op) {
auto attrOptimizerID = op->getAttrOfType<IntegerAttr>("TFHE.OId");
// Skip operation is no optimizer identifier
if (attrOptimizerID == nullptr) {
DEBUG("skip operation: " << op->getName())
return;
}
DEBUG("process operation: " << *op);
auto optimizerID = attrOptimizerID.getInt();
// Change the output type of the operation
for (auto result : op->getResults()) {
result.setType(getParametrizedType(result.getType(), attrOptimizerID));
}
// Set the keyswitch_key attribute
// TODO: Change ambiguous attribute name
auto attrKeyswitchKey =
op->getAttrOfType<TFHE::GLWEKeyswitchKeyAttr>("key");
if (attrKeyswitchKey == nullptr) {
DEBUG("no keyswitch key");
} else {
op->setAttr("key", getKeyswitchKeyAttr(context, optimizerID));
}
// Set boostrap_key attribute
// TODO: Change ambiguous attribute name
auto attrBootstrapKey =
op->getAttrOfType<TFHE::GLWEBootstrapKeyAttr>("key");
if (attrBootstrapKey == nullptr) {
DEBUG("no bootstrap key");
} else {
op->setAttr("key", getBootstrapKeyAttr(context, optimizerID));
// FIXME: For now we know that if there are an extra conversion key
// this result will only be used in another partition. This is a
// STRONG assumptions of how the optimization work, this is done like
// that to avoid a bug in type propagation, but the extra conversion
// key should be added at the use and not here.
auto instKeys = getInstructionKey(optimizerID);
if (instKeys.extra_conversion_keys.size() != 0) {
assert(instKeys.extra_conversion_keys.size() == 1);
auto convKSK =
solution.circuit_keys
.conversion_keyswitch_keys[instKeys.extra_conversion_keys[0]];
auto convKSKAttr = getExtraConversionKeyAttr(context, convKSK);
mlir::IRRewriter rewriter(context);
rewriter.setInsertionPointAfter(op);
auto outputKey = toLWESecretKey(convKSK.output_key);
auto resType = TFHE::GLWECipherTextType::get(context, outputKey);
auto extraKSK = rewriter.create<TFHE::KeySwitchGLWEOp>(
op->getLoc(), resType, op->getResult(0), convKSKAttr);
rewriter.replaceAllUsesExcept(op->getResult(0), extraKSK, extraKSK);
}
}
});
}
void fixupKeyswitchOuputs(mlir::func::FuncOp func) {
auto context = func.getContext();
func.walk([&](TFHE::KeySwitchGLWEOp op) {
DEBUG("process op: " << op)
auto attrKeyswitchKey =
op->getAttrOfType<TFHE::GLWEKeyswitchKeyAttr>("key");
assert(attrKeyswitchKey != nullptr);
auto outputKey = attrKeyswitchKey.getOutputKey();
outputKey = GLWESecretKeyAsLWE(outputKey);
op.getResult().setType(TFHE::GLWECipherTextType::get(context, outputKey));
DEBUG("fixed op: " << op)
});
// Fixup input of the boostrap operator
DEBUG("### Fixup input tlu of bootstrap")
func.walk([&](TFHE::BootstrapGLWEOp op) {
DEBUG("process op: " << op)
auto attrBootstrapKey =
op->getAttrOfType<TFHE::GLWEBootstrapKeyAttr>("key");
assert(attrBootstrapKey != nullptr);
auto polySize = attrBootstrapKey.getPolySize();
auto lutDefiningOp = op.getLookupTable().getDefiningOp();
// Dirty fixup of the lookup table as we known the operators that can
// define it
// TODO: Do something more robust, using the GLWE type?
mlir::Builder builder(op->getContext());
assert(lutDefiningOp != nullptr);
if (auto encodeOp = mlir::dyn_cast<TFHE::EncodeExpandLutForBootstrapOp>(
lutDefiningOp);
encodeOp != nullptr) {
encodeOp.setPolySize(polySize);
} else if (auto constantOp =
mlir::dyn_cast<arith::ConstantOp>(lutDefiningOp)) {
// Rounded PBS case
auto denseAttr =
constantOp.getValueAttr().dyn_cast<mlir::DenseIntElementsAttr>();
auto val = denseAttr.getValues<int64_t>()[0];
std::vector<int64_t> lut(polySize, val);
constantOp.setValueAttr(mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(lut.size(), builder.getIntegerType(64)),
lut));
}
op.getLookupTable().setType(mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(polySize), builder.getI64Type()));
// Also fixup the bootstrap key as the TFHENormalization rely on
// GLWESecretKey structure and not on identifier
// TODO: FIXME
auto outputKey = attrBootstrapKey.getOutputKey().getParameterized();
auto newOutputKey = TFHE::GLWESecretKey::newParameterized(
outputKey->polySize * outputKey->dimension, 1, outputKey->identifier);
auto newAttrBootstrapKey = TFHE::GLWEBootstrapKeyAttr::get(
context, attrBootstrapKey.getInputKey(), newOutputKey,
attrBootstrapKey.getPolySize(), attrBootstrapKey.getGlweDim(),
attrBootstrapKey.getLevels(), attrBootstrapKey.getBaseLog(), -1);
op.setKeyAttr(newAttrBootstrapKey);
});
}
static void
fixupNonParametrizedOp(mlir::Operation *op,
TFHE::GLWECipherTextType parametrizedGlweType) {
DEBUG(" START Fixup {" << *op)
for (auto result : op->getResults()) {
if (isNoneGlweType(result.getType())) {
DEBUG(" -> Fixing result " << result)
result.setType(
getParametrizedType(result.getType(), parametrizedGlweType));
DEBUG(" -> Fixed result " << result)
// Recurse on all users of the fixed result
for (auto user : result.getUsers()) {
DEBUG(" -> Propagate on user " << *user)
fixupNonParametrizedOp(user, parametrizedGlweType);
}
}
}
// Recursively fixup producer of op operands
mlir::Block *parentBlock = nullptr;
for (auto operand : op->getOperands()) {
if (isNoneGlweType(operand.getType())) {
DEBUG(" -> Propagate on operand " << operand.getType())
if (auto opResult = operand.dyn_cast<mlir::OpResult>();
opResult != nullptr) {
fixupNonParametrizedOp(opResult.getOwner(), parametrizedGlweType);
continue;
}
if (auto blockArg = operand.dyn_cast<mlir::BlockArgument>();
blockArg != nullptr) {
DEBUG(" -> Fixing block arg " << blockArg)
blockArg.setType(
getParametrizedType(blockArg.getType(), parametrizedGlweType));
for (auto users : blockArg.getUsers()) {
fixupNonParametrizedOp(users, parametrizedGlweType);
}
auto blockOwner = blockArg.getOwner();
if (blockOwner->isEntryBlock()) {
DEBUG(" -> Will propagate on parent op "
<< blockOwner->getParentOp());
assert(parentBlock == blockOwner || parentBlock == nullptr);
parentBlock = blockOwner;
}
continue;
}
// An mlir::Value should always be an OpResult or a BlockArgument
assert(false);
}
}
DEBUG(" } END Fixup")
if (parentBlock != nullptr) {
fixupNonParametrizedOp(parentBlock->getParentOp(), parametrizedGlweType);
}
}
void
fixupIncompatibleLeveledOpWithExtraConversionKeys(mlir::func::FuncOp func) {
auto context = func.getContext();
func.walk([&](mlir::Operation *op) {
// Skip bootstrap/keyswitch
if (mlir::isa<TFHE::BootstrapGLWEOp>(op) ||
mlir::isa<TFHE::KeySwitchGLWEOp>(op)) {
return;
}
auto attrOptimizerID = op->getAttrOfType<IntegerAttr>("TFHE.OId");
// Skip operation with no optimizer identifier
if (attrOptimizerID == nullptr) {
return;
}
DEBUG(" -> process op: " << *op)
// TFHE operators have only one ciphertext result
assert(op->getNumResults() == 1);
auto resType =
op->getResult(0).getType().dyn_cast<TFHE::GLWECipherTextType>();
// For each ciphertext operands apply the extra keyswitch if found
for (const auto &p : llvm::enumerate(op->getOperands())) {
if (resType == nullptr) {
// We don't expect tensor operands to exist at this point of the
// pipeline for now, but if we happen to have some, this assert will
// break, and things will need to be changed to allow tensor ops to
// be parameterized.
// TODO: Actually this case could happens with tensor manipulation
// operators, so for now we just skip it and that should be fixed
// and tested. As the operand will not be fixed the validation of
// operators should not validate the operators.
continue;
}
auto operand = p.value();
auto operandIdx = p.index();
DEBUG(" -> processing operand " << operand);
auto operandType =
operand.getType().dyn_cast<TFHE::GLWECipherTextType>();
if (operandType == nullptr) {
DEBUG(" -> skip operand, no glwe");
continue;
}
if (operandType.getKey() == resType.getKey()) {
DEBUG(" -> skip operand, unnecessary conversion");
continue;
}
// Lookup for the extra conversion key
DEBUG(" -> get extra conversion key")
auto extraConvKey = getExtraConversionKeyAttr(
context, operandType.getKey(), resType.getKey());
if (extraConvKey == nullptr) {
DEBUG(" -> extra conversion key, not found")
assert(false);
}
mlir::IRRewriter rewriter(context);
rewriter.setInsertionPoint(op);
auto newKSK = rewriter.create<TFHE::KeySwitchGLWEOp>(
op->getLoc(), resType, operand, extraConvKey);
DEBUG("create extra conversion keyswitch: " << newKSK);
op->setOperand(operandIdx, newKSK);
}
});
}
static void fixupNonParametrizedOps(mlir::func::FuncOp func) {
// Lookup all operators that uses function arguments
for (const auto arg : func.getArguments()) {
auto parametrizedGlweType =
getParametrizedGlweTypeFromType(arg.getType());
if (parametrizedGlweType != nullptr) {
DEBUG(" -> Fixup uses of arg " << arg)
// The argument is glwe, so propagate the glwe parametrization to all
// operators which use it
for (auto userOp : arg.getUsers()) {
fixupNonParametrizedOp(userOp, parametrizedGlweType);
}
}
}
// Fixup all operators that take at least a parametrized glwe and produce an
// none glwe
func.walk([&](mlir::Operation *op) {
for (auto operand : op->getOperands()) {
auto parametrizedGlweType =
getParametrizedGlweTypeFromType(operand.getType());
if (parametrizedGlweType != nullptr) {
// An operand is a parametrized glwe
for (auto result : op->getResults()) {
if (isNoneGlweType(result.getType())) {
DEBUG(" -> Fixup illegal op " << *op)
fixupNonParametrizedOp(op, parametrizedGlweType);
return;
}
}
}
}
});
}
static void removeOptimizerIdentifiers(mlir::func::FuncOp func) {
for (size_t i = 0; i < func.getNumArguments(); i++) {
func.removeArgAttr(i, "TFHE.OId");
}
func.walk([&](mlir::Operation *op) { op->removeAttr("TFHE.OId"); });
}
static void fixupFunctionSignature(mlir::func::FuncOp func) {
mlir::SmallVector<mlir::Type> inputs;
mlir::SmallVector<mlir::Type> outputs;
// Set inputs by looking actual arguments types
for (auto arg : func.getArguments()) {
inputs.push_back(arg.getType());
}
// Look for return to set the outputs
func.walk([&](mlir::func::ReturnOp returnOp) {
// TODO: multiple return op
for (auto output : returnOp->getOperandTypes()) {
outputs.push_back(output);
}
});
auto funcType =
mlir::FunctionType::get(func->getContext(), inputs, outputs);
func.setFunctionType(funcType);
}
const concrete_optimizer::dag::InstructionKeys &
getInstructionKey(size_t optimizerID) {
DEBUG("lookup instruction key: #" << optimizerID);
return solution.instructions_keys[optimizerID];
}
const TFHE::GLWESecretKey GLWESecretKeyAsLWE(TFHE::GLWESecretKey key) {
auto keyP = key.getParameterized();
assert(keyP.has_value());
return TFHE::GLWESecretKey::newParameterized(
keyP->polySize * keyP->dimension, 1, keyP->identifier);
}
const TFHE::GLWESecretKey
toGLWESecretKey(concrete_optimizer::dag::SecretLweKey key) {
return TFHE::GLWESecretKey::newParameterized(
key.glwe_dimension, key.polynomial_size, key.identifier);
}
const TFHE::GLWESecretKey
toLWESecretKey(concrete_optimizer::dag::SecretLweKey key) {
return TFHE::GLWESecretKey::newParameterized(
key.glwe_dimension * key.polynomial_size, 1, key.identifier);
}
const TFHE::GLWESecretKey getLWESecretKey(size_t keyID) {
DEBUG("lookup secret key: #" << keyID);
auto key = solution.circuit_keys.secret_keys[keyID];
assert(keyID == key.identifier);
return toLWESecretKey(key);
}
const TFHE::GLWESecretKey getInputLWESecretKey(size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).input_key;
return getLWESecretKey(keyID);
}
const TFHE::GLWESecretKey getOutputLWESecretKey(size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).output_key;
return getLWESecretKey(keyID);
}
const TFHE::GLWEKeyswitchKeyAttr
getKeyswitchKeyAttr(mlir::MLIRContext *context, size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).tlu_keyswitch_key;
DEBUG("lookup keyswicth key: #" << keyID);
auto key = solution.circuit_keys.keyswitch_keys[keyID];
return TFHE::GLWEKeyswitchKeyAttr::get(
context, toLWESecretKey(key.input_key), toLWESecretKey(key.output_key),
key.ks_decomposition_parameter.level,
key.ks_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWEKeyswitchKeyAttr
getExtraConversionKeyAttr(mlir::MLIRContext *context,
TFHE::GLWESecretKey inputKey,
TFHE::GLWESecretKey ouputKey) {
auto convKSK = std::find_if(
solution.circuit_keys.conversion_keyswitch_keys.begin(),
solution.circuit_keys.conversion_keyswitch_keys.end(),
[&](concrete_optimizer::dag::ConversionKeySwitchKey &arg) {
assert(ouputKey.isParameterized() && inputKey.isParameterized());
return arg.input_key.identifier ==
inputKey.getParameterized()->identifier &&
arg.output_key.identifier ==
ouputKey.getParameterized()->identifier;
});
assert(convKSK != solution.circuit_keys.conversion_keyswitch_keys.end());
return TFHE::GLWEKeyswitchKeyAttr::get(
context, toLWESecretKey(convKSK->input_key),
toLWESecretKey(convKSK->output_key),
convKSK->ks_decomposition_parameter.level,
convKSK->ks_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWEKeyswitchKeyAttr getExtraConversionKeyAttr(
mlir::MLIRContext *context,
concrete_optimizer::dag::ConversionKeySwitchKey convKSK) {
return TFHE::GLWEKeyswitchKeyAttr::get(
context, toLWESecretKey(convKSK.input_key),
toLWESecretKey(convKSK.output_key),
convKSK.ks_decomposition_parameter.level,
convKSK.ks_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWEBootstrapKeyAttr
getBootstrapKeyAttr(mlir::MLIRContext *context, size_t optimizerID) {
auto keyID = getInstructionKey(optimizerID).tlu_bootstrap_key;
DEBUG("lookup bootstrap key: #" << keyID);
auto key = solution.circuit_keys.bootstrap_keys[keyID];
return TFHE::GLWEBootstrapKeyAttr::get(
context, toLWESecretKey(key.input_key), toGLWESecretKey(key.output_key),
key.output_key.polynomial_size, key.output_key.glwe_dimension,
key.br_decomposition_parameter.level,
key.br_decomposition_parameter.log2_base, -1);
}
const TFHE::GLWECipherTextType
getOutputLWECipherTextType(mlir::MLIRContext *context, size_t optimizerID) {
auto outputKey = getOutputLWESecretKey(optimizerID);
return TFHE::GLWECipherTextType::get(context, outputKey);
}
private:
concrete_optimizer::dag::CircuitSolution solution;
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<>>
createTFHECircuitSolutionParametrizationPass(
concrete_optimizer::dag::CircuitSolution solution) {
return std::make_unique<TFHECircuitSolutionParametrizationPass>(solution);
}
} // namespace concretelang
} // namespace mlir