// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include "concrete-core-ffi.h" #include "concretelang/ClientLib/PublicArguments.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" namespace concretelang { namespace clientlib { template Result read_deser(std::istream &istream, int (*deser)(Engine *, BufferView, Result *), Engine *engine) { size_t length; readSize(istream, length); // buffer is too big to be allocated on stack // vector ensures everything is deallocated w.r.t. new std::vector buffer(length); istream.read((char *)buffer.data(), length); assert(istream.good()); Result result; CAPI_ASSERT_ERROR(deser(engine, {buffer.data(), length}, &result)); return result; } template std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) { writeSize(ostream, buffer.length); ostream.write((const char *)buffer.pointer, buffer.length); assert(ostream.good()); return ostream; } std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) { DefaultSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); Buffer b; CAPI_ASSERT_ERROR( default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key, &b)); writeBufferLike(ostream, b); free((void *)b.pointer); b.pointer = nullptr; return ostream; } std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) { DefaultSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); Buffer b; CAPI_ASSERT_ERROR( default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, &b)) writeBufferLike(ostream, b); free((void *)b.pointer); b.pointer = nullptr; return ostream; } std::ostream &operator<<(std::ostream &ostream, const FftFourierLweBootstrapKey64 *key) { FftSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine)); Buffer b; CAPI_ASSERT_ERROR( fft_serialization_engine_serialize_fft_fourier_lwe_bootstrap_key_u64( engine, key, &b)) writeBufferLike(ostream, b); free((void *)b.pointer); b.pointer = nullptr; return ostream; } std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) { DefaultSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); key = read_deser( istream, default_serialization_engine_deserialize_lwe_keyswitch_key_u64, engine); return istream; } std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) { DefaultSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); key = read_deser( istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64, engine); return istream; } std::istream &operator>>(std::istream &istream, FftFourierLweBootstrapKey64 *&key) { FftSerializationEngine *engine; // No Freeing as it doesn't allocate anything. CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine)); key = read_deser( istream, fft_serialization_engine_deserialize_fft_fourier_lwe_bootstrap_key_u64, engine); return istream; } std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext) { istream >> runtimeContext.evaluationKeys; assert(istream.good()); return istream; } std::ostream &operator<<(std::ostream &ostream, const RuntimeContext &runtimeContext) { ostream << runtimeContext.evaluationKeys; assert(ostream.good()); return ostream; } template std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) { writeWord(ostream, sizeof(T) * 8); writeWord(ostream, std::is_signed()); writeWord(ostream, value); return ostream; } std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) { switch (sd.getType()) { case ElementType::u64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i64: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i32: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i16: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::u8: return serializeScalarDataRaw(sd.getValue(), ostream); case ElementType::i8: return serializeScalarDataRaw(sd.getValue(), ostream); } return ostream; } template ScalarData unserializeScalarValue(std::istream &istream) { T value; readWord(istream, value); return ScalarData(value); } outcome::checked unserializeScalarData(std::istream &istream) { uint64_t scalarWidth; readWord(istream, scalarWidth); switch (scalarWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Scalar width must be either 64, 32, 16 or 8, but got ") << scalarWidth; } uint8_t scalarSignedness; readWord(istream, scalarSignedness); if (scalarSignedness != 0 && scalarSignedness != 1) { return StringError("Numerical value for scalar signedness must be either " "0 or 1, but got ") << scalarSignedness; } switch (scalarWidth) { case 64: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 32: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 16: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); case 8: return (scalarSignedness) ? unserializeScalarValue(istream) : unserializeScalarValue(istream); } assert(false && "Unhandled scalar type"); } template static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes, std::istream &istream) { readWords(istream, values_and_sizes.getElementPointer(0), values_and_sizes.getNumElements()); return istream; } std::ostream &serializeTensorData(const TensorData &values_and_sizes, std::ostream &ostream) { switch (values_and_sizes.getElementType()) { case ElementType::u64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i64: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i32: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i16: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::u8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); case ElementType::i8: return serializeTensorDataRaw( values_and_sizes.getDimensions(), values_and_sizes.getElements(), ostream); } assert(false && "Unhandled element type"); } outcome::checked unserializeTensorData( const std::vector &expectedSizes, // includes lweSize, unsigned to // accomodate non static sizes std::istream &istream) { if (incorrectMode(istream)) { return StringError("Stream is in incorrect mode"); } uint64_t numDimensions; readWord(istream, numDimensions); std::vector dims; for (uint64_t i = 0; i < numDimensions; i++) { int64_t dimSize; readWord(istream, dimSize); if (dimSize != expectedSizes[i]) { istream.setstate(std::ios::badbit); return StringError("Number of dimensions did not match the number of " "expected dimensions"); } dims.push_back(dimSize); } uint64_t elementWidth; readWord(istream, elementWidth); switch (elementWidth) { case 64: case 32: case 16: case 8: break; default: return StringError("Element width must be either 64, 32, 16 or 8, but got ") << elementWidth; } uint8_t elementSignedness; readWord(istream, elementSignedness); if (elementSignedness != 0 && elementSignedness != 1) { return StringError("Numerical value for element signedness must be either " "0 or 1, but got ") << elementSignedness; } TensorData result(dims, elementWidth, elementSignedness == 1); switch (result.getElementType()) { case ElementType::u64: unserializeTensorDataElements(result, istream); break; case ElementType::i64: unserializeTensorDataElements(result, istream); break; case ElementType::u32: unserializeTensorDataElements(result, istream); break; case ElementType::i32: unserializeTensorDataElements(result, istream); break; case ElementType::u16: unserializeTensorDataElements(result, istream); break; case ElementType::i16: unserializeTensorDataElements(result, istream); break; case ElementType::u8: unserializeTensorDataElements(result, istream); break; case ElementType::i8: unserializeTensorDataElements(result, istream); break; } return std::move(result); } std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd, std::ostream &ostream) { writeWord(ostream, sotd.isTensor()); if (sotd.isTensor()) return serializeTensorData(sotd.getTensor(), ostream); else return serializeScalarData(sotd.getScalar(), ostream); } outcome::checked unserializeScalarOrTensorData(const std::vector &expectedSizes, std::istream &istream) { uint8_t isTensor; readWord(istream, isTensor); if (isTensor != 0 && isTensor != 1) { return StringError("Numerical value indicating whether a data element is a " "tensor must be either 0 or 1, but got ") << isTensor; } if (isTensor) { auto tdOrErr = unserializeTensorData(expectedSizes, istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } else { auto tdOrErr = unserializeScalarData(istream); if (tdOrErr.has_error()) return std::move(tdOrErr.error()); else return ScalarOrTensorData(std::move(tdOrErr.value())); } } std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &wrappedKsk) { ostream << wrappedKsk.ksk; assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) { istream >> wrappedKsk.ksk; assert(istream.good()); return istream; } std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &wrappedBsk) { ostream << wrappedBsk.bsk; assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) { istream >> wrappedBsk.bsk; assert(istream.good()); return istream; } std::ostream &operator<<(std::ostream &ostream, const EvaluationKeys &evaluationKeys) { bool has_ksk = (bool)evaluationKeys.sharedKsk; writeWord(ostream, has_ksk); if (has_ksk) { ostream << *evaluationKeys.sharedKsk; } bool has_bsk = (bool)evaluationKeys.sharedBsk; writeWord(ostream, has_bsk); if (has_bsk) { ostream << *evaluationKeys.sharedBsk; } assert(ostream.good()); return ostream; } std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys) { bool has_ksk; readWord(istream, has_ksk); if (has_ksk) { auto sharedKsk = LweKeyswitchKey(nullptr); istream >> sharedKsk; evaluationKeys.sharedKsk = std::make_shared(std::move(sharedKsk)); } bool has_bsk; readWord(istream, has_bsk); if (has_bsk) { auto sharedBsk = LweBootstrapKey(nullptr); istream >> sharedBsk; evaluationKeys.sharedBsk = std::make_shared(std::move(sharedBsk)); } assert(istream.good()); return istream; } } // namespace clientlib } // namespace concretelang