Files
concrete/compiler/lib/ClientLib/Serializers.cpp
2022-11-16 11:44:46 +01:00

469 lines
14 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iosfwd>
#include <iostream>
#include <stdlib.h>
#include "concrete-core-ffi.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
template <typename Engine, typename Result>
Result read_deser(std::istream &istream,
int (*deser)(Engine *, BufferView, Result *),
Engine *engine) {
size_t length;
readSize(istream, length);
// buffer is too big to be allocated on stack
// vector ensures everything is deallocated w.r.t. new
std::vector<uint8_t> buffer(length);
istream.read((char *)buffer.data(), length);
assert(istream.good());
Result result;
CAPI_ASSERT_ERROR(deser(engine, {buffer.data(), length}, &result));
return result;
}
template <typename BufferLike>
std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) {
writeSize(ostream, buffer.length);
ostream.write((const char *)buffer.pointer, buffer.length);
assert(ostream.good());
return ostream;
}
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
&b));
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&b))
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream,
const FftFourierLweBootstrapKey64 *key) {
FftSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
fft_serialization_engine_serialize_fft_fourier_lwe_bootstrap_key_u64(
engine, key, &b))
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
key = read_deser(
istream, default_serialization_engine_deserialize_lwe_keyswitch_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
key = read_deser(
istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream,
FftFourierLweBootstrapKey64 *&key) {
FftSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
key = read_deser(
istream,
fft_serialization_engine_deserialize_fft_fourier_lwe_bootstrap_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream,
RuntimeContext &runtimeContext) {
istream >> runtimeContext.evaluationKeys;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext) {
ostream << runtimeContext.evaluationKeys;
assert(ostream.good());
return ostream;
}
template <typename T>
std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) {
writeWord<uint64_t>(ostream, sizeof(T) * 8);
writeWord<uint8_t>(ostream, std::is_signed<T>());
writeWord<T>(ostream, value);
return ostream;
}
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) {
switch (sd.getType()) {
case ElementType::u64:
return serializeScalarDataRaw<uint64_t>(sd.getValue<uint64_t>(), ostream);
case ElementType::i64:
return serializeScalarDataRaw<int64_t>(sd.getValue<int64_t>(), ostream);
case ElementType::u32:
return serializeScalarDataRaw<uint32_t>(sd.getValue<uint32_t>(), ostream);
case ElementType::i32:
return serializeScalarDataRaw<int32_t>(sd.getValue<int32_t>(), ostream);
case ElementType::u16:
return serializeScalarDataRaw<uint16_t>(sd.getValue<uint16_t>(), ostream);
case ElementType::i16:
return serializeScalarDataRaw<int16_t>(sd.getValue<int16_t>(), ostream);
case ElementType::u8:
return serializeScalarDataRaw<uint8_t>(sd.getValue<uint8_t>(), ostream);
case ElementType::i8:
return serializeScalarDataRaw<int8_t>(sd.getValue<int8_t>(), ostream);
}
return ostream;
}
template <typename T> ScalarData unserializeScalarValue(std::istream &istream) {
T value;
readWord(istream, value);
return ScalarData(value);
}
outcome::checked<ScalarData, StringError>
unserializeScalarData(std::istream &istream) {
uint64_t scalarWidth;
readWord(istream, scalarWidth);
switch (scalarWidth) {
case 64:
case 32:
case 16:
case 8:
break;
default:
return StringError("Scalar width must be either 64, 32, 16 or 8, but got ")
<< scalarWidth;
}
uint8_t scalarSignedness;
readWord(istream, scalarSignedness);
if (scalarSignedness != 0 && scalarSignedness != 1) {
return StringError("Numerical value for scalar signedness must be either "
"0 or 1, but got ")
<< scalarSignedness;
}
switch (scalarWidth) {
case 64:
return (scalarSignedness) ? unserializeScalarValue<int64_t>(istream)
: unserializeScalarValue<uint64_t>(istream);
case 32:
return (scalarSignedness) ? unserializeScalarValue<int32_t>(istream)
: unserializeScalarValue<uint32_t>(istream);
case 16:
return (scalarSignedness) ? unserializeScalarValue<int16_t>(istream)
: unserializeScalarValue<uint16_t>(istream);
case 8:
return (scalarSignedness) ? unserializeScalarValue<int8_t>(istream)
: unserializeScalarValue<uint8_t>(istream);
}
assert(false && "Unhandled scalar type");
}
template <typename T>
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
std::istream &istream) {
readWords(istream, values_and_sizes.getElementPointer<T>(0),
values_and_sizes.getNumElements());
return istream;
}
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
std::ostream &ostream) {
switch (values_and_sizes.getElementType()) {
case ElementType::u64:
return serializeTensorDataRaw<uint64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint64_t>(), ostream);
case ElementType::i64:
return serializeTensorDataRaw<int64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int64_t>(), ostream);
case ElementType::u32:
return serializeTensorDataRaw<uint32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint32_t>(), ostream);
case ElementType::i32:
return serializeTensorDataRaw<int32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int32_t>(), ostream);
case ElementType::u16:
return serializeTensorDataRaw<uint16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint16_t>(), ostream);
case ElementType::i16:
return serializeTensorDataRaw<int16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int16_t>(), ostream);
case ElementType::u8:
return serializeTensorDataRaw<uint8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint8_t>(), ostream);
case ElementType::i8:
return serializeTensorDataRaw<int8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int8_t>(), ostream);
}
assert(false && "Unhandled element type");
}
outcome::checked<TensorData, StringError> unserializeTensorData(
const std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream) {
if (incorrectMode(istream)) {
return StringError("Stream is in incorrect mode");
}
uint64_t numDimensions;
readWord(istream, numDimensions);
std::vector<size_t> dims;
for (uint64_t i = 0; i < numDimensions; i++) {
int64_t dimSize;
readWord(istream, dimSize);
if (dimSize != expectedSizes[i]) {
istream.setstate(std::ios::badbit);
return StringError("Number of dimensions did not match the number of "
"expected dimensions");
}
dims.push_back(dimSize);
}
uint64_t elementWidth;
readWord(istream, elementWidth);
switch (elementWidth) {
case 64:
case 32:
case 16:
case 8:
break;
default:
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
<< elementWidth;
}
uint8_t elementSignedness;
readWord(istream, elementSignedness);
if (elementSignedness != 0 && elementSignedness != 1) {
return StringError("Numerical value for element signedness must be either "
"0 or 1, but got ")
<< elementSignedness;
}
TensorData result(dims, elementWidth, elementSignedness == 1);
switch (result.getElementType()) {
case ElementType::u64:
unserializeTensorDataElements<uint64_t>(result, istream);
break;
case ElementType::i64:
unserializeTensorDataElements<int64_t>(result, istream);
break;
case ElementType::u32:
unserializeTensorDataElements<uint32_t>(result, istream);
break;
case ElementType::i32:
unserializeTensorDataElements<int32_t>(result, istream);
break;
case ElementType::u16:
unserializeTensorDataElements<uint16_t>(result, istream);
break;
case ElementType::i16:
unserializeTensorDataElements<int16_t>(result, istream);
break;
case ElementType::u8:
unserializeTensorDataElements<uint8_t>(result, istream);
break;
case ElementType::i8:
unserializeTensorDataElements<int8_t>(result, istream);
break;
}
return std::move(result);
}
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
std::ostream &ostream) {
writeWord<uint8_t>(ostream, sotd.isTensor());
if (sotd.isTensor())
return serializeTensorData(sotd.getTensor(), ostream);
else
return serializeScalarData(sotd.getScalar(), ostream);
}
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
std::istream &istream) {
uint8_t isTensor;
readWord(istream, isTensor);
if (isTensor != 0 && isTensor != 1) {
return StringError("Numerical value indicating whether a data element is a "
"tensor must be either 0 or 1, but got ")
<< isTensor;
}
if (isTensor) {
auto tdOrErr = unserializeTensorData(expectedSizes, istream);
if (tdOrErr.has_error())
return std::move(tdOrErr.error());
else
return ScalarOrTensorData(std::move(tdOrErr.value()));
} else {
auto tdOrErr = unserializeScalarData(istream);
if (tdOrErr.has_error())
return std::move(tdOrErr.error());
else
return ScalarOrTensorData(std::move(tdOrErr.value()));
}
}
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk) {
ostream << wrappedKsk.ksk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) {
istream >> wrappedKsk.ksk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk) {
ostream << wrappedBsk.bsk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
istream >> wrappedBsk.bsk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
bool has_ksk = (bool)evaluationKeys.sharedKsk;
writeWord(ostream, has_ksk);
if (has_ksk) {
ostream << *evaluationKeys.sharedKsk;
}
bool has_bsk = (bool)evaluationKeys.sharedBsk;
writeWord(ostream, has_bsk);
if (has_bsk) {
ostream << *evaluationKeys.sharedBsk;
}
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys) {
bool has_ksk;
readWord(istream, has_ksk);
if (has_ksk) {
auto sharedKsk = LweKeyswitchKey(nullptr);
istream >> sharedKsk;
evaluationKeys.sharedKsk =
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
}
bool has_bsk;
readWord(istream, has_bsk);
if (has_bsk) {
auto sharedBsk = LweBootstrapKey(nullptr);
istream >> sharedBsk;
evaluationKeys.sharedBsk =
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
}
assert(istream.good());
return istream;
}
} // namespace clientlib
} // namespace concretelang