// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include "concretelang/ClientLib/PublicArguments.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" namespace concretelang { namespace clientlib { template std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) { writeSize(ostream, (uint64_t)buffer.size()); ostream.write((const char *)buffer.buffer(), buffer.size() * sizeof(uint64_t)); assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, std::shared_ptr> &vec) { // TODO assertion on size? uint64_t size; readSize(istream, size); vec->resize(size); istream.read((char *)vec->data(), size * sizeof(uint64_t)); assert(istream.good()); return istream; } // LweSecretKey //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) { writeWord(ostream, param.dimension); return ostream; } std::istream &operator>>(std::istream &istream, LweSecretKeyParam ¶m) { readWord(istream, param.dimension); return istream; } // LweSecretKey ///////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweSecretKey readLweSecretKey(std::istream &istream) { LweSecretKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweSecretKey(buffer, param); } // KeyswitchKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.variance); return ostream; } std::istream &operator>>(std::istream &istream, KeyswitchKeyParam ¶m) { // TODO keys id param.outputSecretKeyID = 1234; param.inputSecretKeyID = 1234; readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.variance); return istream; } // LweKeyswitchKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) { KeyswitchKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweKeyswitchKey(buffer, param); } // BootstrapKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.glweDimension); writeWord(ostream, param.variance); writeWord(ostream, param.polynomialSize); writeWord(ostream, param.inputLweDimension); return ostream; } std::istream &operator>>(std::istream &istream, BootstrapKeyParam ¶m) { // TODO keys id readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.glweDimension); readWord(istream, param.variance); readWord(istream, param.polynomialSize); readWord(istream, param.inputLweDimension); return istream; } // LweBootstrapKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweBootstrapKey readLweBootstrapKey(std::istream &istream) { BootstrapKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweBootstrapKey(buffer, param); } // PackingKeyswitchKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const PackingKeyswitchKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.glweDimension); writeWord(ostream, param.polynomialSize); writeWord(ostream, param.inputLweDimension); writeWord(ostream, param.variance); return ostream; } std::istream &operator>>(std::istream &istream, PackingKeyswitchKeyParam ¶m) { // TODO keys id param.outputSecretKeyID = 1234; param.inputSecretKeyID = 1234; readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.glweDimension); readWord(istream, param.polynomialSize); readWord(istream, param.inputLweDimension); readWord(istream, param.variance); return istream; } // PackingKeyswitchKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const PackingKeyswitchKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) { PackingKeyswitchKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; auto b = PackingKeyswitchKey(buffer, param); return b; } // KeySet //////////////////////////////// std::unique_ptr readKeySet(std::istream &istream) { uint64_t nbKey; readSize(istream, nbKey); std::vector secretKeys; for (uint64_t i = 0; i < nbKey; i++) { secretKeys.push_back(readLweSecretKey(istream)); } readSize(istream, nbKey); std::vector bootstrapKeys; for (uint64_t i = 0; i < nbKey; i++) { bootstrapKeys.push_back(readLweBootstrapKey(istream)); } readSize(istream, nbKey); std::vector keyswitchKeys; for (uint64_t i = 0; i < nbKey; i++) { keyswitchKeys.push_back(readLweKeyswitchKey(istream)); } std::vector packingKeyswitchKeys; readSize(istream, nbKey); for (uint64_t i = 0; i < nbKey; i++) { packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream)); } std::string clientParametersString; istream >> clientParametersString; auto clientParameters = llvm::json::parse(clientParametersString); if (!clientParameters) { return std::unique_ptr(nullptr); } auto csprng = ConcreteCSPRNG(0); auto keySet = KeySet::fromKeys(clientParameters.get(), secretKeys, bootstrapKeys, keyswitchKeys, packingKeyswitchKeys, std::move(csprng)); return std::move(keySet.value()); } std::ostream &operator<<(std::ostream &ostream, const KeySet &keySet) { auto secretKeys = keySet.getSecretKeys(); writeSize(ostream, secretKeys.size()); for (auto sk : secretKeys) { ostream << sk; } auto bootstrapKeys = keySet.getBootstrapKeys(); writeSize(ostream, bootstrapKeys.size()); for (auto bsk : bootstrapKeys) { ostream << bsk; } auto keyswitchKeys = keySet.getKeyswitchKeys(); writeSize(ostream, keyswitchKeys.size()); for (auto ksk : keyswitchKeys) { ostream << ksk; } auto packingKeyswitchKeys = keySet.getPackingKeyswitchKeys(); writeSize(ostream, packingKeyswitchKeys.size()); for (auto pksk : packingKeyswitchKeys) { ostream << pksk; } auto clientParametersJson = llvm::json::Value(keySet.clientParameters()); std::string clientParametersString; llvm::raw_string_ostream clientParametersStringBuffer(clientParametersString); clientParametersStringBuffer << clientParametersJson; ostream << clientParametersString; assert(ostream.good()); return ostream; } // EvaluationKey //////////////////////////////// EvaluationKeys readEvaluationKeys(std::istream &istream) { uint64_t nbKey; readSize(istream, nbKey); std::vector bootstrapKeys; for (uint64_t i = 0; i < nbKey; i++) { bootstrapKeys.push_back(readLweBootstrapKey(istream)); } readSize(istream, nbKey); std::vector keyswitchKeys; for (uint64_t i = 0; i < nbKey; i++) { keyswitchKeys.push_back(readLweKeyswitchKey(istream)); } std::vector packingKeyswitchKeys; readSize(istream, nbKey); for (uint64_t i = 0; i < nbKey; i++) { packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream)); } return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys); } std::ostream &operator<<(std::ostream &ostream, const EvaluationKeys &evaluationKeys) { auto bootstrapKeys = evaluationKeys.getBootstrapKeys(); writeSize(ostream, bootstrapKeys.size()); for (auto bsk : bootstrapKeys) { ostream << bsk; } auto keyswitchKeys = evaluationKeys.getKeyswitchKeys(); writeSize(ostream, keyswitchKeys.size()); for (auto ksk : keyswitchKeys) { ostream << ksk; } auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys(); writeSize(ostream, packingKeyswitchKeys.size()); for (auto pksk : packingKeyswitchKeys) { ostream << pksk; } assert(ostream.good()); return ostream; } // TensorData /////////////////////////////////// template std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) { writeWord(ostream, sizeof(T) * 8); writeWord(ostream, std::is_signed()); writeWord(ostream, value); return ostream; } std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) { switch (sd.getType()) { case ElementType::u64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u8: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i8: return serializeScalarDataRaw(sd.getValue(), ostream); } return ostream; } template ScalarData unserializeScalarValue(std::istream &istream) { T value; readWord(istream, value); return ScalarData(value); } outcome::checked unserializeScalarData(std::istream &istream) { uint64_t scalarWidth; readWord(istream, scalarWidth); switch (scalarWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Scalar width must be either 64, 32, 16 or 8, but got ") << scalarWidth; } uint8_t scalarSignedness; readWord(istream, scalarSignedness); if (scalarSignedness != 0 && scalarSignedness != 1) { return StringError("Numerical value for scalar signedness must be either " "0 or 1, but got ") << scalarSignedness; } switch (scalarWidth) { case 64: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 32: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 16: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 8: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); } assert(false && "Unhandled scalar type"); } template static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes, std::istream &istream) { // getElementPointer is not valid if the tensor contains no data if (values_and_sizes.getNumElements() > 0) { readWords(istream, values_and_sizes.getElementPointer(0), values_and_sizes.getNumElements()); } return istream; } std::ostream &serializeTensorData(const TensorData &values_and_sizes, std::ostream &ostream) { switch (values_and_sizes.getElementType()) { case ElementType::u64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); } assert(false && "Unhandled element type"); } outcome::checked unserializeTensorData(std::istream &istream) { if (incorrectMode(istream)) { return StringError("Stream is in incorrect mode"); } uint64_t numDimensions; readWord(istream, numDimensions); std::vector dims; for (uint64_t i = 0; i < numDimensions; i++) { int64_t dimSize; readWord(istream, dimSize); dims.push_back(dimSize); } uint64_t elementWidth; readWord(istream, elementWidth); switch (elementWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Element width must be either 64, 32, 16 or 8, but got ") << elementWidth; } uint8_t elementSignedness; readWord(istream, elementSignedness); if (elementSignedness != 0 && elementSignedness != 1) { return StringError("Numerical value for element signedness must be either " "0 or 1, but got ") << elementSignedness; } TensorData result(dims, elementWidth, elementSignedness == 1); switch (result.getElementType()) { case ElementType::u64: unserializeTensorDataElements(result, istream); break; case ElementType::i64: unserializeTensorDataElements(result, istream); break; case ElementType::u32: unserializeTensorDataElements(result, istream); break; case ElementType::i32: unserializeTensorDataElements(result, istream); break; case ElementType::u16: unserializeTensorDataElements(result, istream); break; case ElementType::i16: unserializeTensorDataElements(result, istream); break; case ElementType::u8: unserializeTensorDataElements(result, istream); break; case ElementType::i8: unserializeTensorDataElements(result, istream); break; } return std::move(result); } std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, std::ostream &ostream) { writeWord(ostream, sotd.isTensor()); if (sotd.isTensor()) return serializeTensorData(sotd.getTensor(), ostream); else return serializeScalarData(sotd.getScalar(), ostream); } outcome::checked unserializeScalarOrTensorData(std::istream &istream) { uint8_t isTensor; readWord(istream, isTensor); if (isTensor != 0 && isTensor != 1) { return StringError("Numerical value indicating whether a data element is a " "tensor must be either 0 or 1, but got ") << isTensor; } if (isTensor) { auto tdOrErr = unserializeTensorData(istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } else { auto tdOrErr = unserializeScalarData(istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } } std::ostream &serializeVectorOfScalarOrTensorData( const std::vector &v, std::ostream &ostream) { writeSize(ostream, v.size()); for (auto &sotd : v) { serializeScalarOrTensorData(sotd.get(), ostream); if (!ostream.good()) { return ostream; } } return ostream; } outcome::checked, StringError> unserializeVectorOfScalarOrTensorData(std::istream &istream) { uint64_t nbElt; readSize(istream, nbElt); std::vector v; for (uint64_t i = 0; i < nbElt; i++) { OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream)); v.push_back(SharedScalarOrTensorData(std::move(elt))); } return v; } } // namespace clientlib } // namespace concretelang