// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include "concretelang/ClientLib/PublicArguments.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" namespace concretelang { namespace clientlib { template std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) { writeSize(ostream, (uint64_t)buffer.size()); ostream.write((const char *)buffer.buffer(), buffer.size() * sizeof(uint64_t)); assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, std::shared_ptr> &vec) { // TODO assertion on size? uint64_t size; readSize(istream, size); vec->resize(size); istream.read((char *)vec->data(), size * sizeof(uint64_t)); assert(istream.good()); return istream; } // LweSecretKey //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) { writeWord(ostream, param.dimension); return ostream; } std::istream &operator>>(std::istream &istream, LweSecretKeyParam ¶m) { readWord(istream, param.dimension); return istream; } // LweSecretKey ///////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweSecretKey readLweSecretKey(std::istream &istream) { LweSecretKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweSecretKey(buffer, param); } // KeyswitchKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.variance); return ostream; } std::istream &operator>>(std::istream &istream, KeyswitchKeyParam ¶m) { // TODO keys id param.outputSecretKeyID = 1234; param.inputSecretKeyID = 1234; readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.variance); return istream; } // LweKeyswitchKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) { KeyswitchKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweKeyswitchKey(buffer, param); } // BootstrapKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.glweDimension); writeWord(ostream, param.variance); writeWord(ostream, param.polynomialSize); writeWord(ostream, param.inputLweDimension); return ostream; } std::istream &operator>>(std::istream &istream, BootstrapKeyParam ¶m) { // TODO keys id readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.glweDimension); readWord(istream, param.variance); readWord(istream, param.polynomialSize); readWord(istream, param.inputLweDimension); return istream; } // LweBootstrapKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } LweBootstrapKey readLweBootstrapKey(std::istream &istream) { BootstrapKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; return LweBootstrapKey(buffer, param); } // PackingKeyswitchKeyParam //////////////////////////// std::ostream &operator<<(std::ostream &ostream, const PackingKeyswitchKeyParam param) { // TODO keys id writeWord(ostream, param.level); writeWord(ostream, param.baseLog); writeWord(ostream, param.glweDimension); writeWord(ostream, param.polynomialSize); writeWord(ostream, param.inputLweDimension); writeWord(ostream, param.variance); return ostream; } std::istream &operator>>(std::istream &istream, PackingKeyswitchKeyParam ¶m) { // TODO keys id param.outputSecretKeyID = 1234; param.inputSecretKeyID = 1234; readWord(istream, param.level); readWord(istream, param.baseLog); readWord(istream, param.glweDimension); readWord(istream, param.polynomialSize); readWord(istream, param.inputLweDimension); readWord(istream, param.variance); return istream; } // PackingKeyswitchKey ////////////////////////////// std::ostream &operator<<(std::ostream &ostream, const PackingKeyswitchKey &key) { ostream << key.parameters(); writeUInt64KeyBuffer(ostream, key); return ostream; } PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) { PackingKeyswitchKeyParam param; istream >> param; auto buffer = std::make_shared>(); istream >> buffer; auto b = PackingKeyswitchKey(buffer, param); return b; } // EvaluationKey //////////////////////////////// EvaluationKeys readEvaluationKeys(std::istream &istream) { uint64_t nbKey; readSize(istream, nbKey); std::vector bootstrapKeys; for (uint64_t i = 0; i < nbKey; i++) { bootstrapKeys.push_back(readLweBootstrapKey(istream)); } readSize(istream, nbKey); std::vector keyswitchKeys; for (uint64_t i = 0; i < nbKey; i++) { keyswitchKeys.push_back(readLweKeyswitchKey(istream)); } std::vector packingKeyswitchKeys; readSize(istream, nbKey); for (uint64_t i = 0; i < nbKey; i++) { packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream)); } return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys); } std::ostream &operator<<(std::ostream &ostream, const EvaluationKeys &evaluationKeys) { auto bootstrapKeys = evaluationKeys.getBootstrapKeys(); writeSize(ostream, bootstrapKeys.size()); for (auto bsk : bootstrapKeys) { ostream << bsk; } auto keyswitchKeys = evaluationKeys.getKeyswitchKeys(); writeSize(ostream, keyswitchKeys.size()); for (auto ksk : keyswitchKeys) { ostream << ksk; } auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys(); writeSize(ostream, packingKeyswitchKeys.size()); for (auto pksk : packingKeyswitchKeys) { ostream << pksk; } assert(ostream.good()); return ostream; } // TensorData /////////////////////////////////// template std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) { writeWord(ostream, sizeof(T) * 8); writeWord(ostream, std::is_signed()); writeWord(ostream, value); return ostream; } std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) { switch (sd.getType()) { case ElementType::u64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u8: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i8: return serializeScalarDataRaw(sd.getValue(), ostream); } return ostream; } template ScalarData unserializeScalarValue(std::istream &istream) { T value; readWord(istream, value); return ScalarData(value); } outcome::checked unserializeScalarData(std::istream &istream) { uint64_t scalarWidth; readWord(istream, scalarWidth); switch (scalarWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Scalar width must be either 64, 32, 16 or 8, but got ") << scalarWidth; } uint8_t scalarSignedness; readWord(istream, scalarSignedness); if (scalarSignedness != 0 && scalarSignedness != 1) { return StringError("Numerical value for scalar signedness must be either " "0 or 1, but got ") << scalarSignedness; } switch (scalarWidth) { case 64: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 32: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 16: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 8: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); } assert(false && "Unhandled scalar type"); } template static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes, std::istream &istream) { readWords(istream, values_and_sizes.getElementPointer(0), values_and_sizes.getNumElements()); return istream; } std::ostream &serializeTensorData(const TensorData &values_and_sizes, std::ostream &ostream) { switch (values_and_sizes.getElementType()) { case ElementType::u64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); } assert(false && "Unhandled element type"); } outcome::checked unserializeTensorData( const std::vector &expectedSizes, // includes lweSize, unsigned to // accomodate non static sizes std::istream &istream) { if (incorrectMode(istream)) { return StringError("Stream is in incorrect mode"); } uint64_t numDimensions; readWord(istream, numDimensions); std::vector dims; for (uint64_t i = 0; i < numDimensions; i++) { int64_t dimSize; readWord(istream, dimSize); if (dimSize != expectedSizes[i]) { istream.setstate(std::ios::badbit); return StringError("Number of dimensions did not match the number of " "expected dimensions"); } dims.push_back(dimSize); } uint64_t elementWidth; readWord(istream, elementWidth); switch (elementWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Element width must be either 64, 32, 16 or 8, but got ") << elementWidth; } uint8_t elementSignedness; readWord(istream, elementSignedness); if (elementSignedness != 0 && elementSignedness != 1) { return StringError("Numerical value for element signedness must be either " "0 or 1, but got ") << elementSignedness; } TensorData result(dims, elementWidth, elementSignedness == 1); switch (result.getElementType()) { case ElementType::u64: unserializeTensorDataElements(result, istream); break; case ElementType::i64: unserializeTensorDataElements(result, istream); break; case ElementType::u32: unserializeTensorDataElements(result, istream); break; case ElementType::i32: unserializeTensorDataElements(result, istream); break; case ElementType::u16: unserializeTensorDataElements(result, istream); break; case ElementType::i16: unserializeTensorDataElements(result, istream); break; case ElementType::u8: unserializeTensorDataElements(result, istream); break; case ElementType::i8: unserializeTensorDataElements(result, istream); break; } return std::move(result); } std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, std::ostream &ostream) { writeWord(ostream, sotd.isTensor()); if (sotd.isTensor()) return serializeTensorData(sotd.getTensor(), ostream); else return serializeScalarData(sotd.getScalar(), ostream); } outcome::checked unserializeScalarOrTensorData(const std::vector &expectedSizes, std::istream &istream) { uint8_t isTensor; readWord(istream, isTensor); if (isTensor != 0 && isTensor != 1) { return StringError("Numerical value indicating whether a data element is a " "tensor must be either 0 or 1, but got ") << isTensor; } if (isTensor) { auto tdOrErr = unserializeTensorData(expectedSizes, istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } else { auto tdOrErr = unserializeScalarData(istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } } } // namespace clientlib } // namespace concretelang