mirror of
https://github.com/brave-experiments/DiStefano.git
synced 2026-01-09 12:17:54 -05:00
Merge pull request #6 from brave-experiments/fix-compilation
Fix compilation
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -0,0 +1,3 @@
|
||||
[submodule "src/benchmark"]
|
||||
path = src/benchmark
|
||||
url = https://github.com/google/benchmark.git
|
||||
|
||||
@@ -76,11 +76,8 @@ set(WARNINGS
|
||||
-Wnon-virtual-dtor
|
||||
-Wunused
|
||||
# These are omitted just so that we can continue piecewise developing.
|
||||
-Wno-error=unused-parameter
|
||||
-Wno-error=zero-as-null-pointer-constant
|
||||
-Wno-error=unused-variable
|
||||
-Wno-error=unused-parameter
|
||||
-Wno-error=return-type
|
||||
-Wno-error=unused-function
|
||||
-Woverloaded-virtual
|
||||
-Wsign-conversion
|
||||
-Wconversion
|
||||
@@ -92,8 +89,7 @@ set(WARNINGS
|
||||
-Wformat=2
|
||||
-Wcast-qual
|
||||
-Wmissing-declarations
|
||||
-Wsign-promo
|
||||
-Wno-error=disabled-optimization)
|
||||
-Wsign-promo)
|
||||
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
@@ -187,9 +183,6 @@ set(SRC_FILES
|
||||
ssl/Messaging.cpp
|
||||
ssl/EmpWrapper.cpp
|
||||
ssl/EmpWrapperAG2PC.cpp
|
||||
ssl/EmpSSLSocketManager.cpp
|
||||
ssl/ThreadSafeSSL.cpp
|
||||
ssl/EmpThreadSocket.cpp
|
||||
ssl/SSLBuffer.cpp
|
||||
ssl/CounterType.cpp
|
||||
ssl/NoBuffer.cpp
|
||||
@@ -223,7 +216,7 @@ target_link_libraries(TLSAttestDebugAsan PRIVATE ${L_ASAN_FLAGS})
|
||||
target_compile_options(TLSAttestOpt PRIVATE ${C_OPT_FLAGS})
|
||||
target_link_libraries(TLSAttestOpt PRIVATE ${L_FLAGS})
|
||||
|
||||
set(SSLTARGETS Util EmpWrapper EmpWrapperAG2PC ThreePartyHandshake StatefulSocket TLSSocket Messaging SSL ThreadSafeSSL EmpSSLSocketManager EmpThreadSocket SSLBuffer NoBuffer CounterType)
|
||||
set(SSLTARGETS Util EmpWrapper EmpWrapperAG2PC ThreePartyHandshake StatefulSocket TLSSocket Messaging SSL SSLBuffer NoBuffer CounterType)
|
||||
set(NodeTARGETS Server KeyShare)
|
||||
set(MTATARGETS MtA F2128MtA EmpBlockOwningSpan EmpBlockNonOwningSpan EmpBlockArray ectf PackArray)
|
||||
set(2PCTARGETS CircuitSynthesis)
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
#include "EmpSSLSocketManager.hpp"
|
||||
#include "ThreadSafeSSL.hpp" // This should have everything we need in it.
|
||||
|
||||
// This is the file-local socket variable that we use for SSL communications
|
||||
// with the client. This must be initialised via the appropriate create_ssl
|
||||
// function and destroyed via the appropriate destroy_ssl function.
|
||||
static ThreadSafeSSL *client_socket{nullptr};
|
||||
static ThreadSafeSSL *server_socket{nullptr};
|
||||
|
||||
template <bool is_server> static void create_ssl(SSL *const ssl) noexcept {
|
||||
assert(ssl);
|
||||
if (is_server) {
|
||||
assert(server_socket == nullptr);
|
||||
server_socket = new ThreadSafeSSL(ssl);
|
||||
} else {
|
||||
assert(client_socket == nullptr);
|
||||
client_socket = new ThreadSafeSSL(ssl);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool is_server> static void destroy_ssl() noexcept {
|
||||
if (is_server) {
|
||||
delete server_socket;
|
||||
server_socket = nullptr;
|
||||
} else {
|
||||
delete client_socket;
|
||||
client_socket = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ThreadSafeSSL *EmpSSLSocketManager::get_ssl_server() noexcept {
|
||||
return server_socket;
|
||||
}
|
||||
|
||||
ThreadSafeSSL *EmpSSLSocketManager::get_ssl_client() noexcept {
|
||||
return client_socket;
|
||||
}
|
||||
|
||||
void EmpSSLSocketManager::destroy_ssl_server() noexcept { destroy_ssl<true>(); }
|
||||
|
||||
void EmpSSLSocketManager::destroy_ssl_client() noexcept {
|
||||
destroy_ssl<false>();
|
||||
}
|
||||
|
||||
void EmpSSLSocketManager::create_ssl_client(SSL *const ssl) noexcept {
|
||||
create_ssl<false>(ssl);
|
||||
}
|
||||
|
||||
void EmpSSLSocketManager::create_ssl_server(SSL *const ssl) noexcept {
|
||||
create_ssl<true>(ssl);
|
||||
}
|
||||
|
||||
unsigned EmpSSLSocketManager::register_new_socket_server() noexcept {
|
||||
assert(server_socket);
|
||||
return server_socket->register_new_socket();
|
||||
}
|
||||
|
||||
unsigned EmpSSLSocketManager::register_new_socket_client() noexcept {
|
||||
assert(client_socket);
|
||||
return client_socket->register_new_socket();
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
#ifndef INCLUDED_EMPSSLSOCKETMANAGER_HPP
|
||||
#define INCLUDED_EMPSSLSOCKETMANAGER_HPP
|
||||
|
||||
/**
|
||||
EmpSSLSocketManagerManager.
|
||||
\brief This component realises a shared SSL socket that can be shared
|
||||
across multiple threads. This component exists here to allow separation from the
|
||||
EmpThreadSocket in client mode and in server mode.
|
||||
|
||||
Essentially, this component
|
||||
allows the caller to setup a global socket that can be used exclusively in
|
||||
client mode or exclusively in server mode. Notably, though, these cannot overlap
|
||||
for threading reasons. As a result, we separate these sockets out here. In
|
||||
practical terms this probably matters very little, but for testing purposes this
|
||||
is a big win.
|
||||
|
||||
As a usage principle this namespace operates similarly to most C-style functions
|
||||
for creating new objects. To instantiate a new ThreadSafeSocket, call the
|
||||
appropriate create_ssl_ function. Similarly, to destroy a ThreadSafeSocket, call
|
||||
the appropriate destroy_ssl_ function. The exact function you'll use depends on
|
||||
the use case: if you don't want to make this decision, use the appropriate type
|
||||
defined in EmpThreadSocket.hpp.
|
||||
**/
|
||||
|
||||
struct ssl_st; // Forward declaration for nicer compilation.
|
||||
typedef struct ssl_st SSL; // N.B This has to be a typedef.
|
||||
class ThreadSafeSSL; // Forward declaration for nicer compilation.
|
||||
|
||||
namespace EmpSSLSocketManager {
|
||||
|
||||
/**
|
||||
create_ssl_client. This function creates a new global ThreadSafeSSL object
|
||||
for use by clients. The underlying conneciton used is provided by the `ssl`
|
||||
argument. This exists solely to allow all EmpThreadSockets to use the same SSL
|
||||
connection. This function does not throw. Note that this function will
|
||||
assert(false) if there is an already existing ThreadSafeSSL connection when this
|
||||
function is called.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerCreateSSL
|
||||
@param[in] ssl: the SSL connection to use. Must not be null.
|
||||
**/
|
||||
void create_ssl_client(SSL *const ssl) noexcept;
|
||||
|
||||
/**
|
||||
create_ssl_server. This function creates a new global ThreadSafeSSL object for
|
||||
use by servers. The underlying conneciton used is provided by the `ssl`
|
||||
argument. This exists solely to allow all EmpThreadSockets to use the same SSL
|
||||
connection. This function does not throw. Note that this function will
|
||||
assert(false) if there is an already existing ThreadSafeSSL connection when this
|
||||
function is called.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerCreateSSL
|
||||
@param[in] ssl: the SSL connection to use. Must not be null.
|
||||
**/
|
||||
void create_ssl_server(SSL *const ssl) noexcept;
|
||||
|
||||
/**
|
||||
get_ssl_client. This function returns a pointer to the underlying client
|
||||
ThreadSafeSSL object. This function never returns a pointer that is null.
|
||||
This function does not throw and ideally should only be used during testing.
|
||||
@return a non-null pointer to the client's ThreadSafeSSL object.
|
||||
**/
|
||||
ThreadSafeSSL *get_ssl_client() noexcept;
|
||||
|
||||
/**
|
||||
get_ssl_server. This function returns a pointer to the underlying server
|
||||
ThreadSafeSSL object. This function never returns a pointer that is null.
|
||||
This function does not throw and ideally should only be used during testing.
|
||||
@return a non-null pointer to the server's ThreadSafeSSL object.
|
||||
**/
|
||||
ThreadSafeSSL *get_ssl_server() noexcept;
|
||||
|
||||
/**
|
||||
destroy_ssl_server. This function destroys the underlying server
|
||||
ThreadSafeSSL object. This function does not throw. Note that this function does
|
||||
not free the underlying SSL connection.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerDestroySSL
|
||||
@remarks This function nulls out the previous server ThreadSafeSSL object.
|
||||
This is primarily for testing. Note that this should not be relied upon.
|
||||
**/
|
||||
void destroy_ssl_server() noexcept;
|
||||
|
||||
/**
|
||||
destroy_ssl_client. This function destroys the underlying client
|
||||
ThreadSafeSSL object. This function does not throw. Note that this function does
|
||||
not free the underlying SSL connection.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerDestroySSL
|
||||
@remarks This function nulls out the previous client ThreadSafeSSL object. This
|
||||
is primarily for testing. Note that this should not be relied upon.
|
||||
**/
|
||||
void destroy_ssl_client() noexcept;
|
||||
|
||||
/**
|
||||
register_new_socket_server. This function registers a new socket with the
|
||||
ThreadSafeSSL object for the server. This function does not throw. Note that
|
||||
this function asserts that the underlying ThreadSafeSSL object is set for
|
||||
safety.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerRegisterNewSocket
|
||||
@return the tag for the new server socket.
|
||||
**/
|
||||
unsigned register_new_socket_server() noexcept;
|
||||
/**
|
||||
register_new_socket_client. This function registers a new socket with the
|
||||
ThreadSafeSSL object for the client. This function does not throw. Note that
|
||||
this function asserts that the underlying ThreadSafeSSL object is set for
|
||||
safety.
|
||||
@snippet EmpSSLSocketManager.t.cpp EmpSSLSocketManagerRegisterNewSocket
|
||||
@return the tag for the new client socket.
|
||||
**/
|
||||
unsigned register_new_socket_client() noexcept;
|
||||
|
||||
} // namespace EmpSSLSocketManager
|
||||
|
||||
#endif
|
||||
@@ -1,45 +0,0 @@
|
||||
#include "../doctest.h"
|
||||
#include "EmpSSLSocketManager.hpp"
|
||||
#include "TestUtil.hpp"
|
||||
|
||||
//! [EmpSSLSocketManagerDefaults]
|
||||
TEST_CASE("defaults") {
|
||||
// These should be true.
|
||||
CHECK(EmpSSLSocketManager::get_ssl_client() == nullptr);
|
||||
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
|
||||
}
|
||||
//! [EmpSSLSocketManagerDefaults]
|
||||
|
||||
//! [EmpSSLSocketManagerCreateSSL]
|
||||
TEST_CASE("e2e") {
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
REQUIRE(context);
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
REQUIRE(ssl);
|
||||
|
||||
//! [EmpSSLSocketManagerCreateSSL]
|
||||
EmpSSLSocketManager::create_ssl_client(ssl.get());
|
||||
CHECK(EmpSSLSocketManager::get_ssl_client() != nullptr);
|
||||
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
|
||||
|
||||
EmpSSLSocketManager::create_ssl_server(ssl.get());
|
||||
CHECK(EmpSSLSocketManager::get_ssl_server() != nullptr);
|
||||
CHECK(EmpSSLSocketManager::get_ssl_client() !=
|
||||
EmpSSLSocketManager::get_ssl_server());
|
||||
//! [EmpSSLSocketManagerCreateSSL]
|
||||
|
||||
//! [EmpSSLSocketManagerRegisterNewSocket]
|
||||
for (unsigned i = 0; i < 100; i++) {
|
||||
CHECK(EmpSSLSocketManager::register_new_socket_client() == i);
|
||||
CHECK(EmpSSLSocketManager::register_new_socket_server() == i);
|
||||
}
|
||||
//! [EmpSSLSocketManagerRegisterNewSocket]
|
||||
|
||||
//! [EmpSSLSocketManagerDestroySSL]
|
||||
EmpSSLSocketManager::destroy_ssl_client();
|
||||
CHECK(EmpSSLSocketManager::get_ssl_client() == nullptr);
|
||||
CHECK(EmpSSLSocketManager::get_ssl_server() != nullptr);
|
||||
EmpSSLSocketManager::destroy_ssl_server();
|
||||
CHECK(EmpSSLSocketManager::get_ssl_server() == nullptr);
|
||||
//! [EmpSSLSocketManagerDestroySSL]
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
#include "EmpThreadSocket.hpp"
|
||||
@@ -1,230 +0,0 @@
|
||||
#ifndef INCLUDED_EMPTHREADSOCKET_HPP
|
||||
#define INCLUDED_EMPTHREADSOCKET_HPP
|
||||
|
||||
#include "../emp-tool/emp-tool/io/io_channel.h" // This contains the declarations for emp things.
|
||||
#include "EmpSSLSocketManager.hpp" // Needed for multiplexed sends and receives.
|
||||
#include "ThreadSafeSSL.hpp" // Needed for multiplexed sends and receives.
|
||||
#include "openssl/base.h" // Needed for various declarations.
|
||||
#include "ssl/internal.h" // Needed for various declarations.
|
||||
|
||||
#include "NoBuffer.hpp"
|
||||
#include "SSLBuffer.hpp"
|
||||
|
||||
/**
|
||||
EmpThreadSocket. This component realises a thread-safe wrapper around sockets
|
||||
for EMP AG2PC (https://github.com/emp-toolkit/emp-ag2pc). This wrapper exists
|
||||
because EMP AG2PC allows one to specify an arbitrary number of threads for
|
||||
setup, but in a TLS context this may be unrealistically expensive (or, at best,
|
||||
confusing). Equally, using a single socket and manually restricting to using a
|
||||
single thread seems to cause some deadlocks in the I/O code, which implies there
|
||||
is some sort of implicit assumption that threading is used.
|
||||
|
||||
To circumvent this issue, we use the following strategy to ensure that
|
||||
multiple threads can access the same socket whilst also playing nicely with
|
||||
existing EMP code:
|
||||
|
||||
1. The EmpThreadSocket class acts as a barrier class for EMP AG2PC.
|
||||
2. Mechanically, this class contains two things: a reference to a thread-safe
|
||||
TLS socket and a tag. The reference is simple: it is simply a thread safe way to
|
||||
do I/O without modifying EMP.
|
||||
|
||||
The tag is less straightforward: conceptually, it is a ticket which specifies
|
||||
when this particular EmpThreadSocket was created. Those of you who have shopped
|
||||
in the British Shoe shop Clarks will recognise this system: it is a numbered tag
|
||||
that specifies when an event happened, or the number of a shopper in the queue.
|
||||
The reason for this is because we want to be able to "tag" our messages as they
|
||||
go into the outside world to make sure they are delivered to the right thread on
|
||||
the other side of the network.
|
||||
|
||||
The reference to the thread-safe TLS socket deals with all of the mechanics
|
||||
of doing this: the important part is that each EmpThreadSocket has a way of
|
||||
knowing which messages are due for it, and those which aren't.
|
||||
|
||||
@remarks The underlying assumption here is that the EmpThreadSockets are all
|
||||
made from one thread (i.e the setup is done socket-by-socket). This is important
|
||||
because otherwise the connection ordering will get broken.
|
||||
@remarks For more details on e.g the template invocation here, please see
|
||||
EmpWrapper.hpp
|
||||
@remarks If you want to use this class, you _must_ call
|
||||
EmpThreadSocket::set_ssl _first_. This sets up a global ThreadSafeSSL object
|
||||
that handles all multiplexing. To protect against this, the constructor will
|
||||
assert against this invariant. Similarly, once all EmpThreadSockets have
|
||||
outserved their usefulness, you must also call the
|
||||
EmpThreadSocket::free_ssl function. We assert against this when creating new
|
||||
objects, so this will be fairly obvious in debug mode if this is not set.
|
||||
@tparam is_server: true if this class is a server class, false otherwise.
|
||||
@tparam BufferType: type of buffering scheme to use. See SSLBuffer.hpp for more.
|
||||
**/
|
||||
template <bool server, typename BufferType>
|
||||
class EmpThreadSocket
|
||||
: public emp::IOChannel<EmpThreadSocket<server, BufferType>> {
|
||||
|
||||
// These constructors are explicitly deleted to prevent them being called
|
||||
// accidentally.
|
||||
EmpThreadSocket() = delete;
|
||||
EmpThreadSocket(EmpThreadSocket &) = delete;
|
||||
EmpThreadSocket(EmpThreadSocket &&) = delete;
|
||||
|
||||
public:
|
||||
/**
|
||||
is_server. This variable indicates if this particular type of
|
||||
EmpThreadSocket is a server socket or not. This variable is unused and only
|
||||
exists to satisfy the requirements of EmpAG2PC.
|
||||
**/
|
||||
static constexpr bool is_server = server;
|
||||
|
||||
/**
|
||||
addr. This variable is a placeholder for EMPAG2PC. This isn't actually
|
||||
used for anything.
|
||||
**/
|
||||
inline static const std::string addr = "";
|
||||
|
||||
/**
|
||||
port. This variable is a placeholder for EMPAG2PC. This isn't actually used
|
||||
for anything.
|
||||
**/
|
||||
inline static const int port = 0;
|
||||
|
||||
/**
|
||||
EmpThreadSocket. This is the default constructor for this class.
|
||||
This constructor exists solely to satisfy the API requirements of
|
||||
emp::IOChannel, and as a result none of the arguments do anything. This
|
||||
function never throws. This function sets up the tag of this socket.
|
||||
@snippet EmpThreadSocket.t.cpp EmpThreadSocketConstruct
|
||||
**/
|
||||
inline EmpThreadSocket(const char *const, const int, bool) noexcept;
|
||||
|
||||
/**
|
||||
send_data_internal. This is a wrapper function that is called by EMP's
|
||||
IOChannel type. This function accepts a valid void* pointer, `data`, some
|
||||
amount of data `nbyte` and sends `data` over the underlying SSL channel. This
|
||||
function does not throw.
|
||||
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSendDataInternal.
|
||||
@tparam T: the type of nbyte.
|
||||
@param[in] data: the data to be sent. This must be a non-null pointer.
|
||||
@param[in] nbyte: the number of bytes to send. This value should be
|
||||
non-zero.
|
||||
@remarks Note that this function does not report any errors. This is
|
||||
because IOChannel's parent function does not allow the reporting of any
|
||||
errors.
|
||||
**/
|
||||
template <typename T>
|
||||
inline void send_data_internal(const void *const data,
|
||||
const T nbyte) noexcept;
|
||||
|
||||
/**
|
||||
recv_data_internal. This is a wrapper function that is called by
|
||||
EMP's IOChannel type. This function accepts a `void *` pointer, `data`,
|
||||
some amount of data `nbyte` and reads `nbyte`s into `data` using the
|
||||
underlying SSL object. This function does not throw any exceptions.
|
||||
|
||||
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSendDataInternalTests.
|
||||
@tparam T: the type of nbyte.
|
||||
@param[in] data: the buffer to store the read data. This pointer must be
|
||||
non-null.
|
||||
@param[in] nbyte: the number of bytes to read. This value should be non-zero,
|
||||
but we do not enforce this, instead relying on BoringSSL to do this for us.
|
||||
|
||||
@remarks Note that this function does not report any errors. This is
|
||||
because (unlike BoringSSL) the IOChannel's parent function does not allow the
|
||||
reporting of any errors. Instead, it simply prints to stderr. This seems hard
|
||||
to use in a browser setting, so we ignore it here.
|
||||
**/
|
||||
template <typename T>
|
||||
inline void recv_data_internal(void *const data, const T nbyte) noexcept;
|
||||
|
||||
/**
|
||||
flush. This function is a wrapper function that is called by EMP's
|
||||
IOChannel type. This function flushes to the output if a buffering scheme is
|
||||
used and does nothing otherwise.
|
||||
|
||||
This function does not throw. In case where the socket does not support
|
||||
buffering, this function also does not modify `this` object.
|
||||
@remarks This function is only enabled if the BufferType does not support
|
||||
buffering. The template parameter `T` is here to allow SFINAE to work (C++ has
|
||||
some complicated rules around SFINAE and dependent templates).
|
||||
**/
|
||||
template <typename T = void>
|
||||
inline std::enable_if_t<!BufferType::can_buffer(), T> flush() const noexcept;
|
||||
|
||||
/**
|
||||
flush. This function is a wrapper function that is called by EMP's
|
||||
IOChannel type. This function flushes to the output if a buffering scheme is
|
||||
used and does nothing otherwise.
|
||||
|
||||
This function does not throw. In case where the socket does not support
|
||||
buffering, this function also does not modify `this` object.
|
||||
@remarks This function is only enabled if the BufferType supports buffering.
|
||||
The template parameter `T` is here to allow SFINAE to work (C++ has some
|
||||
complicated rules around SFINAE and dependent templates).
|
||||
**/
|
||||
template <typename T = void>
|
||||
inline std::enable_if_t<BufferType::can_buffer(), T> flush() noexcept;
|
||||
|
||||
/**
|
||||
set_ssl. This static function sets the global ThreadSafeSSL object to the
|
||||
`ssl` argument. This exists solely to allow all EmpThreadSockets to use the
|
||||
same SSL connection. This function does not throw.
|
||||
@snippet EmpThreadSocket.t.cpp EmpThreadSocketSetSSL
|
||||
@param[in] ssl: the ssl connection to use. Must not be null.
|
||||
|
||||
**/
|
||||
static inline void set_ssl(SSL *const ssl) noexcept;
|
||||
|
||||
/**
|
||||
destroy_ssl. This static function destroys the global ThreadSafeSSL object.
|
||||
@snippet EmpThreadSocket.t.cpp EmpThreadSocketDestroySSL
|
||||
**/
|
||||
static inline void destroy_ssl() noexcept;
|
||||
/**
|
||||
get_socket. This returns a copy of the socket pointer that this class uses.
|
||||
This exists solely to allow us to pass static variables from the .cpp file to
|
||||
the .inl file. This function does not throw. This should only be used during
|
||||
testing.
|
||||
@return a copy of the socket object.
|
||||
**/
|
||||
static inline ThreadSafeSSL *get_socket() noexcept;
|
||||
|
||||
/**
|
||||
get_tag. This function returns a copy of the tag assigned to this socket.
|
||||
This function does not throw or modify this object.
|
||||
@return a copy of the tag of this object.
|
||||
**/
|
||||
inline unsigned get_tag() const noexcept;
|
||||
|
||||
private:
|
||||
/**
|
||||
tag. This tag identifies the thread of this particular socket. This is used
|
||||
to make sure that messages reach the right thread.
|
||||
**/
|
||||
unsigned tag;
|
||||
|
||||
/**
|
||||
buffer. This is the internal write buffer. This buffer is used to store
|
||||
outgoing writes to mitigate delays and reduce the number of system calls.
|
||||
**/
|
||||
BufferType buffer;
|
||||
};
|
||||
|
||||
// Inline definitions live here.
|
||||
#include "EmpThreadSocket.inl"
|
||||
|
||||
/**
|
||||
EmpClientSocket. This class is a thread safe SSL socket that is to be used
|
||||
by clients. For more information, see EmpThreadSocket in EmpThreadSocket.hpp
|
||||
**/
|
||||
using EmpClientSocket = EmpThreadSocket<false, SSLBuffer>;
|
||||
|
||||
/**
|
||||
EmpServerSocket. This class is a thread safe SSL socket that is to be used
|
||||
by server. For more information, see EmpThreadSocket in EmpThreadSocket.hpp
|
||||
**/
|
||||
using EmpServerSocket = EmpThreadSocket<true, SSLBuffer>;
|
||||
|
||||
// This template class simply returns the type of EmpThreadSocket that should
|
||||
// be used. This is primarily for neatness.
|
||||
template <bool type> struct EmpSocketDispatch;
|
||||
template <> struct EmpSocketDispatch<false> { using type = EmpClientSocket; };
|
||||
template <> struct EmpSocketDispatch<true> { using type = EmpServerSocket; };
|
||||
|
||||
#endif
|
||||
@@ -1,117 +0,0 @@
|
||||
#ifndef INCLUDED_EMPTHREADSOCKET_HPP
|
||||
#error Do not include EmpThreadSocket.inl without EmpThreadSocket.hpp
|
||||
#endif
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
EmpThreadSocket<is_server, BufferType>::EmpThreadSocket(const char *const,
|
||||
const int,
|
||||
bool) noexcept
|
||||
// N.B the EmpSSLSocketManager does the assertion on if the underlying
|
||||
// object exists.
|
||||
: tag{(is_server) ? EmpSSLSocketManager::register_new_socket_server()
|
||||
: EmpSSLSocketManager::register_new_socket_client()},
|
||||
buffer{} {}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
template <typename T>
|
||||
void EmpThreadSocket<is_server, BufferType>::send_data_internal(
|
||||
const void *const data, const T nbyte) noexcept {
|
||||
// We only accept integral sizes here.
|
||||
static_assert(
|
||||
std::is_integral_v<T>,
|
||||
"Error: send_data_internal can only be instantiated with integral types");
|
||||
|
||||
// For casting purposes we only use a fixed-type here.
|
||||
using SizeType = ThreadSafeSSL::SizeType;
|
||||
|
||||
// It turns out that emp can actually call this with a zero size.
|
||||
// If that happens, just quit
|
||||
if (nbyte == 0) {
|
||||
return;
|
||||
}
|
||||
// We also need to make sure we aren't sending too much.
|
||||
assert(nbyte <= std::numeric_limits<SizeType>::max());
|
||||
|
||||
// If buffering is supported, then buffer.
|
||||
if constexpr (BufferType::can_buffer()) {
|
||||
buffer.buffer_data(data, static_cast<SizeType>(nbyte));
|
||||
} else {
|
||||
auto *socket = EmpThreadSocket::get_socket();
|
||||
assert(socket);
|
||||
socket->send(tag, data, static_cast<SizeType>(nbyte));
|
||||
}
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
template <typename T>
|
||||
void EmpThreadSocket<is_server, BufferType>::recv_data_internal(
|
||||
void *const data, const T nbyte) noexcept {
|
||||
|
||||
// We only accept integral sizes here.
|
||||
static_assert(
|
||||
std::is_integral_v<T>,
|
||||
"Error: send_data_internal can only be instantiated with integral types");
|
||||
|
||||
// For casting purposes we only use a fixed-type here.
|
||||
using SizeType = ThreadSafeSSL::SizeType;
|
||||
// It turns out that emp can actually call this with a zero size.
|
||||
// If that happens, just quit.
|
||||
if (nbyte == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We also need to make sure we aren't reading too much.
|
||||
assert(nbyte <= std::numeric_limits<SizeType>::max());
|
||||
auto *socket = EmpThreadSocket::get_socket();
|
||||
assert(socket);
|
||||
socket->recv(tag, data, static_cast<SizeType>(nbyte));
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
template <typename T>
|
||||
inline typename std::enable_if_t<BufferType::can_buffer(), T>
|
||||
EmpThreadSocket<is_server, BufferType>::flush() noexcept {
|
||||
static_assert(std::is_same_v<T, void>,
|
||||
"Error: can only instantiate can_buffer with T = void");
|
||||
const auto size = buffer.size();
|
||||
if (size == 0)
|
||||
return;
|
||||
auto *socket = EmpThreadSocket::get_socket();
|
||||
assert(socket);
|
||||
socket->send(tag, buffer.data(), size);
|
||||
buffer.clear();
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
template <typename T>
|
||||
inline typename std::enable_if_t<!BufferType::can_buffer(), T>
|
||||
EmpThreadSocket<is_server, BufferType>::flush() const noexcept {
|
||||
static_assert(std::is_same_v<T, void>,
|
||||
"Error: can only instantiate can_buffer with T = void");
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
inline void
|
||||
EmpThreadSocket<is_server, BufferType>::set_ssl(SSL *const ssl) noexcept {
|
||||
(is_server) ? EmpSSLSocketManager::create_ssl_server(ssl)
|
||||
: EmpSSLSocketManager::create_ssl_client(ssl);
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
inline void EmpThreadSocket<is_server, BufferType>::destroy_ssl() noexcept {
|
||||
(is_server) ? EmpSSLSocketManager::destroy_ssl_server()
|
||||
: EmpSSLSocketManager::destroy_ssl_client();
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
inline ThreadSafeSSL *
|
||||
EmpThreadSocket<is_server, BufferType>::get_socket() noexcept {
|
||||
return (is_server) ? EmpSSLSocketManager::get_ssl_server()
|
||||
: EmpSSLSocketManager::get_ssl_client();
|
||||
}
|
||||
|
||||
template <bool is_server, typename BufferType>
|
||||
inline unsigned
|
||||
EmpThreadSocket<is_server, BufferType>::get_tag() const noexcept {
|
||||
return tag;
|
||||
}
|
||||
@@ -1,117 +0,0 @@
|
||||
#include "../doctest.h"
|
||||
#include "EmpThreadSocket.hpp"
|
||||
#include "TLSSocket.hpp"
|
||||
#define SOCKET_SETUP
|
||||
#include "TestUtil.hpp"
|
||||
#include <thread>
|
||||
|
||||
// Just for testing across multiple types.
|
||||
template <typename T> struct EmpThreadSocketTestClass { using type = T; };
|
||||
|
||||
//! [EmpThreadSocketSetSSL]
|
||||
TEST_CASE_TEMPLATE("set_ssl", socket_type,
|
||||
EmpThreadSocketTestClass<EmpClientSocket>,
|
||||
EmpThreadSocketTestClass<EmpServerSocket>) {
|
||||
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
REQUIRE(context);
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
REQUIRE(ssl);
|
||||
|
||||
using SocketType = typename socket_type::type;
|
||||
|
||||
CHECK(SocketType::get_socket() == nullptr);
|
||||
SocketType::set_ssl(ssl.get());
|
||||
CHECK(SocketType::get_socket() != nullptr);
|
||||
CHECK(SocketType::get_socket()->get_ssl() == ssl.get());
|
||||
|
||||
//! [EmpThreadSocketDestroySSL]
|
||||
CHECK(SocketType::get_socket() != nullptr);
|
||||
SocketType::destroy_ssl();
|
||||
CHECK(SocketType::get_socket() == nullptr);
|
||||
//! [EmpThreadSocketDestroySSL]
|
||||
}
|
||||
//! [EmpThreadSocketSetSSL]
|
||||
|
||||
//! [EmpThreadSocketConstruct]
|
||||
TEST_CASE_TEMPLATE("construct", socket_type,
|
||||
EmpThreadSocketTestClass<EmpClientSocket>,
|
||||
EmpThreadSocketTestClass<EmpServerSocket>) {
|
||||
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
REQUIRE(context);
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
REQUIRE(ssl);
|
||||
using SocketType = typename socket_type::type;
|
||||
|
||||
SocketType::set_ssl(ssl.get());
|
||||
|
||||
SUBCASE("Just one") {
|
||||
SocketType sock("", 0, true); // The arguments here don't matter.
|
||||
CHECK(sock.get_socket()->get_ssl() == ssl.get());
|
||||
CHECK(sock.get_tag() == 0);
|
||||
}
|
||||
|
||||
SUBCASE("Many") {
|
||||
for (unsigned i = 0; i < 100; i++) {
|
||||
SocketType sock("", 0, true); // The arguments here don't matter.
|
||||
CHECK(sock.get_socket()->get_ssl() == ssl.get());
|
||||
CHECK(sock.get_tag() == i);
|
||||
}
|
||||
}
|
||||
|
||||
SocketType::destroy_ssl();
|
||||
}
|
||||
//! [EmpThreadSocketConstruct]
|
||||
|
||||
// This test isn't templated because it isn't needed.
|
||||
TEST_CASE("send_recv") {
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
REQUIRE(context);
|
||||
std::unique_ptr<TLSSocket> server, client;
|
||||
// Setup the connections.
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
|
||||
// Now we'll practice sending from the client to the server.
|
||||
// We'll do this using an existing connection from a ThreadSafeSSL object.
|
||||
EmpSSLSocketManager::create_ssl_server(server->get_ssl_object());
|
||||
EmpSSLSocketManager::create_ssl_client(client->get_ssl_object());
|
||||
|
||||
// Make some data to send. This is fixed data, but there's no reason for that.
|
||||
std::array<uint8_t, 100> data;
|
||||
std::iota(data.begin(), data.end(), 0);
|
||||
|
||||
EmpClientSocket client_socket("", 0,
|
||||
true); // The arguments here don't matter.
|
||||
EmpServerSocket server_socket("", 0,
|
||||
true); // The arguments here don't matter.
|
||||
|
||||
auto client_code = [&]() {
|
||||
//! [EmpThreadSocketSendDataInternal]
|
||||
client_socket.send_data_internal(data.data(), data.size());
|
||||
client_socket.flush();
|
||||
//! [EmpThreadSocketSendDataInternal]
|
||||
|
||||
//! [EmpThreadSocketRecvDataInternal]
|
||||
std::array<uint8_t, 100> in;
|
||||
client_socket.recv_data_internal(in.data(), in.size());
|
||||
CHECK(in == data);
|
||||
//! [EmpThreadSocketRecvDataInternal]
|
||||
};
|
||||
|
||||
auto server_code = [&]() {
|
||||
std::array<uint8_t, 100> in;
|
||||
server_socket.recv_data_internal(in.data(), in.size());
|
||||
CHECK(in == data);
|
||||
server_socket.send_data_internal(in.data(), in.size());
|
||||
server_socket.flush();
|
||||
};
|
||||
|
||||
std::thread server_thread(server_code);
|
||||
client_code();
|
||||
server_thread.join();
|
||||
|
||||
// Have to destroy these at the end.
|
||||
EmpSSLSocketManager::destroy_ssl_server();
|
||||
EmpSSLSocketManager::destroy_ssl_client();
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "../doctest.h"
|
||||
#include "EmpWrapperAG2PC.hpp"
|
||||
#include "SSLScope.hpp"
|
||||
#include "TLSSocket.hpp"
|
||||
#define SOCKET_SETUP
|
||||
#include "../mta/F2128MtA.hpp"
|
||||
@@ -305,8 +304,6 @@ TEST_CASE("aes_gcm_derivation") {
|
||||
REQUIRE(context);
|
||||
std::unique_ptr<TLSSocket> client, server;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
SSLScope<true> server_scope(server->get_ssl_object());
|
||||
SSLScope<false> client_scope(client->get_ssl_object());
|
||||
|
||||
std::unique_ptr<EmpWrapperAG2PC> client_circ;
|
||||
std::unique_ptr<EmpWrapperAG2PC> server_circ;
|
||||
@@ -450,8 +447,6 @@ TEST_CASE_TEMPLATE(
|
||||
REQUIRE(context);
|
||||
std::unique_ptr<TLSSocket> client, server;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
SSLScope<true> server_scope(server->get_ssl_object());
|
||||
SSLScope<false> client_scope(client->get_ssl_object());
|
||||
|
||||
std::unique_ptr<EmpWrapperAG2PC> client_circ;
|
||||
std::unique_ptr<EmpWrapperAG2PC> server_circ;
|
||||
@@ -693,8 +688,6 @@ TEST_CASE("derive_combined_traffic") {
|
||||
REQUIRE(context);
|
||||
std::unique_ptr<TLSSocket> client, server;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
SSLScope<true> server_scope(server->get_ssl_object());
|
||||
SSLScope<false> client_scope(client->get_ssl_object());
|
||||
|
||||
std::unique_ptr<EmpWrapperAG2PC> client_circ;
|
||||
std::unique_ptr<EmpWrapperAG2PC> server_circ;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef INCLUDED_SSLBUFFER_HPP
|
||||
#define INCLUDED_SSLBUFFER_HPP
|
||||
|
||||
#include "ThreadSafeSSL.hpp"
|
||||
#include "openssl/base.h"
|
||||
#include "ssl/internal.h"
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@@ -27,10 +28,9 @@ template <size_t flush_size = SSL3_RT_MAX_PLAIN_LENGTH, bool hold_flush = true>
|
||||
class SSLBufferPolicy {
|
||||
public:
|
||||
/**
|
||||
SizeType. We use the same SizeType type as ThreadSafeSSL for compatibility
|
||||
purposes.
|
||||
SizeType. We forward declare a size type for compatibility's sake.
|
||||
**/
|
||||
using SizeType = ThreadSafeSSL::SizeType;
|
||||
using SizeType = unsigned;
|
||||
|
||||
/**
|
||||
has_data. This function always returns true. It should be used to indicate
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
#ifndef INCLUDED_SSLSCOPE_HPP
|
||||
#define INCLUDED_SSLSCOPE_HPP
|
||||
|
||||
#include "EmpThreadSocket.hpp"
|
||||
|
||||
/**
|
||||
SSLScope. This struct exists solely to provide a RAII wrapper for socket
|
||||
management inside this project. In essence, because the tagged sockets
|
||||
(EmpThreadSocket) all rely upon global variables, it's easier to tie the
|
||||
lifetime of those sockets to the lifetime of the global socket object. In our
|
||||
case, we want to be able to tie the lifetime of those sockets to the Server
|
||||
object that owns the underlying socket. This class just makes all of that
|
||||
easier.
|
||||
**/
|
||||
|
||||
template <bool is_server> class SSLScope {
|
||||
using SocketType = typename EmpSocketDispatch<is_server>::type;
|
||||
|
||||
public:
|
||||
SSLScope(SSL *const ssl) noexcept { SocketType::set_ssl(ssl); }
|
||||
~SSLScope() noexcept { SocketType::destroy_ssl(); }
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,232 +0,0 @@
|
||||
#include "ThreadSafeSSL.hpp"
|
||||
#include "Util.hpp" // Needed for generic I/O routines.
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
// Our datagrams in this file have the following two leading entries:
|
||||
// 1. The tag where the message is due to be stored, and
|
||||
// 2. The number of bytes in the message.
|
||||
// We arbitrarily assume this fits into 64 bits. This should be true, as we
|
||||
// don't expect either party to have more than 256 threads available, or for
|
||||
// anyone to be sending messages larger than 2^56 bytes long.
|
||||
struct Packed {
|
||||
unsigned header : 8;
|
||||
size_t size : 56;
|
||||
};
|
||||
|
||||
// Check the packet header fits exactly into 64 bits.
|
||||
static_assert(sizeof(Packed) == sizeof(uint64_t));
|
||||
|
||||
ThreadSafeSSL::ThreadSafeSSL(SSL *const ssl_in) noexcept
|
||||
: count{}, socket_lock{}, ssl{ssl_in}, registered_in{},
|
||||
registered_out{}, incoming{} {
|
||||
assert(ssl);
|
||||
// N.B there's a constexpr way to do this initialisation, but it's rather
|
||||
// long. If this is too slow, we can fix it.
|
||||
std::fill(registered_in.begin(), registered_in.end(),
|
||||
ThreadSafeSSL::tombstone);
|
||||
std::fill(registered_out.begin(), registered_out.end(),
|
||||
ThreadSafeSSL::tombstone);
|
||||
}
|
||||
|
||||
SSL *ThreadSafeSSL::get_ssl() noexcept { return ssl; }
|
||||
|
||||
unsigned ThreadSafeSSL::register_new_socket() noexcept {
|
||||
assert(count + 1 < ThreadSafeSSL::max_size);
|
||||
return count++;
|
||||
}
|
||||
|
||||
unsigned ThreadSafeSSL::find_first_tombstone() const noexcept {
|
||||
// Broken out for readability.
|
||||
constexpr auto predicate = [](const uint8_t value) {
|
||||
return value == ThreadSafeSSL::tombstone;
|
||||
};
|
||||
|
||||
const auto it = std::find_if(std::cbegin(registered_out),
|
||||
std::cbegin(registered_out) + count, predicate);
|
||||
if (it == std::cbegin(registered_out) + count) {
|
||||
return ThreadSafeSSL::tombstone;
|
||||
}
|
||||
|
||||
// This cast is safe because we have an upper limit on the size of
|
||||
// registered_in.
|
||||
return static_cast<unsigned>(std::distance(std::cbegin(registered_out), it));
|
||||
}
|
||||
|
||||
void ThreadSafeSSL::recv(const unsigned tag, void *const data,
|
||||
const unsigned nbyte) noexcept {
|
||||
|
||||
// Precondition: data must be a valid pointer.
|
||||
assert(data);
|
||||
|
||||
// Make sure the caller hasn't messed up.
|
||||
assert(tag < ThreadSafeSSL::max_size);
|
||||
assert(tag < count);
|
||||
|
||||
// First thing we do is lock the entire class.
|
||||
std::lock_guard<std::mutex> lock(socket_lock);
|
||||
|
||||
// Now we have three situations:
|
||||
// 1. We already have a message ready for us. Let's not read from the socket:
|
||||
// we'll just read from our slot and return that.
|
||||
if (incoming[tag].size() > 0) {
|
||||
auto buffer = incoming[tag].data();
|
||||
// Stop us from reading too much data.
|
||||
assert(nbyte <= incoming[tag].size());
|
||||
memcpy(data, buffer, nbyte);
|
||||
incoming[tag].erase(incoming[tag].begin(), incoming[tag].begin() + nbyte);
|
||||
return;
|
||||
}
|
||||
|
||||
// The other cases are:
|
||||
// 2. We've entered this function with a thread that has never received data
|
||||
// before.
|
||||
// 3. We've entered this function with a thread that has received data.
|
||||
// before, but there's no messages in the buffer for us.
|
||||
// In any case, we need to read the header to find out.
|
||||
Packed header;
|
||||
[[maybe_unused]] const auto header_bytes =
|
||||
SSL_read(ssl, &header, sizeof(header));
|
||||
// Only in debug.
|
||||
assert(header_bytes == sizeof(header));
|
||||
|
||||
// We actually handle both cases in a single loop.
|
||||
size_t amount_read_for_this = 0;
|
||||
while (true) {
|
||||
// This is just an abbreviation.
|
||||
const auto dest_tag = registered_in[header.header];
|
||||
// If the message is for us, we just read it already.
|
||||
if (dest_tag == tag) {
|
||||
// We might have more data to read here than we wanted. To stop us from
|
||||
// over running the buffer if that's the case, we:
|
||||
// 1. Read the first nbytes into the `data` buffer, and
|
||||
// 2. Read the rest into the buffer for this tag. This will be read into
|
||||
// the next time the tag calls this particular function.
|
||||
if (nbyte >= header.size) {
|
||||
Util::process_data(ssl,
|
||||
static_cast<char *>(data) + amount_read_for_this,
|
||||
header.size, SSL_read);
|
||||
amount_read_for_this += header.size;
|
||||
} else {
|
||||
// If we've hit this case then there's nothing in _our_ buffer already,
|
||||
// so we can just resize and copy from the beginning.
|
||||
Util::process_data(ssl, static_cast<char *>(data), nbyte, SSL_read);
|
||||
incoming[tag].resize(header.size - nbyte);
|
||||
Util::process_data(ssl, incoming[tag].data(), header.size - nbyte,
|
||||
SSL_read);
|
||||
amount_read_for_this += nbyte;
|
||||
}
|
||||
} else if (dest_tag == ThreadSafeSSL::tombstone) {
|
||||
// Case 2: a new thread has appeared.
|
||||
// It's possible that we've read on a thread that has already been
|
||||
// assigned to another thread (e.g on the other side) but we've received
|
||||
// a new message anyway. We patch that here by finding the first thread on
|
||||
// this node that hasn't yet been assigned to a thread on the other node.
|
||||
const auto thread = (registered_out[tag] == ThreadSafeSSL::tombstone)
|
||||
? tag
|
||||
: find_first_tombstone();
|
||||
// This will fire if there's no free threads on our side to deal with those
|
||||
// messages.
|
||||
assert(thread != ThreadSafeSSL::tombstone);
|
||||
const auto thread_as_uint8 = static_cast<uint8_t>(thread);
|
||||
// Set up the forwarding tables.
|
||||
registered_out[thread] = header.header;
|
||||
registered_in[header.header] = thread_as_uint8;
|
||||
|
||||
// Read into this tag's buffer if appropriate.
|
||||
if (thread == tag) {
|
||||
Util::process_data(ssl, static_cast<char *>(data),
|
||||
static_cast<size_t>(header.size), SSL_read);
|
||||
amount_read_for_this = static_cast<unsigned>(header.size);
|
||||
} else {
|
||||
incoming[thread].resize(header.size);
|
||||
Util::process_data(ssl, incoming[thread].data(),
|
||||
static_cast<size_t>(header.size), SSL_read);
|
||||
}
|
||||
|
||||
// Now we have to tell the other thread who they are speaking to.
|
||||
[[maybe_unused]] const auto written_bytes =
|
||||
SSL_write(ssl, &thread_as_uint8, sizeof(thread_as_uint8));
|
||||
assert(written_bytes > 0);
|
||||
} else {
|
||||
// Otherwise, we are going to append into the buffer that the message is
|
||||
// due for. To do that, we have to reallocate enough memory to be able to
|
||||
// read into the buffer for that particular tag.
|
||||
const auto end_pos = incoming[dest_tag].size();
|
||||
incoming[dest_tag].resize(end_pos + header.size);
|
||||
Util::process_data(ssl, incoming[dest_tag].data() + end_pos,
|
||||
static_cast<std::size_t>(nbyte), SSL_read);
|
||||
}
|
||||
|
||||
// If we've read everything, then quit.
|
||||
if (amount_read_for_this == nbyte) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, there must be data left, so quit.
|
||||
[[maybe_unused]] const auto head_bytes =
|
||||
SSL_read(ssl, &header, sizeof(header));
|
||||
assert(head_bytes == sizeof(header));
|
||||
// We'll deal with everything else on the next iteration.
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadSafeSSL::send(const unsigned tag, const void *const data,
|
||||
const unsigned nbyte) noexcept {
|
||||
|
||||
// Precondition: cannot send null data.
|
||||
assert(data);
|
||||
|
||||
// To prevent the caller from messing up.
|
||||
assert(tag < ThreadSafeSSL::max_size);
|
||||
assert(tag < count);
|
||||
|
||||
// Lock the class to prevent race conditions.
|
||||
std::lock_guard<std::mutex> lock(socket_lock);
|
||||
|
||||
// This code mirrors the code in recv quite closely.
|
||||
// Here, though, we only have two cases:
|
||||
// 1. We have a thread that is new.
|
||||
// 2. We have a thread that is already established.
|
||||
|
||||
// Remarkably the code is mostly the same: the difference is if we
|
||||
// need to wait for the other node to send a message back to us.
|
||||
|
||||
// In either case we need to use a header for writing.
|
||||
// N.B This cast is fine as tag can fit into a uint8_t.
|
||||
Packed header{static_cast<uint8_t>(tag), nbyte};
|
||||
[[maybe_unused]] const auto bytes_written =
|
||||
SSL_write(ssl, &header, sizeof(header));
|
||||
assert(bytes_written > 0);
|
||||
|
||||
// Now we just write everything else out too.
|
||||
Util::process_data(ssl, static_cast<const char *>(data),
|
||||
static_cast<size_t>(nbyte), SSL_write);
|
||||
|
||||
// Now we need to check if this was the first contact with another thread or
|
||||
// not.
|
||||
if (registered_out[tag] == ThreadSafeSSL::tombstone) {
|
||||
// We just read the tag back from the other side.
|
||||
// N.B This assumes that both parties have the same endianness!
|
||||
uint8_t other_tag;
|
||||
[[maybe_unused]] const auto header_bytes =
|
||||
SSL_read(ssl, &other_tag, sizeof(other_tag));
|
||||
assert(header_bytes == sizeof(other_tag));
|
||||
// Now we set up the tables.
|
||||
registered_out[tag] = other_tag;
|
||||
// This cast is safe because we statically enforce a maximum size on
|
||||
// `tag`.
|
||||
registered_in[other_tag] = static_cast<uint8_t>(tag);
|
||||
}
|
||||
}
|
||||
|
||||
bool ThreadSafeSSL::is_registered_in(const unsigned tag) const noexcept {
|
||||
assert(tag < ThreadSafeSSL::max_size);
|
||||
assert(tag < count);
|
||||
return registered_in[tag] != ThreadSafeSSL::tombstone;
|
||||
}
|
||||
|
||||
bool ThreadSafeSSL::is_registered_out(const unsigned tag) const noexcept {
|
||||
assert(tag < ThreadSafeSSL::max_size);
|
||||
assert(tag < count);
|
||||
return registered_in[tag] != ThreadSafeSSL::tombstone;
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
#ifndef INCLUDED_THREADSAFESSL_HPP
|
||||
#define INCLUDED_THREADSAFESSL_HPP
|
||||
|
||||
#include "openssl/base.h"
|
||||
#include "ssl/internal.h"
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
/**
|
||||
ThreadSafeSSL. This component realises a thread-safe SSL connection.
|
||||
This component exists solely to allow the re-use of an SSL connection inside
|
||||
certain parts of EMP.
|
||||
|
||||
Details: this socket is essentially a multiplexing socket.
|
||||
As an analogy, you can consider two buildings full of people.
|
||||
Normally the fastest way for these people to communicate would be to directly
|
||||
phone or email each other, but in absence of that we can use a postal
|
||||
service.
|
||||
|
||||
Assuming we only know the address, how can we make sure that the messages get
|
||||
to the right people?
|
||||
|
||||
One simplifying assumption is that all of the people in both buildings are
|
||||
capable of the same work. As a result, it just matters that the same two people
|
||||
always speak in a pair, rather than to other people.
|
||||
|
||||
Let's label the buildings as A and B respectively. We'll assume that both
|
||||
sides have the same number of people, labelled 0,..., n-1.
|
||||
|
||||
When person 0 in building A sends a message to building B, they will
|
||||
tag their message with `0` before sending. Then, some person `i` in building
|
||||
B will collect that message and note that they are dealing with person 0. Before
|
||||
person `i` deals with the message, they will send a short note back to person
|
||||
`0` that tells them they are dealing with person `i`. This is primarily for
|
||||
consistency across both buildings.
|
||||
|
||||
Iterating this process, we can then build a multiplexed socket for multiple
|
||||
threads.
|
||||
|
||||
|
||||
@remarks As an implementation detail, this class sends all messages prepended
|
||||
with a 64-bit tag. This tag is divided into 8-bits for the ID (i.e the person
|
||||
number in the above example) and the rest is for the length. Whilst this is an
|
||||
artificial restriction, in practice it seems reasonable: 2^56 bytes is 72
|
||||
petabytes, which seems unrealistic. The 2^8 = 256 thread count also seems
|
||||
reasonable, although for other reasons we actually restrict this to 255 threads.
|
||||
|
||||
@remarks This class assumes that both platforms have the same endianness:
|
||||
there's no format independent encoding in this class. We can add this later if
|
||||
needed.
|
||||
**/
|
||||
|
||||
class ThreadSafeSSL {
|
||||
|
||||
public:
|
||||
/**
|
||||
SizeType. This is the type that is used as a size parameter in send and
|
||||
recv. This declaration is here so that callers know what type to which they
|
||||
should cast their size parameters.
|
||||
**/
|
||||
using SizeType = unsigned;
|
||||
|
||||
/**
|
||||
register_new_socket. This function increments the number of sockets
|
||||
associated with this SSL object and returns the tag for the caller. The value
|
||||
returned is always the previous value of `count`: this means that the first
|
||||
caller gets 0, the second caller gets 1 and so on. This function does not
|
||||
throw.
|
||||
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisterNewSocket
|
||||
@return the tag associated with the calling socket.
|
||||
**/
|
||||
unsigned register_new_socket() noexcept;
|
||||
|
||||
/**
|
||||
send. This function sends `nbyte` of `data` to the thread associated with
|
||||
`tag`. This function does not throw. Note that similarly to other socket
|
||||
classes that interface with EMP, this function does not return any error
|
||||
information.
|
||||
@param[in] tag: the tag associated with the calling socket. Must be less
|
||||
than ThreadSafeSSL::max_size.
|
||||
@param[in] data: the data to be sent. Must not be null.
|
||||
@param[in] nbyte: the number of bytes to send.
|
||||
**/
|
||||
void send(const unsigned tag, const void *const data,
|
||||
const SizeType nbyte) noexcept;
|
||||
/**
|
||||
recv. This function receives at most `nbyte` of data meant for the thread
|
||||
associated with `tag`. This function does not throw. Note that similarly to
|
||||
other socket classes that interface with EMP, this function does not return
|
||||
any error information.
|
||||
@param[in] tag: the tag associated with the calling socket. Must be less than
|
||||
ThreadSafeSSL::max_size.
|
||||
@param[in] data: the location to store the incoming data. Must not be null. In
|
||||
debug builds, we assert to this.
|
||||
@param[in] nbyte: the number of bytes to read.
|
||||
**/
|
||||
void recv(const unsigned tag, void *const data,
|
||||
const SizeType nbyte) noexcept;
|
||||
|
||||
/**
|
||||
ThreadSafeSSL. This constructor builds the socket. This constructor does
|
||||
not throw.
|
||||
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLConstructor
|
||||
@param[in] ssl_in: the input ssl connection to use. Must not be null.
|
||||
**/
|
||||
ThreadSafeSSL(SSL *const ssl_in) noexcept;
|
||||
|
||||
/**
|
||||
get_ssl. Returns a copy of the SSL object associated with `this`
|
||||
ThreadSafeSSL object. This function never returns a null pointer and never
|
||||
throws.
|
||||
@snippet ThreadSafeSocket.t.cpp ThreadSafeSSLConstructor
|
||||
@returns a non-null copy of the `ssl` object.
|
||||
**/
|
||||
SSL *get_ssl() noexcept;
|
||||
|
||||
/**
|
||||
is_registered_in. This function returns true if `tag` is registered as
|
||||
incoming at this node. This means that the thread with `tag` as their
|
||||
identifier has a thread on this node that they consistently communicate with.
|
||||
This function does not throw.
|
||||
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisteredInOut
|
||||
@param[in] tag: the tag that we are looking up. Must be less than
|
||||
ThreadSafeSSL::max_size.
|
||||
@return true if tag has a corresponding thread that it communicates with,
|
||||
false otherwise.
|
||||
**/
|
||||
bool is_registered_in(const unsigned tag) const noexcept;
|
||||
|
||||
/**
|
||||
is_registered_out. This function returns true if `tag` is registered as
|
||||
outgoing at this node. This means that the thread with `tag` as their
|
||||
identifier has a thread on the node at the other end of the channel that they
|
||||
consistently communicate with.
|
||||
@snippet ThreadSafeSSL.t.cpp ThreadSafeSSLRegisteredInOut
|
||||
@param[in] tag: the tag that we are looking up. Must be less than
|
||||
ThreadSafeSSL::max_size.
|
||||
@return true if tag has a corresponding thread that it communicates with,
|
||||
false otherwise.
|
||||
**/
|
||||
bool is_registered_out(const unsigned tag) const noexcept;
|
||||
|
||||
private:
|
||||
/**
|
||||
max_size. This is the maximum number of threads we support in this class.
|
||||
This is actually set here to make certain declarations nicer to write.
|
||||
**/
|
||||
static constexpr uint8_t max_size = 254;
|
||||
|
||||
/**
|
||||
tombstone. This is the value that is set in the registered_in and
|
||||
registered_out arrays to denote that the thread has not yet been seen or
|
||||
registered. This is by default max_size + 1.
|
||||
**/
|
||||
static constexpr uint8_t tombstone = max_size + 1;
|
||||
|
||||
/**
|
||||
count. This is the counter of the number of sockets that have been
|
||||
registered with `this` socket. Note that despite the unsigned nature of this
|
||||
counter the actual size is at most 255, which we assert to in the
|
||||
register_new_socket function in release modes.
|
||||
**/
|
||||
std::atomic<unsigned> count;
|
||||
|
||||
/**
|
||||
socket_lock. This is the lock for this class. All operations that involve
|
||||
any variables of this class use this lock.
|
||||
**/
|
||||
std::mutex socket_lock;
|
||||
/**
|
||||
ssl. This is the ssl connection to use. Note that this is never null
|
||||
after this socket has been connected.
|
||||
**/
|
||||
SSL *ssl;
|
||||
|
||||
/**
|
||||
registered_in. This array contains the IDs of all threads that have been
|
||||
registered _in_ at this socket. If an entry here is not set to `tombstone`,
|
||||
then it means that it has been registered with a thread. In more detail, if
|
||||
registered_in[i] is not `tombstone`, then the value of registered_in[i] is the
|
||||
thread on this node that thread `i` (on another node) communicates with.
|
||||
**/
|
||||
std::array<uint8_t, max_size> registered_in;
|
||||
/**
|
||||
registered_out. This array contains the IDs of all threads that have been
|
||||
registered as sending a message from this socket. If registered_out[i] !=
|
||||
tombstone, then thread `i` here is communicating with the thread with ID
|
||||
registered_out[i] on another node.
|
||||
**/
|
||||
std::array<uint8_t, max_size> registered_out;
|
||||
|
||||
/**
|
||||
incoming. This vector of vector contains temporary storage for each
|
||||
thread. In particular, incoming[i] contains any read messages for thread
|
||||
`i`. This may be populated by other threads to prevent locking.
|
||||
**/
|
||||
std::array<std::vector<char>, max_size> incoming;
|
||||
|
||||
unsigned find_first_tombstone() const noexcept;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,332 +0,0 @@
|
||||
#include "../doctest.h"
|
||||
#include "TLSSocket.hpp"
|
||||
#include "ThreadSafeSSL.hpp"
|
||||
#include <numeric>
|
||||
#define SOCKET_SETUP
|
||||
#include "TestUtil.hpp"
|
||||
#include <thread>
|
||||
|
||||
//! [ThreadSafeSSLConstructor]
|
||||
TEST_CASE("constructor") {
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
assert(context);
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
ThreadSafeSSL tssl(ssl.get());
|
||||
CHECK(tssl.get_ssl() == ssl.get());
|
||||
}
|
||||
//! [ThreadSafeSSLConstructor]
|
||||
|
||||
//! [ThreadSafeSSLRegisterNewSocket]
|
||||
TEST_CASE("constructor") {
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
ThreadSafeSSL tssl(ssl.get());
|
||||
|
||||
SUBCASE("linear") {
|
||||
CHECK(tssl.register_new_socket() == 0);
|
||||
CHECK(tssl.register_new_socket() == 1);
|
||||
CHECK(tssl.register_new_socket() == 2);
|
||||
}
|
||||
|
||||
SUBCASE("loop") {
|
||||
for (unsigned i = 0; i < 100; i++) {
|
||||
CHECK(tssl.register_new_socket() == i);
|
||||
}
|
||||
}
|
||||
}
|
||||
//! [ThreadSafeSSLRegisterNewSocket]
|
||||
|
||||
//! [ThreadSafeSSLRegisteredIn]
|
||||
TEST_CASE("registered_in") {
|
||||
// This function checks that a registered thread is actually in the initial
|
||||
// table.
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
std::unique_ptr<TLSSocket> server, client;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
}
|
||||
//! [ThreadSafeSSLRegisteredIn]
|
||||
|
||||
//! [ThreadSafeSSLSend]
|
||||
TEST_CASE("send") {
|
||||
// This just checks that the sending code works as we'd expect.
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
bssl::UniquePtr<SSL> ssl(SSL_new(context.get()));
|
||||
std::unique_ptr<TLSSocket> server, client;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
|
||||
// We'll use the server as the TSSL object.
|
||||
ThreadSafeSSL tssl(server->get_ssl_object());
|
||||
|
||||
SUBCASE("random data, first time") {
|
||||
|
||||
const auto tag = tssl.register_new_socket();
|
||||
std::array<uint8_t, 20> data;
|
||||
for (unsigned i = 0; i < data.size(); i++) {
|
||||
data[i] = static_cast<uint8_t>(rand());
|
||||
}
|
||||
|
||||
auto send_code = [&]() { tssl.send(tag, data.data(), data.size()); };
|
||||
CHECK(!tssl.is_registered_out(tag));
|
||||
std::thread server_code(send_code);
|
||||
// The message is split into two portions: the header first, then the
|
||||
// message itself.
|
||||
struct Packed {
|
||||
unsigned header : 8;
|
||||
size_t size : 56;
|
||||
};
|
||||
|
||||
Packed header;
|
||||
const auto nr_bytes =
|
||||
SSL_read(client->get_ssl_object(), &header, sizeof(header));
|
||||
CHECK(nr_bytes == sizeof(header));
|
||||
CHECK((unsigned)header.header == tag);
|
||||
CHECK((size_t)header.size == data.size());
|
||||
|
||||
// Now we expect the message to be of size 20 * sizeof(uint8_t).
|
||||
uint8_t arr[data.size()];
|
||||
const auto nr_bytes_mes =
|
||||
SSL_read(client->get_ssl_object(), arr, sizeof(arr));
|
||||
REQUIRE(nr_bytes_mes == sizeof(arr));
|
||||
CHECK(memcmp(arr, data.data(), data.size()) == 0);
|
||||
|
||||
// Now to get the server to terminate we need to write a tag back.
|
||||
// We'll just use 0.
|
||||
const uint8_t temp_tag = 0;
|
||||
CHECK(!tssl.is_registered_in(temp_tag));
|
||||
const auto written =
|
||||
SSL_write(client->get_ssl_object(), &temp_tag, sizeof(uint8_t));
|
||||
REQUIRE(written == sizeof(uint8_t));
|
||||
// We want to make sure the socket actually terminates.
|
||||
server_code.join();
|
||||
|
||||
// Check that the tags are actually in the table.
|
||||
//! [ThreadSafeSSLRegisteredInOut]
|
||||
CHECK(tssl.is_registered_in(temp_tag));
|
||||
CHECK(tssl.is_registered_out(tag));
|
||||
//! [ThreadSafeSSLRegisteredInOut]
|
||||
}
|
||||
}
|
||||
//! [ThreadSafeSSLSend]
|
||||
|
||||
//! [ThreadSafeSSLMultiSend]
|
||||
TEST_CASE("multi_send") {
|
||||
// This test case checks that both send and recv work together.
|
||||
// We do this by first establishing a set of connections between two tssl
|
||||
// objects and then checking that the reads work as expected.
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
std::unique_ptr<TLSSocket> server, client;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
|
||||
ThreadSafeSSL client_tssl(client->get_ssl_object());
|
||||
ThreadSafeSSL server_tssl(server->get_ssl_object());
|
||||
|
||||
// Check that establishing the connection works.
|
||||
auto server_code = [&]() {
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
// This counts from 0, so it's essentially `i`.
|
||||
server_tssl.register_new_socket();
|
||||
// Here we are going for a 1:1 mapping. This is just for ease of testing.
|
||||
CHECK(!server_tssl.is_registered_out(i));
|
||||
CHECK(!server_tssl.is_registered_in(i));
|
||||
server_tssl.send(i, &i, sizeof(i));
|
||||
CHECK(server_tssl.is_registered_out(i));
|
||||
CHECK(server_tssl.is_registered_in(i));
|
||||
}
|
||||
};
|
||||
|
||||
std::thread server_establish(server_code);
|
||||
|
||||
unsigned tmp_storage;
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_tssl.register_new_socket();
|
||||
// Here we are going for a 1:1 mapping. This is just for ease of testing.
|
||||
CHECK(!client_tssl.is_registered_out(i));
|
||||
CHECK(!client_tssl.is_registered_in(i));
|
||||
client_tssl.recv(i, &tmp_storage, sizeof(tmp_storage));
|
||||
CHECK(tmp_storage == i);
|
||||
CHECK(client_tssl.is_registered_out(i));
|
||||
CHECK(client_tssl.is_registered_in(i));
|
||||
}
|
||||
|
||||
server_establish.join();
|
||||
|
||||
// Now we are going to send `i`, but in reverse. This is entirely to check
|
||||
// that the buffering works properly.
|
||||
SUBCASE("Sequential") {
|
||||
auto server_send_backwards = [&]() {
|
||||
for (int i = 9; i >= 0; i--) {
|
||||
unsigned as_unsigned = static_cast<unsigned>(i);
|
||||
server_tssl.send(as_unsigned, &as_unsigned, sizeof(as_unsigned));
|
||||
}
|
||||
};
|
||||
|
||||
std::thread server_backwards_thread(server_send_backwards);
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_tssl.recv(i, &tmp_storage, sizeof(tmp_storage));
|
||||
CHECK(tmp_storage == i);
|
||||
}
|
||||
|
||||
server_backwards_thread.join();
|
||||
}
|
||||
|
||||
SUBCASE("small_threaded") {
|
||||
// Now we'll do it from many threads at once, to many threads at once.
|
||||
// This is primarily to check that the buffering etc is actually thread
|
||||
// safe.
|
||||
std::array<std::thread, 10> client_threads;
|
||||
std::array<std::thread, 10> server_threads;
|
||||
|
||||
auto client_send_code = [&](const unsigned i) {
|
||||
unsigned t_tmp_storage;
|
||||
client_tssl.recv(i, &t_tmp_storage, sizeof(t_tmp_storage));
|
||||
CHECK(t_tmp_storage == i);
|
||||
};
|
||||
|
||||
auto server_send_code = [&](const unsigned i) {
|
||||
server_tssl.send(i, &i, sizeof(i));
|
||||
};
|
||||
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_threads[i] = std::thread(client_send_code, i);
|
||||
server_threads[i] = std::thread(server_send_code, 9 - i);
|
||||
}
|
||||
|
||||
// N.B these joins are in separate threads to make sure that
|
||||
// the data from server thread has been sent before we require
|
||||
// any client threads to stop.
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
server_threads[i].join();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_threads[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
SUBCASE("large_threaded") {
|
||||
// Now we'll do it from many threads at once, to many threads at once, but
|
||||
// with lots of data. This is primarily to check whether we actually
|
||||
// can handle large batches of data.
|
||||
std::array<std::thread, 10> client_threads;
|
||||
std::array<std::thread, 10> server_threads;
|
||||
|
||||
using BufType = std::array<char, 2 * SSL3_RT_MAX_PLAIN_LENGTH>;
|
||||
|
||||
auto client_send_code = [&](const unsigned i) {
|
||||
BufType data;
|
||||
std::iota(data.begin(), data.end(), i);
|
||||
BufType in;
|
||||
client_tssl.recv(i, in.data(), sizeof(in));
|
||||
CHECK(in == data);
|
||||
};
|
||||
|
||||
auto server_send_code = [&](const unsigned i) {
|
||||
BufType data;
|
||||
std::iota(data.begin(), data.end(), i);
|
||||
server_tssl.send(i, data.data(), sizeof(data));
|
||||
};
|
||||
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_threads[i] = std::thread(client_send_code, i);
|
||||
server_threads[i] = std::thread(server_send_code, 9 - i);
|
||||
}
|
||||
|
||||
// N.B these joins are in separate threads to make sure that
|
||||
// the data from server thread has been sent before we require
|
||||
// any client threads to stop.
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
server_threads[i].join();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < 10; i++) {
|
||||
client_threads[i].join();
|
||||
}
|
||||
}
|
||||
}
|
||||
//! [ThreadSafeSSLMultiSend]
|
||||
|
||||
//! [ThreadSafeSSLMismatchedSizeSends]
|
||||
TEST_CASE("mismatched_size_send") {
|
||||
// This test case is to make sure that if one thread sends more data than is
|
||||
// expected to the receiver that everything still works.
|
||||
|
||||
auto context = CreateContextWithTestCertificate(TLS_method());
|
||||
std::unique_ptr<TLSSocket> server, client;
|
||||
REQUIRE(setup_sockets(context, server, client));
|
||||
|
||||
ThreadSafeSSL client_tssl(client->get_ssl_object());
|
||||
ThreadSafeSSL server_tssl(server->get_ssl_object());
|
||||
|
||||
// We're going to do everything over tag 0. This means that the threads that
|
||||
// are used for communicating both have a tag of 0. This is just for ease of
|
||||
// testing: the behaviour we're simulating doesn't really require a particular
|
||||
// thread to be used.
|
||||
constexpr unsigned tag = 0;
|
||||
|
||||
auto server_code = [&]() {
|
||||
// Just register a single new socket.
|
||||
server_tssl.register_new_socket();
|
||||
CHECK(!server_tssl.is_registered_out(tag));
|
||||
CHECK(!server_tssl.is_registered_in(tag));
|
||||
server_tssl.send(tag, &tag, sizeof(tag));
|
||||
CHECK(server_tssl.is_registered_out(tag));
|
||||
CHECK(server_tssl.is_registered_in(tag));
|
||||
};
|
||||
|
||||
std::thread server_establish(server_code);
|
||||
|
||||
// Now we'll do the same setup for the client.
|
||||
|
||||
unsigned tmp_storage;
|
||||
client_tssl.register_new_socket();
|
||||
CHECK(!client_tssl.is_registered_out(tag));
|
||||
CHECK(!client_tssl.is_registered_in(tag));
|
||||
client_tssl.recv(tag, &tmp_storage, sizeof(tmp_storage));
|
||||
CHECK(tmp_storage == tag);
|
||||
CHECK(client_tssl.is_registered_out(tag));
|
||||
CHECK(client_tssl.is_registered_in(tag));
|
||||
|
||||
server_establish.join();
|
||||
|
||||
// Now we want to do the actual testing: sending mismatched data over.
|
||||
SUBCASE("small mismatch") {
|
||||
constexpr std::array<unsigned, 2> test_arr{1, 2};
|
||||
auto server_send = [&]() {
|
||||
server_tssl.send(tag, test_arr.data(),
|
||||
test_arr.size() * sizeof(unsigned));
|
||||
};
|
||||
|
||||
std::thread server_send_thread(server_send);
|
||||
unsigned recv{};
|
||||
for (unsigned i = 0; i < test_arr.size(); i++) {
|
||||
client_tssl.recv(tag, &recv, sizeof(unsigned));
|
||||
CHECK(recv == test_arr[i]);
|
||||
}
|
||||
|
||||
server_send_thread.join();
|
||||
}
|
||||
|
||||
SUBCASE("large mismatch") {
|
||||
std::vector<unsigned> vals(10000);
|
||||
for (auto &val : vals) {
|
||||
val = static_cast<unsigned>(rand());
|
||||
}
|
||||
|
||||
auto server_send = [&]() {
|
||||
server_tssl.send(tag, vals.data(),
|
||||
static_cast<unsigned>(vals.size() * sizeof(unsigned)));
|
||||
};
|
||||
|
||||
std::thread server_send_thread(server_send);
|
||||
unsigned recv{};
|
||||
for (unsigned i = 0; i < vals.size(); i++) {
|
||||
client_tssl.recv(tag, &recv, sizeof(unsigned));
|
||||
CHECK(recv == vals[i]);
|
||||
}
|
||||
|
||||
server_send_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
//! [ThreadSafeSSLMismatchedSizeSends]
|
||||
@@ -796,9 +796,10 @@ static bool run_traffic_circuit_internal(
|
||||
std::copy(output.cbegin() + offset,
|
||||
output.cbegin() + offset + out.client_iv.size(),
|
||||
out.client_iv.begin());
|
||||
offset += out.client_iv.size() + out.server_key_share.size();
|
||||
// These are also fine to cast.
|
||||
offset += unsigned(out.client_iv.size() + out.server_key_share.size());
|
||||
std::copy(output.cbegin() + offset, output.cend(), out.server_iv.begin());
|
||||
offset += out.server_iv.size();
|
||||
offset += unsigned(out.server_iv.size());
|
||||
} else {
|
||||
const auto copy_func = [&](auto &dest,
|
||||
const unsigned inc = sizeof(decltype(dest))) {
|
||||
|
||||
Reference in New Issue
Block a user