enhance(compiler): template result of typed vector from PublicResult

This commit is contained in:
Quentin Bourgerie
2022-03-22 15:00:52 +01:00
parent 52aa18a848
commit 0d376bc559
4 changed files with 46 additions and 39 deletions

View File

@@ -90,8 +90,33 @@ struct PublicResult {
/// Get the result at `pos` as a vector, if the result is a scalar returns a
/// vector of size 1. Decryption happens if the result is encrypted.
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
asClearTextVector(KeySet &keySet, size_t pos);
// outcome::checked<std::vector<decrypted_scalar_t>, StringError>
// asClearTextVector(KeySet &keySet, size_t pos);
template <typename T>
outcome::checked<std::vector<T>, StringError>
asClearTextVector(KeySet &keySet, size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted()) {
std::vector<T> result;
result.reserve(buffers[pos].values.size());
std::copy(buffers[pos].values.begin(), buffers[pos].values.end(),
std::back_inserter(result));
return result;
}
auto buffer = buffers[pos];
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
uint64_t decrypted;
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted));
decryptedValues[i] = decrypted;
}
return decryptedValues;
}
// private: TODO tmp
friend class ::concretelang::serverlib::ServerLambda;

View File

@@ -37,7 +37,7 @@ llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
template <>
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
auto clearResult = result.asClearTextVector<uint64_t>(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedResult cannot get clear text vector")
<< clearResult.error().mesg;
@@ -52,7 +52,7 @@ inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
auto clearResult = result.asClearTextVector<T>(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedVectorResult cannot get clear text vector")
<< clearResult.error().mesg;
@@ -68,21 +68,21 @@ typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
// to ambiguity with scalar template
// template <>
// inline llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint8_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint16_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint16_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint32_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint32_t>(keySet, result);
// }
template <>
inline llvm::Expected<std::vector<uint8_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint8_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint16_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint16_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint32_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint32_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
@@ -113,7 +113,7 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto gate = keySet.outputGate(0);
// scalar case
if (gate.shape.dimensions.empty()) {
auto clearResult = result.asClearTextVector(keySet, 0);
auto clearResult = result.asClearTextVector<uint64_t>(keySet, 0);
if (clearResult.has_error()) {
return StreamStringError("typedResult: ") << clearResult.error().mesg;
}

View File

@@ -52,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
return result.asClearTextVector(keySet, 0);
return result.asClearTextVector<decrypted_scalar_t>(keySet, 0);
}
outcome::checked<void, StringError> errorResultRank(size_t expected,

View File

@@ -136,24 +136,6 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
return sArguments;
}
outcome::checked<std::vector<uint64_t>, StringError>
PublicResult::asClearTextVector(KeySet &keySet, size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted()) {
return buffers[pos].values;
}
auto buffer = buffers[pos];
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<uint64_t> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
}
return decryptedValues;
}
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
// increase multi dim index
for (int r = rank - 1; r >= 0; r--) {