Merge pull request #6 from brave-experiments/fix-compilation

Fix compilation
This commit is contained in:
Sofía Celi
2023-09-29 17:07:46 +02:00
committed by GitHub
16 changed files with 13 additions and 1498 deletions

3
.gitmodules vendored
View File

@@ -0,0 +1,3 @@
[submodule "src/benchmark"]
path = src/benchmark
url = https://github.com/google/benchmark.git

View File

@@ -76,11 +76,8 @@ set(WARNINGS
-Wnon-virtual-dtor
-Wunused
# These are omitted just so that we can continue piecewise developing.
-Wno-error=unused-parameter
-Wno-error=zero-as-null-pointer-constant
-Wno-error=unused-variable
-Wno-error=unused-parameter
-Wno-error=return-type
-Wno-error=unused-function
-Woverloaded-virtual
-Wsign-conversion
-Wconversion
@@ -92,8 +89,7 @@ set(WARNINGS
-Wformat=2
-Wcast-qual
-Wmissing-declarations
-Wsign-promo
-Wno-error=disabled-optimization)
-Wsign-promo)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
@@ -187,9 +183,6 @@ set(SRC_FILES
ssl/Messaging.cpp
ssl/EmpWrapper.cpp
ssl/EmpWrapperAG2PC.cpp
ssl/EmpSSLSocketManager.cpp
ssl/ThreadSafeSSL.cpp
ssl/EmpThreadSocket.cpp
ssl/SSLBuffer.cpp
ssl/CounterType.cpp
ssl/NoBuffer.cpp
@@ -223,7 +216,7 @@ target_link_libraries(TLSAttestDebugAsan PRIVATE ${L_ASAN_FLAGS})
target_compile_options(TLSAttestOpt PRIVATE ${C_OPT_FLAGS})
target_link_libraries(TLSAttestOpt PRIVATE ${L_FLAGS})
set(SSLTARGETS Util EmpWrapper EmpWrapperAG2PC ThreePartyHandshake StatefulSocket TLSSocket Messaging SSL ThreadSafeSSL EmpSSLSocketManager EmpThreadSocket SSLBuffer NoBuffer CounterType)
set(SSLTARGETS Util EmpWrapper EmpWrapperAG2PC ThreePartyHandshake StatefulSocket TLSSocket Messaging SSL SSLBuffer NoBuffer CounterType)
set(NodeTARGETS Server KeyShare)
set(MTATARGETS MtA F2128MtA EmpBlockOwningSpan EmpBlockNonOwningSpan EmpBlockArray ectf PackArray)
set(2PCTARGETS CircuitSynthesis)

View File

@@ -1,61 +0,0 @@
#include "EmpSSLSocketManager.hpp"
#include "ThreadSafeSSL.hpp" // This should have everything we need in it.
// This is the file-local socket variable that we use for SSL communications
// with the client. This must be initialised via the appropriate create_ssl
// function and destroyed via the appropriate destroy_ssl function.
static ThreadSafeSSL *client_socket{nullptr};
static ThreadSafeSSL *server_socket{nullptr};
template <bool is_server> static void create_ssl(SSL *const ssl) noexcept {
assert(ssl);
if (is_server) {
assert(server_socket == nullptr);
server_socket = new ThreadSafeSSL(ssl);
} else {
assert(client_socket == nullptr);
client_socket = new ThreadSafeSSL(ssl);
}
}
template <bool is_server> static void destroy_ssl() noexcept {
if (is_server) {
delete server_socket;
server_socket = nullptr;
} else {
delete client_socket;
client_socket = nullptr;
}
}
ThreadSafeSSL *EmpSSLSocketManager::get_ssl_server() noexcept {
return server_socket;
}
ThreadSafeSSL *EmpSSLSocketManager::get_ssl_client() noexcept {
return client_socket;
}
void EmpSSLSocketManager::destroy_ssl_server() noexcept { destroy_ssl<true>(); }
void EmpSSLSocketManager::destroy_ssl_client() noexcept {
destroy_ssl<false>();
}
void EmpSSLSocketManager::create_ssl_client(SSL *const ssl) noexcept {
create_ssl<false>(ssl);
}
void EmpSSLSocketManager::create_ssl_server(SSL *const ssl) noexcept {
create_ssl<true>(ssl);
}
unsigned EmpSSLSocketManager::register_new_socket_server() noexcept {
assert(server_socket);
return server_socket->register_new_socket();
}
unsigned EmpSSLSocketManager::register_new_socket_client() noexcept {
assert(client_socket);
return client_socket->register_new_socket();
}

View File

@@ -1,112 +0,0 @@
#ifndef INCLUDED_EMPSSLSOCKETMANAGER_HPP
#define INCLUDED_EMPSSLSOCKETMANAGER_HPP
/**
EmpSSLSocketManagerManager.
\brief This component realises a shared SSL socket that can be shared
across multiple threads. This component exists here to allow separation from the
EmpThreadSocket in client mode and in server mode.
Essentially, this component
allows the caller to setup a global socket that can be used exclusively in
client mode or exclusively in server mode. Notably, though, these cannot overlap
for threading reasons. As a result, we separate these sockets out here. In
practical terms this probably matters very little, but for testing purposes this
is a big win.
As a usage principle this namespace operates similarly to most C-style functions
for creating new objects. To instantiate a new ThreadSafeSocket, call the
appropriate create_ssl_ function. Similarly, to destroy a ThreadSafeSocket, call
the appropriate destroy_ssl_ function. The exact function you'll use depends on
the use case: if you don't want to make this decision, use the appropriate type
defined in EmpThreadSocket.hpp.
**/
struct ssl_st; // Forward declaration for nicer compilation.
typedef struct ssl_st SSL; // N.B This has to be a typedef.
class ThreadSafeSSL; // Forward declaration for nicer compilation.
namespace EmpSSLSocketManager {
/**
create_ssl_client. This function creates a new global ThreadSafeSSL object
for use by clients. The underlying conneciton used is provided by the `ssl`
argument. This exists solely to allow all EmpThreadSockets to use the same SSL
connection. This function does not throw. Note that this function will
assert(false) if there is an already existing ThreadSafeSSL connection when this
function is called.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerCreateSSL
@param[in] ssl: the SSL connection to use. Must not be null.
**/
void create_ssl_client(SSL *const ssl) noexcept;
/**
create_ssl_server. This function creates a new global ThreadSafeSSL object for
use by servers. The underlying conneciton used is provided by the `ssl`
argument. This exists solely to allow all EmpThreadSockets to use the same SSL
connection. This function does not throw. Note that this function will
assert(false) if there is an already existing ThreadSafeSSL connection when this
function is called.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerCreateSSL
@param[in] ssl: the SSL connection to use. Must not be null.
**/
void create_ssl_server(SSL *const ssl) noexcept;
/**
get_ssl_client. This function returns a pointer to the underlying client
ThreadSafeSSL object. This function never returns a pointer that is null.
This function does not throw and ideally should only be used during testing.
@return a non-null pointer to the client's ThreadSafeSSL object.
**/
ThreadSafeSSL *get_ssl_client() noexcept;
/**
get_ssl_server. This function returns a pointer to the underlying server
ThreadSafeSSL object. This function never returns a pointer that is null.
This function does not throw and ideally should only be used during testing.
@return a non-null pointer to the server's ThreadSafeSSL object.
**/
ThreadSafeSSL *get_ssl_server() noexcept;
/**
destroy_ssl_server. This function destroys the underlying server
ThreadSafeSSL object. This function does not throw. Note that this function does
not free the underlying SSL connection.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerDestroySSL
@remarks This function nulls out the previous server ThreadSafeSSL object.
This is primarily for testing. Note that this should not be relied upon.
**/
void destroy_ssl_server() noexcept;
/**
destroy_ssl_client. This function destroys the underlying client
ThreadSafeSSL object. This function does not throw. Note that this function does
not free the underlying SSL connection.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerDestroySSL
@remarks This function nulls out the previous client ThreadSafeSSL object. This
is primarily for testing. Note that this should not be relied upon.
**/
void destroy_ssl_client() noexcept;
/**
register_new_socket_server. This function registers a new socket with the
ThreadSafeSSL object for the server. This function does not throw. Note that
this function asserts that the underlying ThreadSafeSSL object is set for
safety.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerRegisterNewSocket
@return the tag for the new server socket.
**/
unsigned register_new_socket_server() noexcept;
/**
register_new_socket_client. This function registers a new socket with the
ThreadSafeSSL object for the client. This function does not throw. Note that
this function asserts that the underlying ThreadSafeSSL object is set for
safety.
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerRegisterNewSocket
@return the tag for the new client socket.
**/
unsigned register_new_socket_client() noexcept;
} // namespace EmpSSLSocketManager
#endif

View File

@@ -1,45 +0,0 @@
#include "../doctest.h"
#include "EmpSSLSocketManager.hpp"
#include "TestUtil.hpp"
//! [EmpSSLSocketManagerDefaults]
TEST_CASE("defaults") {
// These should be true.
CHECK(EmpSSLSocketManager::get_ssl_client() == nullptr);
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
}
//! [EmpSSLSocketManagerDefaults]
//! [EmpSSLSocketManagerCreateSSL]
TEST_CASE("e2e") {
auto context = CreateContextWithTestCertificate(TLS_method());
REQUIRE(context);
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
REQUIRE(ssl);
//! [EmpSSLSocketManagerCreateSSL]
EmpSSLSocketManager::create_ssl_client(ssl.get());
CHECK(EmpSSLSocketManager::get_ssl_client() != nullptr);
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
EmpSSLSocketManager::create_ssl_server(ssl.get());
CHECK(EmpSSLSocketManager::get_ssl_server() != nullptr);
CHECK(EmpSSLSocketManager::get_ssl_client() !=
EmpSSLSocketManager::get_ssl_server());
//! [EmpSSLSocketManagerCreateSSL]
//! [EmpSSLSocketManagerRegisterNewSocket]
for (unsigned i = 0; i < 100; i++) {
CHECK(EmpSSLSocketManager::register_new_socket_client() == i);
CHECK(EmpSSLSocketManager::register_new_socket_server() == i);
}
//! [EmpSSLSocketManagerRegisterNewSocket]
//! [EmpSSLSocketManagerDestroySSL]
EmpSSLSocketManager::destroy_ssl_client();
CHECK(EmpSSLSocketManager::get_ssl_client() == nullptr);
CHECK(EmpSSLSocketManager::get_ssl_server() != nullptr);
EmpSSLSocketManager::destroy_ssl_server();
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
//! [EmpSSLSocketManagerDestroySSL]
}

View File

@@ -1 +0,0 @@
#include "EmpThreadSocket.hpp"

View File

@@ -1,230 +0,0 @@
#ifndef INCLUDED_EMPTHREADSOCKET_HPP
#define INCLUDED_EMPTHREADSOCKET_HPP
#include "../emp-tool/emp-tool/io/io_channel.h" // This contains the declarations for emp things.
#include "EmpSSLSocketManager.hpp" // Needed for multiplexed sends and receives.
#include "ThreadSafeSSL.hpp" // Needed for multiplexed sends and receives.
#include "openssl/base.h" // Needed for various declarations.
#include "ssl/internal.h" // Needed for various declarations.
#include "NoBuffer.hpp"
#include "SSLBuffer.hpp"
/**
EmpThreadSocket. This component realises a thread-safe wrapper around sockets
for EMP AG2PC (https://github.com/emp-toolkit/emp-ag2pc). This wrapper exists
because EMP AG2PC allows one to specify an arbitrary number of threads for
setup, but in a TLS context this may be unrealistically expensive (or, at best,
confusing). Equally, using a single socket and manually restricting to using a
single thread seems to cause some deadlocks in the I/O code, which implies there
is some sort of implicit assumption that threading is used.
To circumvent this issue, we use the following strategy to ensure that
multiple threads can access the same socket whilst also playing nicely with
existing EMP code:
1. The EmpThreadSocket class acts as a barrier class for EMP AG2PC.
2. Mechanically, this class contains two things: a reference to a thread-safe
TLS socket and a tag. The reference is simple: it is simply a thread safe way to
do I/O without modifying EMP.
The tag is less straightforward: conceptually, it is a ticket which specifies
when this particular EmpThreadSocket was created. Those of you who have shopped
in the British Shoe shop Clarks will recognise this system: it is a numbered tag
that specifies when an event happened, or the number of a shopper in the queue.
The reason for this is because we want to be able to "tag" our messages as they
go into the outside world to make sure they are delivered to the right thread on
the other side of the network.
The reference to the thread-safe TLS socket deals with all of the mechanics
of doing this: the important part is that each EmpThreadSocket has a way of
knowing which messages are due for it, and those which aren't.
@remarks The underlying assumption here is that the EmpThreadSockets are all
made from one thread (i.e the setup is done socket-by-socket). This is important
because otherwise the connection ordering will get broken.
@remarks For more details on e.g the template invocation here, please see
EmpWrapper.hpp
@remarks If you want to use this class, you _must_ call
EmpThreadSocket::set_ssl _first_. This sets up a global ThreadSafeSSL object
that handles all multiplexing. To protect against this, the constructor will
assert against this invariant. Similarly, once all EmpThreadSockets have
outserved their usefulness, you must also call the
EmpThreadSocket::free_ssl function. We assert against this when creating new
objects, so this will be fairly obvious in debug mode if this is not set.
@tparam is_server: true if this class is a server class, false otherwise.
@tparam BufferType: type of buffering scheme to use. See SSLBuffer.hpp for more.
**/
template <bool server, typename BufferType>
class EmpThreadSocket
: public emp::IOChannel<EmpThreadSocket<server, BufferType>> {
// These constructors are explicitly deleted to prevent them being called
// accidentally.
EmpThreadSocket() = delete;
EmpThreadSocket(EmpThreadSocket &) = delete;
EmpThreadSocket(EmpThreadSocket &&) = delete;
public:
/**
is_server. This variable indicates if this particular type of
EmpThreadSocket is a server socket or not. This variable is unused and only
exists to satisfy the requirements of EmpAG2PC.
**/
static constexpr bool is_server = server;
/**
addr. This variable is a placeholder for EMPAG2PC. This isn't actually
used for anything.
**/
inline static const std::string addr = "";
/**
port. This variable is a placeholder for EMPAG2PC. This isn't actually used
for anything.
**/
inline static const int port = 0;
/**
EmpThreadSocket. This is the default constructor for this class.
This constructor exists solely to satisfy the API requirements of
emp::IOChannel, and as a result none of the arguments do anything. This
function never throws. This function sets up the tag of this socket.
@snippet EmpThreadSocket.t.cpp EmpThreadSocketConstruct
**/
inline EmpThreadSocket(const char *const, const int, bool) noexcept;
/**
send_data_internal. This is a wrapper function that is called by EMP's
IOChannel type. This function accepts a valid void* pointer, `data`, some
amount of data `nbyte` and sends `data` over the underlying SSL channel. This
function does not throw.
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSendDataInternal.
@tparam T: the type of nbyte.
@param[in] data: the data to be sent. This must be a non-null pointer.
@param[in] nbyte: the number of bytes to send. This value should be
non-zero.
@remarks Note that this function does not report any errors. This is
because IOChannel's parent function does not allow the reporting of any
errors.
**/
template <typename T>
inline void send_data_internal(const void *const data,
const T nbyte) noexcept;
/**
recv_data_internal. This is a wrapper function that is called by
EMP's IOChannel type. This function accepts a `void *` pointer, `data`,
some amount of data `nbyte` and reads `nbyte`s into `data` using the
underlying SSL object. This function does not throw any exceptions.
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSendDataInternalTests.
@tparam T: the type of nbyte.
@param[in] data: the buffer to store the read data. This pointer must be
non-null.
@param[in] nbyte: the number of bytes to read. This value should be non-zero,
but we do not enforce this, instead relying on BoringSSL to do this for us.
@remarks Note that this function does not report any errors. This is
because (unlike BoringSSL) the IOChannel's parent function does not allow the
reporting of any errors. Instead, it simply prints to stderr. This seems hard
to use in a browser setting, so we ignore it here.
**/
template <typename T>
inline void recv_data_internal(void *const data, const T nbyte) noexcept;
/**
flush. This function is a wrapper function that is called by EMP's
IOChannel type. This function flushes to the output if a buffering scheme is
used and does nothing otherwise.
This function does not throw. In case where the socket does not support
buffering, this function also does not modify `this` object.
@remarks This function is only enabled if the BufferType does not support
buffering. The template parameter `T` is here to allow SFINAE to work (C++ has
some complicated rules around SFINAE and dependent templates).
**/
template <typename T = void>
inline std::enable_if_t<!BufferType::can_buffer(), T> flush() const noexcept;
/**
flush. This function is a wrapper function that is called by EMP's
IOChannel type. This function flushes to the output if a buffering scheme is
used and does nothing otherwise.
This function does not throw. In case where the socket does not support
buffering, this function also does not modify `this` object.
@remarks This function is only enabled if the BufferType supports buffering.
The template parameter `T` is here to allow SFINAE to work (C++ has some
complicated rules around SFINAE and dependent templates).
**/
template <typename T = void>
inline std::enable_if_t<BufferType::can_buffer(), T> flush() noexcept;
/**
set_ssl. This static function sets the global ThreadSafeSSL object to the
`ssl` argument. This exists solely to allow all EmpThreadSockets to use the
same SSL connection. This function does not throw.
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSetSSL
@param[in] ssl: the ssl connection to use. Must not be null.
**/
static inline void set_ssl(SSL *const ssl) noexcept;
/**
destroy_ssl. This static function destroys the global ThreadSafeSSL object.
@snippet EmpThreadSocket.t.cpp EmpThreadSocketDestroySSL
**/
static inline void destroy_ssl() noexcept;
/**
get_socket. This returns a copy of the socket pointer that this class uses.
This exists solely to allow us to pass static variables from the .cpp file to
the .inl file. This function does not throw. This should only be used during
testing.
@return a copy of the socket object.
**/
static inline ThreadSafeSSL *get_socket() noexcept;
/**
get_tag. This function returns a copy of the tag assigned to this socket.
This function does not throw or modify this object.
@return a copy of the tag of this object.
**/
inline unsigned get_tag() const noexcept;
private:
/**
tag. This tag identifies the thread of this particular socket. This is used
to make sure that messages reach the right thread.
**/
unsigned tag;
/**
buffer. This is the internal write buffer. This buffer is used to store
outgoing writes to mitigate delays and reduce the number of system calls.
**/
BufferType buffer;
};
// Inline definitions live here.
#include "EmpThreadSocket.inl"
/**
EmpClientSocket. This class is a thread safe SSL socket that is to be used
by clients. For more information, see EmpThreadSocket in EmpThreadSocket.hpp
**/
using EmpClientSocket = EmpThreadSocket<false, SSLBuffer>;
/**
EmpServerSocket. This class is a thread safe SSL socket that is to be used
by server. For more information, see EmpThreadSocket in EmpThreadSocket.hpp
**/
using EmpServerSocket = EmpThreadSocket<true, SSLBuffer>;
// This template class simply returns the type of EmpThreadSocket that should
// be used. This is primarily for neatness.
template <bool type> struct EmpSocketDispatch;
template <> struct EmpSocketDispatch<false> { using type = EmpClientSocket; };
template <> struct EmpSocketDispatch<true> { using type = EmpServerSocket; };
#endif

View File

@@ -1,117 +0,0 @@
#ifndef INCLUDED_EMPTHREADSOCKET_HPP
#error Do not include EmpThreadSocket.inl without EmpThreadSocket.hpp
#endif
template <bool is_server, typename BufferType>
EmpThreadSocket<is_server, BufferType>::EmpThreadSocket(const char *const,
const int,
bool) noexcept
// N.B the EmpSSLSocketManager does the assertion on if the underlying
// object exists.
: tag{(is_server) ? EmpSSLSocketManager::register_new_socket_server()
: EmpSSLSocketManager::register_new_socket_client()},
buffer{} {}
template <bool is_server, typename BufferType>
template <typename T>
void EmpThreadSocket<is_server, BufferType>::send_data_internal(
const void *const data, const T nbyte) noexcept {
// We only accept integral sizes here.
static_assert(
std::is_integral_v<T>,
"Error: send_data_internal can only be instantiated with integral types");
// For casting purposes we only use a fixed-type here.
using SizeType = ThreadSafeSSL::SizeType;
// It turns out that emp can actually call this with a zero size.
// If that happens, just quit
if (nbyte == 0) {
return;
}
// We also need to make sure we aren't sending too much.
assert(nbyte <= std::numeric_limits<SizeType>::max());
// If buffering is supported, then buffer.
if constexpr (BufferType::can_buffer()) {
buffer.buffer_data(data, static_cast<SizeType>(nbyte));
} else {
auto *socket = EmpThreadSocket::get_socket();
assert(socket);
socket->send(tag, data, static_cast<SizeType>(nbyte));
}
}
template <bool is_server, typename BufferType>
template <typename T>
void EmpThreadSocket<is_server, BufferType>::recv_data_internal(
void *const data, const T nbyte) noexcept {
// We only accept integral sizes here.
static_assert(
std::is_integral_v<T>,
"Error: send_data_internal can only be instantiated with integral types");
// For casting purposes we only use a fixed-type here.
using SizeType = ThreadSafeSSL::SizeType;
// It turns out that emp can actually call this with a zero size.
// If that happens, just quit.
if (nbyte == 0) {
return;
}
// We also need to make sure we aren't reading too much.
assert(nbyte <= std::numeric_limits<SizeType>::max());
auto *socket = EmpThreadSocket::get_socket();
assert(socket);
socket->recv(tag, data, static_cast<SizeType>(nbyte));
}
template <bool is_server, typename BufferType>
template <typename T>
inline typename std::enable_if_t<BufferType::can_buffer(), T>
EmpThreadSocket<is_server, BufferType>::flush() noexcept {
static_assert(std::is_same_v<T, void>,
"Error: can only instantiate can_buffer with T = void");
const auto size = buffer.size();
if (size == 0)
return;
auto *socket = EmpThreadSocket::get_socket();
assert(socket);
socket->send(tag, buffer.data(), size);
buffer.clear();
}
template <bool is_server, typename BufferType>
template <typename T>
inline typename std::enable_if_t<!BufferType::can_buffer(), T>
EmpThreadSocket<is_server, BufferType>::flush() const noexcept {
static_assert(std::is_same_v<T, void>,
"Error: can only instantiate can_buffer with T = void");
}
template <bool is_server, typename BufferType>
inline void
EmpThreadSocket<is_server, BufferType>::set_ssl(SSL *const ssl) noexcept {
(is_server) ? EmpSSLSocketManager::create_ssl_server(ssl)
: EmpSSLSocketManager::create_ssl_client(ssl);
}
template <bool is_server, typename BufferType>
inline void EmpThreadSocket<is_server, BufferType>::destroy_ssl() noexcept {
(is_server) ? EmpSSLSocketManager::destroy_ssl_server()
: EmpSSLSocketManager::destroy_ssl_client();
}
template <bool is_server, typename BufferType>
inline ThreadSafeSSL *
EmpThreadSocket<is_server, BufferType>::get_socket() noexcept {
return (is_server) ? EmpSSLSocketManager::get_ssl_server()
: EmpSSLSocketManager::get_ssl_client();
}
template <bool is_server, typename BufferType>
inline unsigned
EmpThreadSocket<is_server, BufferType>::get_tag() const noexcept {
return tag;
}

View File

@@ -1,117 +0,0 @@
#include "../doctest.h"
#include "EmpThreadSocket.hpp"
#include "TLSSocket.hpp"
#define SOCKET_SETUP
#include "TestUtil.hpp"
#include <thread>
// Just for testing across multiple types.
template <typename T> struct EmpThreadSocketTestClass { using type = T; };
//! [EmpThreadSocketSetSSL]
TEST_CASE_TEMPLATE("set_ssl", socket_type,
EmpThreadSocketTestClass<EmpClientSocket>,
EmpThreadSocketTestClass<EmpServerSocket>) {
auto context = CreateContextWithTestCertificate(TLS_method());
REQUIRE(context);
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
REQUIRE(ssl);
using SocketType = typename socket_type::type;
CHECK(SocketType::get_socket() == nullptr);
SocketType::set_ssl(ssl.get());
CHECK(SocketType::get_socket() != nullptr);
CHECK(SocketType::get_socket()->get_ssl() == ssl.get());
//! [EmpThreadSocketDestroySSL]
CHECK(SocketType::get_socket() != nullptr);
SocketType::destroy_ssl();
CHECK(SocketType::get_socket() == nullptr);
//! [EmpThreadSocketDestroySSL]
}
//! [EmpThreadSocketSetSSL]
//! [EmpThreadSocketConstruct]
TEST_CASE_TEMPLATE("construct", socket_type,
EmpThreadSocketTestClass<EmpClientSocket>,
EmpThreadSocketTestClass<EmpServerSocket>) {
auto context = CreateContextWithTestCertificate(TLS_method());
REQUIRE(context);
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
REQUIRE(ssl);
using SocketType = typename socket_type::type;
SocketType::set_ssl(ssl.get());
SUBCASE("Just one") {
SocketType sock("", 0, true); // The arguments here don't matter.
CHECK(sock.get_socket()->get_ssl() == ssl.get());
CHECK(sock.get_tag() == 0);
}
SUBCASE("Many") {
for (unsigned i = 0; i < 100; i++) {
SocketType sock("", 0, true); // The arguments here don't matter.
CHECK(sock.get_socket()->get_ssl() == ssl.get());
CHECK(sock.get_tag() == i);
}
}
SocketType::destroy_ssl();
}
//! [EmpThreadSocketConstruct]
// This test isn't templated because it isn't needed.
TEST_CASE("send_recv") {
auto context = CreateContextWithTestCertificate(TLS_method());
REQUIRE(context);
std::unique_ptr<TLSSocket> server, client;
// Setup the connections.
REQUIRE(setup_sockets(context, server, client));
// Now we'll practice sending from the client to the server.
// We'll do this using an existing connection from a ThreadSafeSSL object.
EmpSSLSocketManager::create_ssl_server(server->get_ssl_object());
EmpSSLSocketManager::create_ssl_client(client->get_ssl_object());
// Make some data to send. This is fixed data, but there's no reason for that.
std::array<uint8_t, 100> data;
std::iota(data.begin(), data.end(), 0);
EmpClientSocket client_socket("", 0,
true); // The arguments here don't matter.
EmpServerSocket server_socket("", 0,
true); // The arguments here don't matter.
auto client_code = [&]() {
//! [EmpThreadSocketSendDataInternal]
client_socket.send_data_internal(data.data(), data.size());
client_socket.flush();
//! [EmpThreadSocketSendDataInternal]
//! [EmpThreadSocketRecvDataInternal]
std::array<uint8_t, 100> in;
client_socket.recv_data_internal(in.data(), in.size());
CHECK(in == data);
//! [EmpThreadSocketRecvDataInternal]
};
auto server_code = [&]() {
std::array<uint8_t, 100> in;
server_socket.recv_data_internal(in.data(), in.size());
CHECK(in == data);
server_socket.send_data_internal(in.data(), in.size());
server_socket.flush();
};
std::thread server_thread(server_code);
client_code();
server_thread.join();
// Have to destroy these at the end.
EmpSSLSocketManager::destroy_ssl_server();
EmpSSLSocketManager::destroy_ssl_client();
}

View File

@@ -1,6 +1,5 @@
#include "../doctest.h"
#include "EmpWrapperAG2PC.hpp"
#include "SSLScope.hpp"
#include "TLSSocket.hpp"
#define SOCKET_SETUP
#include "../mta/F2128MtA.hpp"
@@ -305,8 +304,6 @@ TEST_CASE("aes_gcm_derivation") {
REQUIRE(context);
std::unique_ptr<TLSSocket> client, server;
REQUIRE(setup_sockets(context, server, client));
SSLScope<true> server_scope(server->get_ssl_object());
SSLScope<false> client_scope(client->get_ssl_object());
std::unique_ptr<EmpWrapperAG2PC> client_circ;
std::unique_ptr<EmpWrapperAG2PC> server_circ;
@@ -450,8 +447,6 @@ TEST_CASE_TEMPLATE(
REQUIRE(context);
std::unique_ptr<TLSSocket> client, server;
REQUIRE(setup_sockets(context, server, client));
SSLScope<true> server_scope(server->get_ssl_object());
SSLScope<false> client_scope(client->get_ssl_object());
std::unique_ptr<EmpWrapperAG2PC> client_circ;
std::unique_ptr<EmpWrapperAG2PC> server_circ;
@@ -693,8 +688,6 @@ TEST_CASE("derive_combined_traffic") {
REQUIRE(context);
std::unique_ptr<TLSSocket> client, server;
REQUIRE(setup_sockets(context, server, client));
SSLScope<true> server_scope(server->get_ssl_object());
SSLScope<false> client_scope(client->get_ssl_object());
std::unique_ptr<EmpWrapperAG2PC> client_circ;
std::unique_ptr<EmpWrapperAG2PC> server_circ;

View File

@@ -1,7 +1,8 @@
#ifndef INCLUDED_SSLBUFFER_HPP
#define INCLUDED_SSLBUFFER_HPP
#include "ThreadSafeSSL.hpp"
#include "openssl/base.h"
#include "ssl/internal.h"
#include <cassert>
#include <iostream>
#include <vector>
@@ -27,10 +28,9 @@ template <size_t flush_size = SSL3_RT_MAX_PLAIN_LENGTH, bool hold_flush = true>
class SSLBufferPolicy {
public:
/**
SizeType. We use the same SizeType type as ThreadSafeSSL for compatibility
purposes.
SizeType. We forward declare a size type for compatibility's sake.
**/
using SizeType = ThreadSafeSSL::SizeType;
using SizeType = unsigned;
/**
has_data. This function always returns true. It should be used to indicate

View File

@@ -1,24 +0,0 @@
#ifndef INCLUDED_SSLSCOPE_HPP
#define INCLUDED_SSLSCOPE_HPP
#include "EmpThreadSocket.hpp"
/**
SSLScope. This struct exists solely to provide a RAII wrapper for socket
management inside this project. In essence, because the tagged sockets
(EmpThreadSocket) all rely upon global variables, it's easier to tie the
lifetime of those sockets to the lifetime of the global socket object. In our
case, we want to be able to tie the lifetime of those sockets to the Server
object that owns the underlying socket. This class just makes all of that
easier.
**/
template <bool is_server> class SSLScope {
using SocketType = typename EmpSocketDispatch<is_server>::type;
public:
SSLScope(SSL *const ssl) noexcept { SocketType::set_ssl(ssl); }
~SSLScope() noexcept { SocketType::destroy_ssl(); }
};
#endif

View File

@@ -1,232 +0,0 @@
#include "ThreadSafeSSL.hpp"
#include "Util.hpp" // Needed for generic I/O routines.
#include <algorithm>
#include <iostream>
// Our datagrams in this file have the following two leading entries:
// 1. The tag where the message is due to be stored, and
// 2. The number of bytes in the message.
// We arbitrarily assume this fits into 64 bits. This should be true, as we
// don't expect either party to have more than 256 threads available, or for
// anyone to be sending messages larger than 2^56 bytes long.
struct Packed {
unsigned header : 8;
size_t size : 56;
};
// Check the packet header fits exactly into 64 bits.
static_assert(sizeof(Packed) == sizeof(uint64_t));
ThreadSafeSSL::ThreadSafeSSL(SSL *const ssl_in) noexcept
: count{}, socket_lock{}, ssl{ssl_in}, registered_in{},
registered_out{}, incoming{} {
assert(ssl);
// N.B there's a constexpr way to do this initialisation, but it's rather
// long. If this is too slow, we can fix it.
std::fill(registered_in.begin(), registered_in.end(),
ThreadSafeSSL::tombstone);
std::fill(registered_out.begin(), registered_out.end(),
ThreadSafeSSL::tombstone);
}
SSL *ThreadSafeSSL::get_ssl() noexcept { return ssl; }
unsigned ThreadSafeSSL::register_new_socket() noexcept {
assert(count + 1 < ThreadSafeSSL::max_size);
return count++;
}
unsigned ThreadSafeSSL::find_first_tombstone() const noexcept {
// Broken out for readability.
constexpr auto predicate = [](const uint8_t value) {
return value == ThreadSafeSSL::tombstone;
};
const auto it = std::find_if(std::cbegin(registered_out),
std::cbegin(registered_out) + count, predicate);
if (it == std::cbegin(registered_out) + count) {
return ThreadSafeSSL::tombstone;
}
// This cast is safe because we have an upper limit on the size of
// registered_in.
return static_cast<unsigned>(std::distance(std::cbegin(registered_out), it));
}
void ThreadSafeSSL::recv(const unsigned tag, void *const data,
const unsigned nbyte) noexcept {
// Precondition: data must be a valid pointer.
assert(data);
// Make sure the caller hasn't messed up.
assert(tag < ThreadSafeSSL::max_size);
assert(tag < count);
// First thing we do is lock the entire class.
std::lock_guard<std::mutex> lock(socket_lock);
// Now we have three situations:
// 1. We already have a message ready for us. Let's not read from the socket:
// we'll just read from our slot and return that.
if (incoming[tag].size() > 0) {
auto buffer = incoming[tag].data();
// Stop us from reading too much data.
assert(nbyte <= incoming[tag].size());
memcpy(data, buffer, nbyte);
incoming[tag].erase(incoming[tag].begin(), incoming[tag].begin() + nbyte);
return;
}
// The other cases are:
// 2. We've entered this function with a thread that has never received data
// before.
// 3. We've entered this function with a thread that has received data.
// before, but there's no messages in the buffer for us.
// In any case, we need to read the header to find out.
Packed header;
[[maybe_unused]] const auto header_bytes =
SSL_read(ssl, &header, sizeof(header));
// Only in debug.
assert(header_bytes == sizeof(header));
// We actually handle both cases in a single loop.
size_t amount_read_for_this = 0;
while (true) {
// This is just an abbreviation.
const auto dest_tag = registered_in[header.header];
// If the message is for us, we just read it already.
if (dest_tag == tag) {
// We might have more data to read here than we wanted. To stop us from
// over running the buffer if that's the case, we:
// 1. Read the first nbytes into the `data` buffer, and
// 2. Read the rest into the buffer for this tag. This will be read into
// the next time the tag calls this particular function.
if (nbyte >= header.size) {
Util::process_data(ssl,
static_cast<char *>(data) + amount_read_for_this,
header.size, SSL_read);
amount_read_for_this += header.size;
} else {
// If we've hit this case then there's nothing in _our_ buffer already,
// so we can just resize and copy from the beginning.
Util::process_data(ssl, static_cast<char *>(data), nbyte, SSL_read);
incoming[tag].resize(header.size - nbyte);
Util::process_data(ssl, incoming[tag].data(), header.size - nbyte,
SSL_read);
amount_read_for_this += nbyte;
}
} else if (dest_tag == ThreadSafeSSL::tombstone) {
// Case 2: a new thread has appeared.
// It's possible that we've read on a thread that has already been
// assigned to another thread (e.g on the other side) but we've received
// a new message anyway. We patch that here by finding the first thread on
// this node that hasn't yet been assigned to a thread on the other node.
const auto thread = (registered_out[tag] == ThreadSafeSSL::tombstone)
? tag
: find_first_tombstone();
// This will fire if there's no free threads on our side to deal with those
// messages.
assert(thread != ThreadSafeSSL::tombstone);
const auto thread_as_uint8 = static_cast<uint8_t>(thread);
// Set up the forwarding tables.
registered_out[thread] = header.header;
registered_in[header.header] = thread_as_uint8;
// Read into this tag's buffer if appropriate.
if (thread == tag) {
Util::process_data(ssl, static_cast<char *>(data),
static_cast<size_t>(header.size), SSL_read);
amount_read_for_this = static_cast<unsigned>(header.size);
} else {
incoming[thread].resize(header.size);
Util::process_data(ssl, incoming[thread].data(),
static_cast<size_t>(header.size), SSL_read);
}
// Now we have to tell the other thread who they are speaking to.
[[maybe_unused]] const auto written_bytes =
SSL_write(ssl, &thread_as_uint8, sizeof(thread_as_uint8));
assert(written_bytes > 0);
} else {
// Otherwise, we are going to append into the buffer that the message is
// due for. To do that, we have to reallocate enough memory to be able to
// read into the buffer for that particular tag.
const auto end_pos = incoming[dest_tag].size();
incoming[dest_tag].resize(end_pos + header.size);
Util::process_data(ssl, incoming[dest_tag].data() + end_pos,
static_cast<std::size_t>(nbyte), SSL_read);
}
// If we've read everything, then quit.
if (amount_read_for_this == nbyte) {
return;
}
// Otherwise, there must be data left, so quit.
[[maybe_unused]] const auto head_bytes =
SSL_read(ssl, &header, sizeof(header));
assert(head_bytes == sizeof(header));
// We'll deal with everything else on the next iteration.
}
}
void ThreadSafeSSL::send(const unsigned tag, const void *const data,
const unsigned nbyte) noexcept {
// Precondition: cannot send null data.
assert(data);
// To prevent the caller from messing up.
assert(tag < ThreadSafeSSL::max_size);
assert(tag < count);
// Lock the class to prevent race conditions.
std::lock_guard<std::mutex> lock(socket_lock);
// This code mirrors the code in recv quite closely.
// Here, though, we only have two cases:
// 1. We have a thread that is new.
// 2. We have a thread that is already established.
// Remarkably the code is mostly the same: the difference is if we
// need to wait for the other node to send a message back to us.
// In either case we need to use a header for writing.
// N.B This cast is fine as tag can fit into a uint8_t.
Packed header{static_cast<uint8_t>(tag), nbyte};
[[maybe_unused]] const auto bytes_written =
SSL_write(ssl, &header, sizeof(header));
assert(bytes_written > 0);
// Now we just write everything else out too.
Util::process_data(ssl, static_cast<const char *>(data),
static_cast<size_t>(nbyte), SSL_write);
// Now we need to check if this was the first contact with another thread or
// not.
if (registered_out[tag] == ThreadSafeSSL::tombstone) {
// We just read the tag back from the other side.
// N.B This assumes that both parties have the same endianness!
uint8_t other_tag;
[[maybe_unused]] const auto header_bytes =
SSL_read(ssl, &other_tag, sizeof(other_tag));
assert(header_bytes == sizeof(other_tag));
// Now we set up the tables.
registered_out[tag] = other_tag;
// This cast is safe because we statically enforce a maximum size on
// `tag`.
registered_in[other_tag] = static_cast<uint8_t>(tag);
}
}
bool ThreadSafeSSL::is_registered_in(const unsigned tag) const noexcept {
assert(tag < ThreadSafeSSL::max_size);
assert(tag < count);
return registered_in[tag] != ThreadSafeSSL::tombstone;
}
bool ThreadSafeSSL::is_registered_out(const unsigned tag) const noexcept {
assert(tag < ThreadSafeSSL::max_size);
assert(tag < count);
return registered_in[tag] != ThreadSafeSSL::tombstone;
}

View File

@@ -1,204 +0,0 @@
#ifndef INCLUDED_THREADSAFESSL_HPP
#define INCLUDED_THREADSAFESSL_HPP
#include "openssl/base.h"
#include "ssl/internal.h"
#include <array>
#include <atomic>
#include <mutex>
#include <vector>
/**
ThreadSafeSSL. This component realises a thread-safe SSL connection.
This component exists solely to allow the re-use of an SSL connection inside
certain parts of EMP.
Details: this socket is essentially a multiplexing socket.
As an analogy, you can consider two buildings full of people.
Normally the fastest way for these people to communicate would be to directly
phone or email each other, but in absence of that we can use a postal
service.
Assuming we only know the address, how can we make sure that the messages get
to the right people?
One simplifying assumption is that all of the people in both buildings are
capable of the same work. As a result, it just matters that the same two people
always speak in a pair, rather than to other people.
Let's label the buildings as A and B respectively. We'll assume that both
sides have the same number of people, labelled 0,..., n-1.
When person 0 in building A sends a message to building B, they will
tag their message with `0` before sending. Then, some person `i` in building
B will collect that message and note that they are dealing with person 0. Before
person `i` deals with the message, they will send a short note back to person
`0` that tells them they are dealing with person `i`. This is primarily for
consistency across both buildings.
Iterating this process, we can then build a multiplexed socket for multiple
threads.
@remarks As an implementation detail, this class sends all messages prepended
with a 64-bit tag. This tag is divided into 8-bits for the ID (i.e the person
number in the above example) and the rest is for the length. Whilst this is an
artificial restriction, in practice it seems reasonable: 2^56 bytes is 72
petabytes, which seems unrealistic. The 2^8 = 256 thread count also seems
reasonable, although for other reasons we actually restrict this to 255 threads.
@remarks This class assumes that both platforms have the same endianness:
there's no format independent encoding in this class. We can add this later if
needed.
**/
class ThreadSafeSSL {
public:
/**
SizeType. This is the type that is used as a size parameter in send and
recv. This declaration is here so that callers know what type to which they
should cast their size parameters.
**/
using SizeType = unsigned;
/**
register_new_socket. This function increments the number of sockets
associated with this SSL object and returns the tag for the caller. The value
returned is always the previous value of `count`: this means that the first
caller gets 0, the second caller gets 1 and so on. This function does not
throw.
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisterNewSocket
@return the tag associated with the calling socket.
**/
unsigned register_new_socket() noexcept;
/**
send. This function sends `nbyte` of `data` to the thread associated with
`tag`. This function does not throw. Note that similarly to other socket
classes that interface with EMP, this function does not return any error
information.
@param[in] tag: the tag associated with the calling socket. Must be less
than ThreadSafeSSL::max_size.
@param[in] data: the data to be sent. Must not be null.
@param[in] nbyte: the number of bytes to send.
**/
void send(const unsigned tag, const void *const data,
const SizeType nbyte) noexcept;
/**
recv. This function receives at most `nbyte` of data meant for the thread
associated with `tag`. This function does not throw. Note that similarly to
other socket classes that interface with EMP, this function does not return
any error information.
@param[in] tag: the tag associated with the calling socket. Must be less than
ThreadSafeSSL::max_size.
@param[in] data: the location to store the incoming data. Must not be null. In
debug builds, we assert to this.
@param[in] nbyte: the number of bytes to read.
**/
void recv(const unsigned tag, void *const data,
const SizeType nbyte) noexcept;
/**
ThreadSafeSSL. This constructor builds the socket. This constructor does
not throw.
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLConstructor
@param[in] ssl_in: the input ssl connection to use. Must not be null.
**/
ThreadSafeSSL(SSL *const ssl_in) noexcept;
/**
get_ssl. Returns a copy of the SSL object associated with `this`
ThreadSafeSSL object. This function never returns a null pointer and never
throws.
@snippet ThreadSafeSocket.t.cpp ThreadSafeSSLConstructor
@returns a non-null copy of the `ssl` object.
**/
SSL *get_ssl() noexcept;
/**
is_registered_in. This function returns true if `tag` is registered as
incoming at this node. This means that the thread with `tag` as their
identifier has a thread on this node that they consistently communicate with.
This function does not throw.
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisteredInOut
@param[in] tag: the tag that we are looking up. Must be less than
ThreadSafeSSL::max_size.
@return true if tag has a corresponding thread that it communicates with,
false otherwise.
**/
bool is_registered_in(const unsigned tag) const noexcept;
/**
is_registered_out. This function returns true if `tag` is registered as
outgoing at this node. This means that the thread with `tag` as their
identifier has a thread on the node at the other end of the channel that they
consistently communicate with.
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisteredInOut
@param[in] tag: the tag that we are looking up. Must be less than
ThreadSafeSSL::max_size.
@return true if tag has a corresponding thread that it communicates with,
false otherwise.
**/
bool is_registered_out(const unsigned tag) const noexcept;
private:
/**
max_size. This is the maximum number of threads we support in this class.
This is actually set here to make certain declarations nicer to write.
**/
static constexpr uint8_t max_size = 254;
/**
tombstone. This is the value that is set in the registered_in and
registered_out arrays to denote that the thread has not yet been seen or
registered. This is by default max_size + 1.
**/
static constexpr uint8_t tombstone = max_size + 1;
/**
count. This is the counter of the number of sockets that have been
registered with `this` socket. Note that despite the unsigned nature of this
counter the actual size is at most 255, which we assert to in the
register_new_socket function in release modes.
**/
std::atomic<unsigned> count;
/**
socket_lock. This is the lock for this class. All operations that involve
any variables of this class use this lock.
**/
std::mutex socket_lock;
/**
ssl. This is the ssl connection to use. Note that this is never null
after this socket has been connected.
**/
SSL *ssl;
/**
registered_in. This array contains the IDs of all threads that have been
registered _in_ at this socket. If an entry here is not set to `tombstone`,
then it means that it has been registered with a thread. In more detail, if
registered_in[i] is not `tombstone`, then the value of registered_in[i] is the
thread on this node that thread `i` (on another node) communicates with.
**/
std::array<uint8_t, max_size> registered_in;
/**
registered_out. This array contains the IDs of all threads that have been
registered as sending a message from this socket. If registered_out[i] !=
tombstone, then thread `i` here is communicating with the thread with ID
registered_out[i] on another node.
**/
std::array<uint8_t, max_size> registered_out;
/**
incoming. This vector of vector contains temporary storage for each
thread. In particular, incoming[i] contains any read messages for thread
`i`. This may be populated by other threads to prevent locking.
**/
std::array<std::vector<char>, max_size> incoming;
unsigned find_first_tombstone() const noexcept;
};
#endif

View File

@@ -1,332 +0,0 @@
#include "../doctest.h"
#include "TLSSocket.hpp"
#include "ThreadSafeSSL.hpp"
#include <numeric>
#define SOCKET_SETUP
#include "TestUtil.hpp"
#include <thread>
//! [ThreadSafeSSLConstructor]
TEST_CASE("constructor") {
auto context = CreateContextWithTestCertificate(TLS_method());
assert(context);
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
ThreadSafeSSL tssl(ssl.get());
CHECK(tssl.get_ssl() == ssl.get());
}
//! [ThreadSafeSSLConstructor]
//! [ThreadSafeSSLRegisterNewSocket]
TEST_CASE("constructor") {
auto context = CreateContextWithTestCertificate(TLS_method());
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
ThreadSafeSSL tssl(ssl.get());
SUBCASE("linear") {
CHECK(tssl.register_new_socket() == 0);
CHECK(tssl.register_new_socket() == 1);
CHECK(tssl.register_new_socket() == 2);
}
SUBCASE("loop") {
for (unsigned i = 0; i < 100; i++) {
CHECK(tssl.register_new_socket() == i);
}
}
}
//! [ThreadSafeSSLRegisterNewSocket]
//! [ThreadSafeSSLRegisteredIn]
TEST_CASE("registered_in") {
// This function checks that a registered thread is actually in the initial
// table.
auto context = CreateContextWithTestCertificate(TLS_method());
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
std::unique_ptr<TLSSocket> server, client;
REQUIRE(setup_sockets(context, server, client));
}
//! [ThreadSafeSSLRegisteredIn]
//! [ThreadSafeSSLSend]
TEST_CASE("send") {
// This just checks that the sending code works as we'd expect.
auto context = CreateContextWithTestCertificate(TLS_method());
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
std::unique_ptr<TLSSocket> server, client;
REQUIRE(setup_sockets(context, server, client));
// We'll use the server as the TSSL object.
ThreadSafeSSL tssl(server->get_ssl_object());
SUBCASE("random data, first time") {
const auto tag = tssl.register_new_socket();
std::array<uint8_t, 20> data;
for (unsigned i = 0; i < data.size(); i++) {
data[i] = static_cast<uint8_t>(rand());
}
auto send_code = [&]() { tssl.send(tag, data.data(), data.size()); };
CHECK(!tssl.is_registered_out(tag));
std::thread server_code(send_code);
// The message is split into two portions: the header first, then the
// message itself.
struct Packed {
unsigned header : 8;
size_t size : 56;
};
Packed header;
const auto nr_bytes =
SSL_read(client->get_ssl_object(), &header, sizeof(header));
CHECK(nr_bytes == sizeof(header));
CHECK((unsigned)header.header == tag);
CHECK((size_t)header.size == data.size());
// Now we expect the message to be of size 20 * sizeof(uint8_t).
uint8_t arr[data.size()];
const auto nr_bytes_mes =
SSL_read(client->get_ssl_object(), arr, sizeof(arr));
REQUIRE(nr_bytes_mes == sizeof(arr));
CHECK(memcmp(arr, data.data(), data.size()) == 0);
// Now to get the server to terminate we need to write a tag back.
// We'll just use 0.
const uint8_t temp_tag = 0;
CHECK(!tssl.is_registered_in(temp_tag));
const auto written =
SSL_write(client->get_ssl_object(), &temp_tag, sizeof(uint8_t));
REQUIRE(written == sizeof(uint8_t));
// We want to make sure the socket actually terminates.
server_code.join();
// Check that the tags are actually in the table.
//! [ThreadSafeSSLRegisteredInOut]
CHECK(tssl.is_registered_in(temp_tag));
CHECK(tssl.is_registered_out(tag));
//! [ThreadSafeSSLRegisteredInOut]
}
}
//! [ThreadSafeSSLSend]
//! [ThreadSafeSSLMultiSend]
TEST_CASE("multi_send") {
// This test case checks that both send and recv work together.
// We do this by first establishing a set of connections between two tssl
// objects and then checking that the reads work as expected.
auto context = CreateContextWithTestCertificate(TLS_method());
std::unique_ptr<TLSSocket> server, client;
REQUIRE(setup_sockets(context, server, client));
ThreadSafeSSL client_tssl(client->get_ssl_object());
ThreadSafeSSL server_tssl(server->get_ssl_object());
// Check that establishing the connection works.
auto server_code = [&]() {
for (unsigned i = 0; i < 10; i++) {
// This counts from 0, so it's essentially `i`.
server_tssl.register_new_socket();
// Here we are going for a 1:1 mapping. This is just for ease of testing.
CHECK(!server_tssl.is_registered_out(i));
CHECK(!server_tssl.is_registered_in(i));
server_tssl.send(i, &i, sizeof(i));
CHECK(server_tssl.is_registered_out(i));
CHECK(server_tssl.is_registered_in(i));
}
};
std::thread server_establish(server_code);
unsigned tmp_storage;
for (unsigned i = 0; i < 10; i++) {
client_tssl.register_new_socket();
// Here we are going for a 1:1 mapping. This is just for ease of testing.
CHECK(!client_tssl.is_registered_out(i));
CHECK(!client_tssl.is_registered_in(i));
client_tssl.recv(i, &tmp_storage, sizeof(tmp_storage));
CHECK(tmp_storage == i);
CHECK(client_tssl.is_registered_out(i));
CHECK(client_tssl.is_registered_in(i));
}
server_establish.join();
// Now we are going to send `i`, but in reverse. This is entirely to check
// that the buffering works properly.
SUBCASE("Sequential") {
auto server_send_backwards = [&]() {
for (int i = 9; i >= 0; i--) {
unsigned as_unsigned = static_cast<unsigned>(i);
server_tssl.send(as_unsigned, &as_unsigned, sizeof(as_unsigned));
}
};
std::thread server_backwards_thread(server_send_backwards);
for (unsigned i = 0; i < 10; i++) {
client_tssl.recv(i, &tmp_storage, sizeof(tmp_storage));
CHECK(tmp_storage == i);
}
server_backwards_thread.join();
}
SUBCASE("small_threaded") {
// Now we'll do it from many threads at once, to many threads at once.
// This is primarily to check that the buffering etc is actually thread
// safe.
std::array<std::thread, 10> client_threads;
std::array<std::thread, 10> server_threads;
auto client_send_code = [&](const unsigned i) {
unsigned t_tmp_storage;
client_tssl.recv(i, &t_tmp_storage, sizeof(t_tmp_storage));
CHECK(t_tmp_storage == i);
};
auto server_send_code = [&](const unsigned i) {
server_tssl.send(i, &i, sizeof(i));
};
for (unsigned i = 0; i < 10; i++) {
client_threads[i] = std::thread(client_send_code, i);
server_threads[i] = std::thread(server_send_code, 9 - i);
}
// N.B these joins are in separate threads to make sure that
// the data from server thread has been sent before we require
// any client threads to stop.
for (unsigned i = 0; i < 10; i++) {
server_threads[i].join();
}
for (unsigned i = 0; i < 10; i++) {
client_threads[i].join();
}
}
SUBCASE("large_threaded") {
// Now we'll do it from many threads at once, to many threads at once, but
// with lots of data. This is primarily to check whether we actually
// can handle large batches of data.
std::array<std::thread, 10> client_threads;
std::array<std::thread, 10> server_threads;
using BufType = std::array<char, 2 * SSL3_RT_MAX_PLAIN_LENGTH>;
auto client_send_code = [&](const unsigned i) {
BufType data;
std::iota(data.begin(), data.end(), i);
BufType in;
client_tssl.recv(i, in.data(), sizeof(in));
CHECK(in == data);
};
auto server_send_code = [&](const unsigned i) {
BufType data;
std::iota(data.begin(), data.end(), i);
server_tssl.send(i, data.data(), sizeof(data));
};
for (unsigned i = 0; i < 10; i++) {
client_threads[i] = std::thread(client_send_code, i);
server_threads[i] = std::thread(server_send_code, 9 - i);
}
// N.B these joins are in separate threads to make sure that
// the data from server thread has been sent before we require
// any client threads to stop.
for (unsigned i = 0; i < 10; i++) {
server_threads[i].join();
}
for (unsigned i = 0; i < 10; i++) {
client_threads[i].join();
}
}
}
//! [ThreadSafeSSLMultiSend]
//! [ThreadSafeSSLMismatchedSizeSends]
TEST_CASE("mismatched_size_send") {
// This test case is to make sure that if one thread sends more data than is
// expected to the receiver that everything still works.
auto context = CreateContextWithTestCertificate(TLS_method());
std::unique_ptr<TLSSocket> server, client;
REQUIRE(setup_sockets(context, server, client));
ThreadSafeSSL client_tssl(client->get_ssl_object());
ThreadSafeSSL server_tssl(server->get_ssl_object());
// We're going to do everything over tag 0. This means that the threads that
// are used for communicating both have a tag of 0. This is just for ease of
// testing: the behaviour we're simulating doesn't really require a particular
// thread to be used.
constexpr unsigned tag = 0;
auto server_code = [&]() {
// Just register a single new socket.
server_tssl.register_new_socket();
CHECK(!server_tssl.is_registered_out(tag));
CHECK(!server_tssl.is_registered_in(tag));
server_tssl.send(tag, &tag, sizeof(tag));
CHECK(server_tssl.is_registered_out(tag));
CHECK(server_tssl.is_registered_in(tag));
};
std::thread server_establish(server_code);
// Now we'll do the same setup for the client.
unsigned tmp_storage;
client_tssl.register_new_socket();
CHECK(!client_tssl.is_registered_out(tag));
CHECK(!client_tssl.is_registered_in(tag));
client_tssl.recv(tag, &tmp_storage, sizeof(tmp_storage));
CHECK(tmp_storage == tag);
CHECK(client_tssl.is_registered_out(tag));
CHECK(client_tssl.is_registered_in(tag));
server_establish.join();
// Now we want to do the actual testing: sending mismatched data over.
SUBCASE("small mismatch") {
constexpr std::array<unsigned, 2> test_arr{1, 2};
auto server_send = [&]() {
server_tssl.send(tag, test_arr.data(),
test_arr.size() * sizeof(unsigned));
};
std::thread server_send_thread(server_send);
unsigned recv{};
for (unsigned i = 0; i < test_arr.size(); i++) {
client_tssl.recv(tag, &recv, sizeof(unsigned));
CHECK(recv == test_arr[i]);
}
server_send_thread.join();
}
SUBCASE("large mismatch") {
std::vector<unsigned> vals(10000);
for (auto &val : vals) {
val = static_cast<unsigned>(rand());
}
auto server_send = [&]() {
server_tssl.send(tag, vals.data(),
static_cast<unsigned>(vals.size() * sizeof(unsigned)));
};
std::thread server_send_thread(server_send);
unsigned recv{};
for (unsigned i = 0; i < vals.size(); i++) {
client_tssl.recv(tag, &recv, sizeof(unsigned));
CHECK(recv == vals[i]);
}
server_send_thread.join();
}
}
//! [ThreadSafeSSLMismatchedSizeSends]

View File

@@ -796,9 +796,10 @@ static bool run_traffic_circuit_internal(
std::copy(output.cbegin() + offset,
output.cbegin() + offset + out.client_iv.size(),
out.client_iv.begin());
offset += out.client_iv.size() + out.server_key_share.size();
// These are also fine to cast.
offset += unsigned(out.client_iv.size() + out.server_key_share.size());
std::copy(output.cbegin() + offset, output.cend(), out.server_iv.begin());
offset += out.server_iv.size();
offset += unsigned(out.server_iv.size());
} else {
const auto copy_func = [&](auto &dest,
const unsigned inc = sizeof(decltype(dest))) {