enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs

This commit is contained in:
Quentin Bourgerie
2021-08-24 15:02:45 +02:00
parent dba76a1e1b
commit af0789f128
11 changed files with 701 additions and 73 deletions

View File

@@ -42,7 +42,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
auto e = keySet->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
if (e) {
return e;
return std::move(e);
}
}
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
@@ -60,7 +60,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
bootstrapKeyParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
return std::move(e);
}
}
for (auto keyswitchParam : params.keyswitchKeys) {
@@ -68,7 +68,7 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
keyswitchParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
return std::move(e);
}
}
}
@@ -112,9 +112,8 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
CAPI_ERR_TO_LLVM_ERROR(
sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}),
"cannot allocate secret key");
CAPI_ERR_TO_LLVM_ERROR(sk = allocate_lwe_secret_key_u64(&err, {param.size}),
"cannot allocate secret key");
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
"cannot fill secret key with random generator")
secretKeys[id] = {param, sk};
@@ -250,6 +249,7 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t &output) {
if (argPos >= outputs.size()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: position of argument is too high",
@@ -262,13 +262,14 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
llvm::inconvertibleErrorCode());
}
// Decrypt
Plaintext_u64 plaintext;
Plaintext_u64 plaintext = {0};
CAPI_ERR_TO_LLVM_ERROR(
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
"cannot decrypt");
// Decode
output = plaintext._0 >>
(64 - (std::get<0>(outputSk).encryption->encoding.precision + 1));
return llvm::Error::success();
}