mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): template result of typed vector from PublicResult
This commit is contained in:
@@ -90,8 +90,33 @@ struct PublicResult {
|
||||
|
||||
/// Get the result at `pos` as a vector, if the result is a scalar returns a
|
||||
/// vector of size 1. Decryption happens if the result is encrypted.
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos);
|
||||
// outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
// asClearTextVector(KeySet &keySet, size_t pos);
|
||||
|
||||
template <typename T>
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted()) {
|
||||
std::vector<T> result;
|
||||
result.reserve(buffers[pos].values.size());
|
||||
std::copy(buffers[pos].values.begin(), buffers[pos].values.end(),
|
||||
std::back_inserter(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
|
||||
std::vector<T> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
uint64_t decrypted;
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted));
|
||||
decryptedValues[i] = decrypted;
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
|
||||
@@ -37,7 +37,7 @@ llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
auto clearResult = result.asClearTextVector<uint64_t>(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
@@ -52,7 +52,7 @@ inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
auto clearResult = result.asClearTextVector<T>(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedVectorResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
@@ -68,21 +68,21 @@ typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
// to ambiguity with scalar template
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint8_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint16_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint16_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint32_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint32_t>(keySet, result);
|
||||
// }
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint8_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint8_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint16_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint16_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint32_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint32_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
@@ -113,7 +113,7 @@ typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto gate = keySet.outputGate(0);
|
||||
// scalar case
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
auto clearResult = result.asClearTextVector<uint64_t>(keySet, 0);
|
||||
if (clearResult.has_error()) {
|
||||
return StreamStringError("typedResult: ") << clearResult.error().mesg;
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
|
||||
return result.asClearTextVector(keySet, 0);
|
||||
return result.asClearTextVector<decrypted_scalar_t>(keySet, 0);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
|
||||
@@ -136,24 +136,6 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
return sArguments;
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<uint64_t>, StringError>
|
||||
PublicResult::asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted()) {
|
||||
return buffers[pos].values;
|
||||
}
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
|
||||
std::vector<uint64_t> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
|
||||
// increase multi dim index
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
|
||||
Reference in New Issue
Block a user