mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
enhance(client/server): Don't decrypt directly from istream use a intermediate container to represent public result
This commit is contained in:
@@ -46,35 +46,15 @@ ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
|
||||
seed_lsb);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
|
||||
std::ostream &ostream) {
|
||||
return serverArguments.serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
|
||||
return v[0];
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
|
||||
auto lweSize =
|
||||
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
|
||||
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
|
||||
if (istream.fail()) {
|
||||
return StringError("Encrypted scalars has not the right size");
|
||||
}
|
||||
auto len = encryptedValues.length();
|
||||
decrypted_tensor_1_t decryptedValues(len / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
|
||||
}
|
||||
return decryptedValues;
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
|
||||
return result.decryptVector(keySet, 0);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
@@ -128,7 +108,7 @@ decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
outcome::checked<DecryptedTensor, StringError>
|
||||
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
|
||||
ClientParameters ¶ms, size_t expectedRank,
|
||||
KeySet &keySet) {
|
||||
auto shape = params.outputs[0].shape;
|
||||
@@ -137,7 +117,7 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
return StringError("Function returns a tensor of rank ")
|
||||
<< expectedRank << " which cannot be decrypted to rank " << rank;
|
||||
}
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
|
||||
llvm::SmallVector<size_t, 6> sizes;
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
sizes.push_back(shape.dimensions[dim]);
|
||||
@@ -146,27 +126,27 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_1_t>(
|
||||
istream, *this, this->clientParameters, 1, keySet);
|
||||
result, *this, this->clientParameters, 1, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_2_t>(
|
||||
istream, *this, this->clientParameters, 2, keySet);
|
||||
result, *this, this->clientParameters, 2, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_3_t>(
|
||||
istream, *this, this->clientParameters, 3, keySet);
|
||||
result, *this, this->clientParameters, 3, keySet);
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
PublicResult &result) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(COMPATIBLE_RESULT_TYPE)0;
|
||||
@@ -175,32 +155,32 @@ topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedScalar(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedScalar(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor1(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor1(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor2(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor2(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor3(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor3(keySet, result);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
|
||||
Reference in New Issue
Block a user