From 0d376bc5590d9a72fd41b3be5de438536f22219b Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 22 Mar 2022 15:00:52 +0100 Subject: [PATCH] enhance(compiler): template result of typed vector from PublicResult --- .../concretelang/ClientLib/PublicArguments.h | 29 +++++++++++++-- .../concretelang/Support/LambdaSupport.h | 36 +++++++++---------- compiler/lib/ClientLib/ClientLambda.cpp | 2 +- compiler/lib/ClientLib/PublicArguments.cpp | 18 ---------- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 8485ccbb3..490987f3b 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -90,8 +90,33 @@ struct PublicResult { /// Get the result at `pos` as a vector, if the result is a scalar returns a /// vector of size 1. Decryption happens if the result is encrypted. - outcome::checked, StringError> - asClearTextVector(KeySet &keySet, size_t pos); + // outcome::checked, StringError> + // asClearTextVector(KeySet &keySet, size_t pos); + + template + outcome::checked, StringError> + asClearTextVector(KeySet &keySet, size_t pos) { + OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); + if (!gate.isEncrypted()) { + std::vector result; + result.reserve(buffers[pos].values.size()); + std::copy(buffers[pos].values.begin(), buffers[pos].values.end(), + std::back_inserter(result)); + return result; + } + + auto buffer = buffers[pos]; + auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); + + std::vector decryptedValues(buffer.length() / lweSize); + for (size_t i = 0; i < decryptedValues.size(); i++) { + auto ciphertext = &buffer.values[i * lweSize]; + uint64_t decrypted; + OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted)); + decryptedValues[i] = decrypted; + } + return decryptedValues; + } // private: TODO tmp friend class ::concretelang::serverlib::ServerLambda; diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index a6012dce7..2ff7705dd 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -37,7 +37,7 @@ llvm::Expected typedResult(clientlib::KeySet &keySet, template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { - auto clearResult = result.asClearTextVector(keySet, 0); + auto clearResult = result.asClearTextVector(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedResult cannot get clear text vector") << clearResult.error().mesg; @@ -52,7 +52,7 @@ inline llvm::Expected typedResult(clientlib::KeySet &keySet, template inline llvm::Expected> typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { - auto clearResult = result.asClearTextVector(keySet, 0); + auto clearResult = result.asClearTextVector(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedVectorResult cannot get clear text vector") << clearResult.error().mesg; @@ -68,21 +68,21 @@ typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { // llvm::Expected> // typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due // to ambiguity with scalar template -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } +template <> +inline llvm::Expected> +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + return typedVectorResult(keySet, result); +} +template <> +inline llvm::Expected> +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + return typedVectorResult(keySet, result); +} +template <> +inline llvm::Expected> +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + return typedVectorResult(keySet, result); +} template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { @@ -113,7 +113,7 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto gate = keySet.outputGate(0); // scalar case if (gate.shape.dimensions.empty()) { - auto clearResult = result.asClearTextVector(keySet, 0); + auto clearResult = result.asClearTextVector(keySet, 0); if (clearResult.has_error()) { return StreamStringError("typedResult: ") << clearResult.error().mesg; } diff --git a/compiler/lib/ClientLib/ClientLambda.cpp b/compiler/lib/ClientLib/ClientLambda.cpp index 5907546f5..67a371e6b 100644 --- a/compiler/lib/ClientLib/ClientLambda.cpp +++ b/compiler/lib/ClientLib/ClientLambda.cpp @@ -52,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) { outcome::checked, StringError> ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) { - return result.asClearTextVector(keySet, 0); + return result.asClearTextVector(keySet, 0); } outcome::checked errorResultRank(size_t expected, diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index 57dcf8ff7..cfd0a7d17 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -136,24 +136,6 @@ PublicArguments::unserialize(ClientParameters &clientParameters, return sArguments; } -outcome::checked, StringError> -PublicResult::asClearTextVector(KeySet &keySet, size_t pos) { - OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); - if (!gate.isEncrypted()) { - return buffers[pos].values; - } - - auto buffer = buffers[pos]; - auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); - - std::vector decryptedValues(buffer.length() / lweSize); - for (size_t i = 0; i < decryptedValues.size(); i++) { - auto ciphertext = &buffer.values[i * lweSize]; - OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i])); - } - return decryptedValues; -} - void next_coord_index(size_t index[], size_t sizes[], size_t rank) { // increase multi dim index for (int r = rank - 1; r >= 0; r--) {