enhance(client/server): Don't decrypt directly from istream use a intermediate container to represent public result

This commit is contained in:
Quentin Bourgerie
2022-02-28 16:54:29 +01:00
parent 69037cd1fa
commit 73da7da81c
14 changed files with 181 additions and 203 deletions

View File

@@ -46,35 +46,15 @@ ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
seed_lsb);
}
outcome::checked<void, StringError>
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
std::ostream &ostream) {
return serverArguments.serialize(ostream);
}
outcome::checked<decrypted_scalar_t, StringError>
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
return v[0];
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
auto lweSize =
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
sizes.push_back(lweSize);
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
if (istream.fail()) {
return StringError("Encrypted scalars has not the right size");
}
auto len = encryptedValues.length();
decrypted_tensor_1_t decryptedValues(len / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
}
return decryptedValues;
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
return result.decryptVector(keySet, 0);
}
outcome::checked<void, StringError> errorResultRank(size_t expected,
@@ -128,7 +108,7 @@ decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
template <typename DecryptedTensor>
outcome::checked<DecryptedTensor, StringError>
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
ClientParameters &params, size_t expectedRank,
KeySet &keySet) {
auto shape = params.outputs[0].shape;
@@ -137,7 +117,7 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
return StringError("Function returns a tensor of rank ")
<< expectedRank << " which cannot be decrypted to rank " << rank;
}
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
llvm::SmallVector<size_t, 6> sizes;
for (size_t dim = 0; dim < rank; dim++) {
sizes.push_back(shape.dimensions[dim]);
@@ -146,27 +126,27 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
}
outcome::checked<decrypted_tensor_1_t, StringError>
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_1_t>(
istream, *this, this->clientParameters, 1, keySet);
result, *this, this->clientParameters, 1, keySet);
}
outcome::checked<decrypted_tensor_2_t, StringError>
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_2_t>(
istream, *this, this->clientParameters, 2, keySet);
result, *this, this->clientParameters, 2, keySet);
}
outcome::checked<decrypted_tensor_3_t, StringError>
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_3_t>(
istream, *this, this->clientParameters, 3, keySet);
result, *this, this->clientParameters, 3, keySet);
}
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
PublicResult &result) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(COMPATIBLE_RESULT_TYPE)0;
@@ -175,32 +155,32 @@ topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedScalar(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedScalar(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor1(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor1(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor2(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor2(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_3_t, StringError>
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor3(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor3(keySet, result);
}
} // namespace clientlib