Files
concrete/compilers/concrete-compiler/compiler/lib/ClientLib/KeySet.cpp
Andi Drebes c8c969773e Rebase onto llvm-project 465ee9bfb26d with local changes
This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:

  * 5d8669d669ee: Fix the element alignment (size) for memrefCopy
  * 4239163ea337: fix: Do not fold the memref.subview if the offset are
                  != 0 and strides != 1
  * 72c5decfcc21: remove github stuff from llvm
  * 8d0ce8f9eca1: Support arbitrary element types in named operations
                  via attributes
  * 94f64805c38c: Copy attributes of scf.for on bufferization and make
                  it an allocation hoisting barrier

Main upstream changes from llvm-project that required modification of
concretecompiler:

  * Switch to C++17
  * Various changes in the interfaces for linalg named operations
  * Transition from `llvm::Optional` to `std::optional`
  * Use of enums instead of string values for iterator types in linalg
  * Changed default naming convention of getter methods in
    ODS-generated operation classes from `some_value()` to
    `getSomeValue()`
  * Renaming of Arithmetic dialect to Arith
  * Refactoring of side effect interfaces (i.e., renaming from
    `NoSideEffect` to `Pure`)
  * Re-design of the data flow analysis framework
  * Refactoring of build targets for Python bindings
  * Refactoring of array attributes with integer values
  * Renaming of `linalg.init_tensor` to `tensor.empty`
  * Emission of `linalg.map` operations in bufferization of the Tensor
    dialect requiring another linalg conversion pass and registration
    of the bufferization op interfaces for linalg operations
  * Refactoring of the one-shot bufferizer
  * Necessity to run the expand-strided-metadata, affine-to-std and
    finalize-memref-to-llvm passes before converson to the LLVM
    dialect
  * Renaming of `BlockAndValueMapping` to `IRMapping`
  * Changes in the build function of `LLVM::CallOp`
  * Refactoring of the construction of `llvm::ArrayRef` and
    `llvm::MutableArrayRef` (direct invocation of constructor instead
    of builder functions for some cases)
  * New naming conventions for generated SSA values requiring rewrite
    of some check tests
  * Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
  * Interface changes in generated type parsers
  * New dependencies for to mlir_float16_utils and
    MLIRSparseTensorRuntime for the runtime
  * Overhaul of MLIR-c deleting `mlir-c/Registration.h`
  * Deletion of library MLIRLinalgToSPIRV
  * Deletion of library MLIRLinalgAnalysis
  * Deletion of library MLIRMemRefUtils
  * Deletion of library MLIRQuantTransforms
  * Deletion of library MLIRVectorToROCDL
2023-03-09 17:47:16 +01:00

301 lines
10 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Support/Error.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
namespace concretelang {
namespace clientlib {
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) {
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
OUTCOME_TRYV(keySet->generateKeysFromParams());
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
outcome::checked<std::unique_ptr<KeySet>, StringError> KeySet::fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng) {
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
keySet->secretKeys = secretKeys;
keySet->bootstrapKeys = bootstrapKeys;
keySet->keyswitchKeys = keyswitchKeys;
keySet->packingKeyswitchKeys = packingKeyswitchKeys;
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
EvaluationKeys KeySet::evaluationKeys() {
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
}
outcome::checked<KeySet::SecretKeyGateMapping, StringError>
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
SecretKeyGateMapping mapping;
for (auto gate : gates) {
if (gate.encryption.has_value()) {
assert(gate.encryption->secretKeyID < this->secretKeys.size());
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {gate, skIt};
mapping.push_back(input);
} else {
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {
gate, std::nullopt};
mapping.push_back(input);
}
}
return mapping;
}
outcome::checked<void, StringError> KeySet::setupEncryptionMaterial() {
OUTCOME_TRY(this->inputs,
mapCircuitGateLweSecretKey(_clientParameters.inputs));
OUTCOME_TRY(this->outputs,
mapCircuitGateLweSecretKey(_clientParameters.outputs));
return outcome::success();
}
outcome::checked<void, StringError> KeySet::generateKeysFromParams() {
// Generate LWE secret keys
for (auto secretKeyParam : _clientParameters.secretKeys) {
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam));
}
// Generate bootstrap keys
for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) {
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam));
}
// Generate keyswitch key
for (auto keyswitchParam : _clientParameters.keyswitchKeys) {
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam));
}
// Generate packing keyswitch key
for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) {
OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam));
}
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyParam param) {
// Init the lwe secret key
LweSecretKey sk(param, csprng);
// Store the lwe secret key
secretKeys.push_back(sk);
return outcome::success();
}
outcome::checked<LweSecretKey, StringError>
KeySet::findLweSecretKey(LweSecretKeyID keyID) {
assert(keyID < secretKeys.size());
auto secretKey = secretKeys[keyID];
return secretKey;
}
outcome::checked<void, StringError>
KeySet::generateBootstrapKey(BootstrapKeyParam param) {
// Finding input and output secretKeys
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng);
// Store the bootstrap key
bootstrapKeys.push_back(std::move(bootstrapKey));
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateKeyswitchKey(KeyswitchKeyParam param) {
// Finding input and output secretKeys
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng);
// Store the keyswitch key
keyswitchKeys.push_back(keyswitchKey);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) {
// Finding input secretKeys
assert(param.inputSecretKeyID < secretKeys.size());
auto inputSk = secretKeys[param.inputSecretKeyID];
assert(param.outputSecretKeyID < secretKeys.size());
auto outputSk = secretKeys[param.outputSecretKeyID];
PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng);
// Store the keyswitch key
packingKeyswitchKeys.push_back(packingKeyswitchKey);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
if (argPos >= inputs.size()) {
return StringError("allocate_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.has_value()) {
return StringError("allocate_lwe argument #")
<< argPos << "is not encypeted";
}
auto numBlocks =
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
assert(inputSk.second.has_value());
size = inputSk.second->parameters().lweSize();
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
return outcome::success();
}
bool KeySet::isInputEncrypted(size_t argPos) {
return argPos < inputs.size() &&
std::get<0>(inputs[argPos]).encryption.has_value();
}
bool KeySet::isOutputEncrypted(size_t argPos) {
return argPos < outputs.size() &&
std::get<0>(outputs[argPos]).encryption.has_value();
}
/// Return the number of bits to represents the given value
uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); }
outcome::checked<void, StringError>
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
if (argPos >= inputs.size()) {
return StringError("encrypt_lwe position of argument is too high");
}
const auto &inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.has_value()) {
return StringError("encrypt_lwe the positional argument is not encrypted");
}
auto encoding = encryption->encoding;
assert(inputSk.second.has_value());
auto lweSecretKey = *inputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
// CRT encoding - N blocks with crt encoding
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
// Put each decomposition into a new ciphertext
auto product = crt::productOfModuli(crt);
for (auto modulus : crt) {
auto plaintext = crt::encode(input, modulus, product);
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
return outcome::success();
}
// Simple TFHE integers - 1 blocks with one padding bits
// TODO we could check if the input value is in the right range
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
if (argPos >= outputs.size()) {
return StringError("decrypt_lwe: position of argument is too high");
}
auto outputSk = outputs[argPos];
assert(outputSk.second.has_value());
auto lweSecretKey = *outputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
auto encryption = std::get<0>(outputSk).encryption;
if (!encryption.has_value()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
}
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
// CRT encoded TFHE integers
// Decrypt and decode remainders
std::vector<int64_t> remainders;
for (auto modulus : crt) {
uint64_t decrypted = 0;
lweSecretKey.decrypt(ciphertext, decrypted);
auto plaintext = crt::decode(decrypted, modulus);
remainders.push_back(plaintext);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
// Compute the inverse crt
output = crt::iCrt(crt, remainders);
// Further decode signed integers
if (encryption->encoding.isSigned) {
uint64_t maxPos = 1;
for (auto prime : encryption->encoding.crt) {
maxPos *= prime;
}
maxPos /= 2;
if (output >= maxPos) {
output -= maxPos * 2;
}
}
} else {
// Native encoded TFHE integers - 1 blocks with one padding bits
uint64_t plaintext = 0;
lweSecretKey.decrypt(ciphertext, plaintext);
// Decode unsigned integer
uint64_t precision = encryption->encoding.precision;
output = plaintext >> (64 - precision - 2);
auto carry = output % 2;
uint64_t mod = (((uint64_t)1) << (precision + 1));
output = ((output >> 1) + carry) % mod;
// Further decode signed integers.
if (encryption->encoding.isSigned) {
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
if (output >= maxPos) { // The output is actually negative.
// Set the preceding bits to zero
output |= UINT64_MAX << precision;
// This makes sure when the value is cast to int64, it has the correct
// value
};
}
}
return outcome::success();
}
const std::vector<LweSecretKey> &KeySet::getSecretKeys() { return secretKeys; }
const std::vector<LweBootstrapKey> &KeySet::getBootstrapKeys() {
return bootstrapKeys;
}
const std::vector<LweKeyswitchKey> &KeySet::getKeyswitchKeys() {
return keyswitchKeys;
}
const std::vector<PackingKeyswitchKey> &KeySet::getPackingKeyswitchKeys() {
return packingKeyswitchKeys;
}
} // namespace clientlib
} // namespace concretelang