reworked how the server and client were structured so that they could be more usable

This commit is contained in:
Elias
2022-03-22 21:42:36 +00:00
parent cf2e1bc4f0
commit d447d874dd
18 changed files with 1118 additions and 108 deletions

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.17...3.22)
cmake_minimum_required(VERSION 3.13.4)
project(iree-wrapper VERSION 1.0 LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11)
@@ -8,5 +8,10 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_library(asio INTERFACE)
target_compile_options(asio INTERFACE ASIO_STANDALONE)
target_include_directories(asio INTERFACE thirdparty/asio/asio/include)
target_link_libraries(asio INTERFACE pthread)
add_subdirectory(thirdparty/iree EXCLUDE_FROM_ALL)
add_subdirectory(dSHARK)
add_subdirectory(dSHARK)

View File

@@ -1,78 +1,7 @@
project(lws-minimal-http-server-form-post-file C)
cmake_minimum_required(VERSION 2.8.12)
find_package(libwebsockets CONFIG REQUIRED)
list(APPEND CMAKE_MODULE_PATH ${LWS_CMAKE_DIR})
message("${LWS_CMAKE_DIR}")
message("${LIBWEBSOCKETS_DEP_LIBS}")
include(CheckCSourceCompiles)
include(LwsCheckRequirements)
project(dSHARK CXX)
cmake_minimum_required(VERSION 3.13.4)
set(_TRANSLATE_TOOL_EXECUTABLE $<TARGET_FILE:iree_tools_iree-translate>)
# Define arguments passed to iree-translate
set(_ARGS)
list(APPEND _ARGS "-iree-input-type=mhlo")
list(APPEND _ARGS "-iree-mlir-to-vm-bytecode-module")
list(APPEND _ARGS "-iree-hal-target-backends=cuda")
# Uncomment the line below to use vulkan-spirv backend
#list(APPEND _ARGS "-iree-hal-target-backends=vulkan-spirv")
list(APPEND _ARGS "${CMAKE_CURRENT_SOURCE_DIR}/simple_embedding_test.mlir")
list(APPEND _ARGS "-o")
list(APPEND _ARGS "simple_embedding_test_bytecode_module_cuda_c.vmfb")
# Translate MLIR file to VM bytecode module
add_custom_command(
OUTPUT "simple_embedding_test_bytecode_module_cuda_c.vmfb"
COMMAND ${_TRANSLATE_TOOL_EXECUTABLE} ${_ARGS}
DEPENDS iree_tools_iree-translate
)
#-------------------------------------------------------------------------------
# Embedd the VM bytcode module into a c file via `generate_embed_data`.
#-------------------------------------------------------------------------------
# Define arguments passed to generate_embed_data
set(_ARGS)
list(APPEND _ARGS "--output_header=simple_embedding_test_bytecode_module_cuda_c.h")
list(APPEND _ARGS "--output_impl=simple_embedding_test_bytecode_module_cuda_c.c")
list(APPEND _ARGS "--identifier=simple_embedding_test_bytecode_module_cuda_c")
list(APPEND _ARGS "--flatten")
list(APPEND _ARGS "simple_embedding_test_bytecode_module_cuda_c.vmfb")
# Embed VM bytecode module into c source file
add_custom_command(
OUTPUT
"simple_embedding_test_bytecode_module_cuda_c.h"
"simple_embedding_test_bytecode_module_cuda_c.c"
COMMAND generate_embed_data ${_ARGS}
DEPENDS generate_embed_data simple_embedding_test_bytecode_module_cuda_c.vmfb
)
add_library(simple_embedding_test_bytecode_module_cuda_c STATIC "")
target_sources(simple_embedding_test_bytecode_module_cuda_c
PRIVATE
simple_embedding_test_bytecode_module_cuda_c.c
simple_embedding_test_bytecode_module_cuda_c.h
)
set(SAMP lws-minimal-http-server-form-post-file)
set(SRCS minimal-http-server-form-post-file.c)
set(requirements 1)
require_lws_config(LWS_ROLE_H1 1 requirements)
require_lws_config(LWS_WITH_SERVER 1 requirements)
require_lws_config(LWS_WITH_FILE_OPS 1 requirements)
if (requirements)
add_executable(${SAMP} ${SRCS})
if (websockets_shared)
target_link_libraries(${SAMP} websockets_shared ${LIBWEBSOCKETS_DEP_LIBS} simple_embedding_test_bytecode_module_cuda_c iree_base_base iree_hal_hal iree_hal_cuda_registration_registration iree_modules_hal_hal iree_vm_vm iree_vm_bytecode_module iree_runtime_runtime iree_hal_local_loaders_system_library_loader iree_hal_local_sync_driver)
add_dependencies(${SAMP} websockets_shared)
else()
target_link_libraries(${SAMP} websockets ${LIBWEBSOCKETS_DEP_LIBS} simple_embedding_test_bytecode_module_cuda_c iree_base_base iree_hal_hal iree_hal_cuda_registration_registration iree_modules_hal_hal iree_vm_vm iree_vm_bytecode_module iree_runtime_runtime iree_hal_local_loaders_system_library_loader iree_hal_local_sync_driver)
endif()
else()
message("not requirements")
endif()
add_subdirectory(dshark_network_lib)
add_subdirectory(dshark_run_module)
add_subdirectory(dshark_client)
add_subdirectory(dshark_server)

View File

@@ -0,0 +1,9 @@
cmake_minimum_required(VERSION 3.13.4)
project(dshark_client)
add_executable(dshark_client dshark_client.cpp)
target_include_directories(dshark_client PRIVATE ../dshark_network_lib)
target_link_directories(dshark_client PRIVATE ../dshark_network_lib)
target_link_libraries(dshark_client dshark_network_lib pthread)

View File

@@ -0,0 +1,123 @@
#include <iostream>
#include "dshark_common.h"
#include "dshark_message.h"
#include "dshark_client.h"
#include "dshark_server.h"
#include <asio.hpp>
#include <fstream>
#include <iterator>
enum class dSHARKMessageType : uint32_t
{
EvaluateBinary,
ServerAccept,
ServerDeny,
ServerPing,
MessageAll,
ServerMessage,
};
class CustomClient : public dshark::client_interface<dSHARKMessageType>
{
public:
void PingServer()
{
dshark::message<dSHARKMessageType> msg;
msg.header.id = dSHARKMessageType::ServerPing;
// Caution with this...
std::chrono::system_clock::time_point timeNow = std::chrono::system_clock::now();
msg << timeNow;
Send(msg);
}
void MessageAll()
{
dshark::message<dSHARKMessageType> msg;
msg.header.id = dSHARKMessageType::MessageAll;
Send(msg);
}
void EvaluateBinary(std::string filepath)
{
std::cout << "Loading Binary" << std::endl;
std::ifstream input(filepath, std::ios::binary);
std::vector<unsigned char> buffer(std::istreambuf_iterator<char>(input), {});
dshark::message<dSHARKMessageType> msg;
msg.header.id = dSHARKMessageType::EvaluateBinary;
for (int i = 0; i < buffer.size(); i++)
{
msg << buffer[i];
}
Send(msg);
}
};
int main(int argc, char* argv[])
{
CustomClient c;
c.Connect("127.0.0.1", 60000);
bool bQuit = false;
std::cout << "Sending File " << argv[1] << std::endl;
c.EvaluateBinary((std::string)argv[1]);
while (!bQuit)
{
if (c.IsConnected())
{
if (!c.Incoming().empty())
{
auto msg = c.Incoming().pop_front().msg;
switch (msg.header.id)
{
case dSHARKMessageType::ServerAccept:
{
// Server has responded to a ping request
std::cout << "Server Accepted Connection\n";
}
break;
case dSHARKMessageType::EvaluateBinary:
{
std::cout << "Binary File Sent" << std::endl;
}
break;
case dSHARKMessageType::ServerPing:
{
// Server has responded to a ping request
std::chrono::system_clock::time_point timeNow = std::chrono::system_clock::now();
std::chrono::system_clock::time_point timeThen;
msg >> timeThen;
std::cout << "Ping: " << std::chrono::duration<double>(timeNow - timeThen).count() << "\n";
}
break;
case dSHARKMessageType::ServerMessage:
{
// Server has responded to a ping request
uint32_t clientID;
msg >> clientID;
std::cout << "Hello from [" << clientID << "]\n";
}
break;
}
}
}
else
{
std::cout << "Server Down\n";
bQuit = true;
}
}
return 0;
}

View File

@@ -1,24 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
#define IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_
#include "iree/base/api.h"
#include "iree/hal/api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
IREE_API_EXPORT iree_status_t
iree_hal_cuda_driver_module_register(iree_hal_driver_registry_t* registry, int index);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // IREE_HAL_CUDA_REGISTRATION_DRIVER_MODULE_H_

View File

@@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.13.4)
add_library(dshark_network_lib dshark_common.h dshark_queue.h dshark_message.h dshark_connection.h dshark_client.h dshark_server.h)
set_target_properties(dshark_network_lib PROPERTIES LINKER_LANGUAGE CXX)

View File

@@ -0,0 +1,96 @@
#pragma once
#include "dshark_common.h"
#include "dshark_message.h"
#include "dshark_queue.h"
#include "dshark_connection.h"
namespace dshark {
template <typename T>
class client_interface
{
public:
client_interface() : m_socket(m_context)
{
}
virtual ~client_interface()
{
Disconnect();
}
public:
bool Connect(const std::string& host, const uint16_t port)
{
try
{
// Resolve hostname/ip-address into tangiable physical address
asio::ip::tcp::resolver resolver(m_context);
asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(host, std::to_string(port));
// Create connection
m_connection = std::make_unique<connection<T>>(connection<T>::owner::client, m_context, asio::ip::tcp::socket(m_context), m_qMessagesIn);
// Tell the connection object to connect to server
m_connection->ConnectToServer(endpoints);
// Start Context Thread
thrContext = std::thread([this]() { m_context.run(); });
}
catch (std::exception& e)
{
std::cerr << "Client Exception: " << e.what() << "\n";
return false;
}
return true;
}
void Disconnect()
{
// If connection exists, and it's connected then...
if (IsConnected())
{
// ...disconnect from server gracefully
m_connection->Disconnect();
}
// Either way, we're also done with the asio context...
m_context.stop();
// ...and its thread
if (thrContext.joinable())
thrContext.join();
// Destroy the connection object
m_connection.release();
}
bool IsConnected()
{
if (m_connection) {
return m_connection->IsConnected();
}
else {
return false;
}
}
void Send(const message<T>& msg)
{
if (IsConnected())
m_connection->Send(msg);
}
// Retrieve queue of messages from server
dsqueue<owned_message<T>>& Incoming()
{
return m_qMessagesIn;
}
protected:
asio::io_context m_context;
std::thread thrContext;
asio::ip::tcp::socket m_socket;
std::unique_ptr<connection<T>> m_connection;
private:
dsqueue<owned_message<T>> m_qMessagesIn;
};
}

View File

@@ -0,0 +1,18 @@
#pragma once
#include <memory>
#include <thread>
#include <mutex>
#include <deque>
#include <optional>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cstdint>
#include <chrono>
#include <stdio.h>
#define ASIO_STANDALONE
#include <asio.hpp>
#include <asio/ts/buffer.hpp>
#include <asio/ts/internet.hpp>

View File

@@ -0,0 +1,291 @@
#pragma once
#include "dshark_common.h"
#include "dshark_message.h"
#include "dshark_queue.h"
namespace dshark {
template <typename T>
class connection : public std::enable_shared_from_this<connection<T>>
{
public:
// A connection is "owned" by either a server or a client, and its
// behaviour is slightly different bewteen the two.
enum class owner
{
server,
client
};
public:
// Constructor: Specify Owner, connect to context, transfer the socket
// Provide reference to incoming message queue
connection(owner parent, asio::io_context& asioContext, asio::ip::tcp::socket socket, dsqueue<owned_message<T>>& qIn)
: m_asioContext(asioContext), m_socket(std::move(socket)), m_qMessagesIn(qIn)
{
m_nOwnerType = parent;
}
virtual ~connection()
{}
// This ID is used system wide - its how clients will understand other clients
// exist across the whole system.
uint32_t GetID() const
{
return id;
}
public:
void ConnectToClient(uint32_t uid = 0)
{
if (m_nOwnerType == owner::server)
{
if (m_socket.is_open())
{
id = uid;
ReadHeader();
}
}
}
void ConnectToServer(const asio::ip::tcp::resolver::results_type& endpoints)
{
// Only clients can connect to servers
if (m_nOwnerType == owner::client)
{
// Request asio attempts to connect to an endpoint
asio::async_connect(m_socket, endpoints,
[this](std::error_code ec, asio::ip::tcp::endpoint endpoint)
{
if (!ec)
{
ReadHeader();
}
});
}
}
void Disconnect()
{
if (IsConnected())
asio::post(m_asioContext, [this]() { m_socket.close(); });
}
bool IsConnected() const
{
return m_socket.is_open();
}
// Prime the connection to wait for incoming messages
void StartListening()
{
}
public:
// ASYNC - Send a message, connections are one-to-one so no need to specifiy
// the target, for a client, the target is the server and vice versa
void Send(const message<T>& msg)
{
asio::post(m_asioContext,
[this, msg]()
{
// If the queue has a message in it, then we must
// assume that it is in the process of asynchronously being written.
// Either way add the message to the queue to be output. If no messages
// were available to be written, then start the process of writing the
// message at the front of the queue.
bool bWritingMessage = !m_qMessagesOut.empty();
m_qMessagesOut.push_back(msg);
if (!bWritingMessage)
{
WriteHeader();
}
});
}
private:
// ASYNC - Prime context to write a message header
void WriteHeader()
{
// If this function is called, we know the outgoing message queue must have
// at least one message to send. So allocate a transmission buffer to hold
// the message, and issue the work - asio, send these bytes
asio::async_write(m_socket, asio::buffer(&m_qMessagesOut.front().header, sizeof(message_header<T>)),
[this](std::error_code ec, std::size_t length)
{
// asio has now sent the bytes - if there was a problem
// an error would be available...
if (!ec)
{
// ... no error, so check if the message header just sent also
// has a message body...
if (m_qMessagesOut.front().body.size() > 0)
{
// ...it does, so issue the task to write the body bytes
WriteBody();
}
else
{
// ...it didnt, so we are done with this message. Remove it from
// the outgoing message queue
m_qMessagesOut.pop_front();
// If the queue is not empty, there are more messages to send, so
// make this happen by issuing the task to send the next header.
if (!m_qMessagesOut.empty())
{
WriteHeader();
}
}
}
else
{
// ...asio failed to write the message, we could analyse why but
// for now simply assume the connection has died by closing the
// socket. When a future attempt to write to this client fails due
// to the closed socket, it will be tidied up.
std::cout << "[" << id << "] Write Header Fail.\n";
m_socket.close();
}
});
}
// ASYNC - Prime context to write a message body
void WriteBody()
{
// If this function is called, a header has just been sent, and that header
// indicated a body existed for this message. Fill a transmission buffer
// with the body data, and send it!
asio::async_write(m_socket, asio::buffer(m_qMessagesOut.front().body.data(), m_qMessagesOut.front().body.size()),
[this](std::error_code ec, std::size_t length)
{
if (!ec)
{
// Sending was successful, so we are done with the message
// and remove it from the queue
m_qMessagesOut.pop_front();
// If the queue still has messages in it, then issue the task to
// send the next messages' header.
if (!m_qMessagesOut.empty())
{
WriteHeader();
}
}
else
{
// Sending failed, see WriteHeader() equivalent for description :P
std::cout << "[" << id << "] Write Body Fail.\n";
m_socket.close();
}
});
}
// ASYNC - Prime context ready to read a message header
void ReadHeader()
{
// If this function is called, we are expecting asio to wait until it receives
// enough bytes to form a header of a message. We know the headers are a fixed
// size, so allocate a transmission buffer large enough to store it. In fact,
// we will construct the message in a "temporary" message object as it's
// convenient to work with.
asio::async_read(m_socket, asio::buffer(&m_msgTemporaryIn.header, sizeof(message_header<T>)),
[this](std::error_code ec, std::size_t length)
{
if (!ec)
{
// A complete message header has been read, check if this message
// has a body to follow...
if (m_msgTemporaryIn.header.size > 0)
{
// ...it does, so allocate enough space in the messages' body
// vector, and issue asio with the task to read the body.
m_msgTemporaryIn.body.resize(m_msgTemporaryIn.header.size);
ReadBody();
}
else
{
// it doesn't, so add this bodyless message to the connections
// incoming message queue
AddToIncomingMessageQueue();
}
}
else
{
// Reading form the client went wrong, most likely a disconnect
// has occurred. Close the socket and let the system tidy it up later.
std::cout << "[" << id << "] Read Header Fail.\n";
m_socket.close();
}
});
}
// ASYNC - Prime context ready to read a message body
void ReadBody()
{
// If this function is called, a header has already been read, and that header
// request we read a body, The space for that body has already been allocated
// in the temporary message object, so just wait for the bytes to arrive...
asio::async_read(m_socket, asio::buffer(m_msgTemporaryIn.body.data(), m_msgTemporaryIn.body.size()),
[this](std::error_code ec, std::size_t length)
{
if (!ec)
{
// ...and they have! The message is now complete, so add
// the whole message to incoming queue
AddToIncomingMessageQueue();
}
else
{
// As above!
std::cout << "[" << id << "] Read Body Fail.\n";
m_socket.close();
}
});
}
// Once a full message is received, add it to the incoming queue
void AddToIncomingMessageQueue()
{
// Shove it in queue, converting it to an "owned message", by initialising
// with the a shared pointer from this connection object
if (m_nOwnerType == owner::server)
m_qMessagesIn.push_back({ this->shared_from_this(), m_msgTemporaryIn });
else
m_qMessagesIn.push_back({ nullptr, m_msgTemporaryIn });
// We must now prime the asio context to receive the next message. It
// wil just sit and wait for bytes to arrive, and the message construction
// process repeats itself. Clever huh?
ReadHeader();
}
protected:
// Each connection has a unique socket to a remote
asio::ip::tcp::socket m_socket;
// This context is shared with the whole asio instance
asio::io_context& m_asioContext;
// This queue holds all messages to be sent to the remote side
// of this connection
dsqueue<message<T>> m_qMessagesOut;
// This references the incoming queue of the parent object
dsqueue<owned_message<T>>& m_qMessagesIn;
// Incoming messages are constructed asynchronously, so we will
// store the part assembled message here, until it is ready
message<T> m_msgTemporaryIn;
// The "owner" decides how some of the connection behaves
owner m_nOwnerType = owner::server;
uint32_t id = 0;
};
}

View File

@@ -0,0 +1,90 @@
#pragma once
#include "dshark_common.h"
namespace dshark
{
template <typename T>
struct message_header
{
T id{};
uint32_t size = 0;
};
template <typename T>
struct message
{
message_header<T> header{};
std::vector<uint8_t> body;
size_t size() const
{
return body.size();
}
friend std::ostream& operator << (std::ostream& os, const message<T>& msg)
{
os << "ID:" << int(msg.header.id) << " Size:" << msg.header.size;
return os;
}
template<typename DataType>
friend message<T>& operator << (message<T>& msg, const DataType& data)
{
// Check that the type of the data being pushed is trivially copyable
static_assert(std::is_standard_layout<DataType>::value, "Data is too complex to be pushed into vector");
// Cache current size of vector, as this will be the point we insert the data
size_t i = msg.body.size();
// Resize the vector by the size of the data being pushed
msg.body.resize(msg.body.size() + sizeof(DataType));
// Physically copy the data into the newly allocated vector space
std::memcpy(msg.body.data() + i, &data, sizeof(DataType));
// Recalculate the message size
msg.header.size = msg.size();
// Return the target message so it can be "chained"
return msg;
}
template<typename DataType>
friend message<T>& operator >> (message<T>& msg, DataType& data)
{
// Check that the type of the data being pushed is trivially copyable
static_assert(std::is_standard_layout<DataType>::value, "Data is too complex to be pulled from vector");
// Cache the location towards the end of the vector where the pulled data starts
size_t i = msg.body.size() - sizeof(DataType);
// Physically copy the data from the vector into the user variable
std::memcpy(&data, msg.body.data() + i, sizeof(DataType));
// Shrink the vector to remove read bytes, and reset end position
msg.body.resize(i);
// Recalculate the message size
msg.header.size = msg.size();
// Return the target message so it can be "chained"
return msg;
}
};
template <typename T>
class connection;
template <typename T>
struct owned_message
{
std::shared_ptr<connection<T>> remote = nullptr;
message<T> msg;
// Again, a friendly string maker
friend std::ostream& operator<<(std::ostream& os, const owned_message<T>& msg)
{
os << msg.msg;
return os;
}
};
}

View File

@@ -0,0 +1,99 @@
#pragma once
#include "dshark_common.h"
namespace dshark {
template<typename T>
class dsqueue
{
public:
dsqueue() = default;
dsqueue(const dsqueue<T>&) = delete;
virtual ~dsqueue() { clear(); }
public:
const T& front()
{
std::scoped_lock lock(muxQueue);
return deqQueue.front();
}
const T& back()
{
std::scoped_lock lock(muxQueue);
return deqQueue.back();
}
T pop_front()
{
std::scoped_lock lock(muxQueue);
auto t = std::move(deqQueue.front());
deqQueue.pop_front();
return t;
}
// Removes and returns item from back of Queue
T pop_back()
{
std::scoped_lock lock(muxQueue);
auto t = std::move(deqQueue.back());
deqQueue.pop_back();
return t;
}
// Adds an item to back of Queue
void push_back(const T& item)
{
std::scoped_lock lock(muxQueue);
deqQueue.emplace_back(std::move(item));
std::unique_lock<std::mutex> ul(muxBlocking);
cvBlocking.notify_one();
}
// Adds an item to front of Queue
void push_front(const T& item)
{
std::scoped_lock lock(muxQueue);
deqQueue.emplace_front(std::move(item));
std::unique_lock<std::mutex> ul(muxBlocking);
cvBlocking.notify_one();
}
// Returns true if Queue has no items
bool empty()
{
std::scoped_lock lock(muxQueue);
return deqQueue.empty();
}
// Returns number of items in Queue
size_t count()
{
std::scoped_lock lock(muxQueue);
return deqQueue.size();
}
//Clears Queue
void clear()
{
std::scoped_lock lock(muxQueue);
return deqQueue.clear();
}
void wait()
{
while (empty())
{
std::unique_lock<std::mutex> ul(muxBlocking);
cvBlocking.wait(ul);
}
}
protected:
std::mutex muxQueue;
std::deque<T> deqQueue;
std::condition_variable cvBlocking;
std::mutex muxBlocking;
};
}

View File

@@ -0,0 +1,237 @@
#pragma once
#include "dshark_common.h"
#include "dshark_queue.h"
#include "dshark_message.h"
#include "dshark_connection.h"
namespace dshark {
template<typename T>
class server_interface
{
public:
// Create a server, ready to listen on specified port
server_interface(uint16_t port)
: m_asioAcceptor(m_asioContext, asio::ip::tcp::endpoint(asio::ip::tcp::v4(), port))
{
}
virtual ~server_interface()
{
// May as well try and tidy up
Stop();
}
// Starts the server!
bool Start()
{
try
{
// Issue a task to the asio context - This is important
// as it will prime the context with "work", and stop it
// from exiting immediately. Since this is a server, we
// want it primed ready to handle clients trying to
// connect.
WaitForClientConnection();
// Launch the asio context in its own thread
m_threadContext = std::thread([this]() { m_asioContext.run(); });
}
catch (std::exception& e)
{
// Something prohibited the server from listening
std::cerr << "[SERVER] Exception: " << e.what() << "\n";
return false;
}
std::cout << "[SERVER] Started!\n";
return true;
}
// Stops the server!
void Stop()
{
// Request the context to close
m_asioContext.stop();
// Tidy up the context thread
if (m_threadContext.joinable()) m_threadContext.join();
// Inform someone, anybody, if they care...
std::cout << "[SERVER] Stopped!\n";
}
// ASYNC - Instruct asio to wait for connection
void WaitForClientConnection()
{
// Prime context with an instruction to wait until a socket connects. This
// is the purpose of an "acceptor" object. It will provide a unique socket
// for each incoming connection attempt
m_asioAcceptor.async_accept(
[this](std::error_code ec, asio::ip::tcp::socket socket)
{
// Triggered by incoming connection request
if (!ec)
{
// Display some useful(?) information
std::cout << "[SERVER] New Connection: " << socket.remote_endpoint() << "\n";
// Create a new connection to handle this client
std::shared_ptr<connection<T>> newconn =
std::make_shared<connection<T>>(connection<T>::owner::server,
m_asioContext, std::move(socket), m_qMessagesIn);
// Give the user server a chance to deny connection
if (OnClientConnect(newconn))
{
// Connection allowed, so add to container of new connections
m_deqConnections.push_back(std::move(newconn));
// And very important! Issue a task to the connection's
// asio context to sit and wait for bytes to arrive!
m_deqConnections.back()->ConnectToClient(nIDCounter++);
std::cout << "[" << m_deqConnections.back()->GetID() << "] Connection Approved\n";
}
else
{
std::cout << "[-----] Connection Denied\n";
// Connection will go out of scope with no pending tasks, so will
// get destroyed automagically due to the wonder of smart pointers
}
}
else
{
// Error has occurred during acceptance
std::cout << "[SERVER] New Connection Error: " << ec.message() << "\n";
}
// Prime the asio context with more work - again simply wait for
// another connection...
WaitForClientConnection();
});
}
// Send a message to a specific client
void MessageClient(std::shared_ptr<connection<T>> client, const message<T>& msg)
{
// Check client is legitimate...
if (client && client->IsConnected())
{
// ...and post the message via the connection
client->Send(msg);
}
else
{
// If we cant communicate with client then we may as
// well remove the client - let the server know, it may
// be tracking it somehow
OnClientDisconnect(client);
// Off you go now, bye bye!
client.reset();
// Then physically remove it from the container
m_deqConnections.erase(
std::remove(m_deqConnections.begin(), m_deqConnections.end(), client), m_deqConnections.end());
}
}
// Send message to all clients
void MessageAllClients(const message<T>& msg, std::shared_ptr<connection<T>> pIgnoreClient = nullptr)
{
bool bInvalidClientExists = false;
// Iterate through all clients in container
for (auto& client : m_deqConnections)
{
// Check client is connected...
if (client && client->IsConnected())
{
// ..it is!
if (client != pIgnoreClient)
client->Send(msg);
}
else
{
// The client couldnt be contacted, so assume it has
// disconnected.
OnClientDisconnect(client);
client.reset();
// Set this flag to then remove dead clients from container
bInvalidClientExists = true;
}
}
// Remove dead clients, all in one go - this way, we dont invalidate the
// container as we iterated through it.
if (bInvalidClientExists)
m_deqConnections.erase(
std::remove(m_deqConnections.begin(), m_deqConnections.end(), nullptr), m_deqConnections.end());
}
// Force server to respond to incoming messages
void Update(size_t nMaxMessages = -1, bool bWait = false)
{
if (bWait) m_qMessagesIn.wait();
// Process as many messages as you can up to the value
// specified
size_t nMessageCount = 0;
while (nMessageCount < nMaxMessages && !m_qMessagesIn.empty())
{
// Grab the front message
auto msg = m_qMessagesIn.pop_front();
// Pass to message handler
OnMessage(msg.remote, msg.msg);
nMessageCount++;
}
}
protected:
// This server class should override thse functions to implement
// customised functionality
// Called when a client connects, you can veto the connection by returning false
virtual bool OnClientConnect(std::shared_ptr<connection<T>> client)
{
return false;
}
// Called when a client appears to have disconnected
virtual void OnClientDisconnect(std::shared_ptr<connection<T>> client)
{
}
// Called when a message arrives
virtual void OnMessage(std::shared_ptr<connection<T>> client, message<T>& msg)
{
}
protected:
// Thread Safe Queue for incoming message packets
dsqueue<owned_message<T>> m_qMessagesIn;
// Container of active validated connections
std::deque<std::shared_ptr<connection<T>>> m_deqConnections;
// Order of declaration is important - it is also the order of initialisation
asio::io_context m_asioContext;
std::thread m_threadContext;
// These things need an asio context
asio::ip::tcp::acceptor m_asioAcceptor; // Handles new incoming connection attempts...
// Clients will be identified in the "wider system" via an ID
uint32_t nIDCounter = 10000;
};
}

View File

@@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.13.4)
add_library(run_module run_module.h run_module.c dshark_driver_module.c)
target_link_libraries(run_module iree_base_base
iree_hal_hal
iree_hal_cuda_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
#iree_runtime_runtime
iree_hal_local_loaders_system_library_loader
iree_hal_local_sync_driver
)
set_target_properties(run_module PROPERTIES LINKER_LANGUAGE C)

View File

@@ -23,9 +23,9 @@
#include "dshark_driver_module.c"
#include "iree/base/tracing.h"
#include "iree/hal/cuda/api.h"
#include "iree/hal/cuda/cuda_driver.c"
//#include "iree/hal/cuda/cuda_driver.c"
#include "iree/hal/cuda/cuda_device.h"
#include "iree/hal/cuda/cuda_device.c"
//#include "iree/hal/cuda/cuda_device.c"
#include "iree/base/internal/file_io.h"
iree_status_t create_sample_device(iree_allocator_t host_allocator,
@@ -116,7 +116,7 @@ iree_status_t Run(char* module_file, int index) {
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
@@ -127,7 +127,7 @@ iree_status_t Run(char* module_file, int index) {
},
iree_make_const_byte_span(kFloat4, sizeof(kFloat4)), &arg0_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
iree_hal_cuda_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
iree_hal_device_allocator(device), shape, IREE_ARRAYSIZE(shape),
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
@@ -196,7 +196,7 @@ iree_status_t Run(char* module_file, int index) {
return iree_ok_status();
}
int run_module(char filename[], int index) {
int run_module(char* filename, int index) {
// fread command
//iree_sample_state_t* sample_state = setup_sample();
//iree_program_state_t* program_state = load_program(sample_state, vmfb_data, vmfb_data_length);

View File

@@ -0,0 +1,4 @@
#pragma once
#include "run_module.c"
int run_module(char* filename[], int index);

View File

@@ -0,0 +1,11 @@
cmake_minimum_required(VERSION 3.13.4)
project(dshark_server CXX C)
add_executable(dshark_server dshark_server.cpp)
target_include_directories(dshark_server PRIVATE ../dshark_network_lib ../dshark_run_module)
target_link_directories(dshark_server PRIVATE ../dshark_network_lib ../dshark_run_module)
target_link_libraries(dshark_server PRIVATE dshark_network_lib pthread run_module)
set_target_properties(dshark_server PROPERTIES LINKER_LANGUAGE CXX)

View File

@@ -0,0 +1,101 @@
#include "dshark_common.h"
#include "dshark_message.h"
#include "dshark_client.h"
#include "dshark_server.h"
#include "run_module.h"
#include <fstream>
enum class dSHARKMessageType : uint32_t
{
EvaluateBinary,
ServerAccept,
ServerDeny,
ServerPing,
MessageAll,
ServerMessage,
};
class CustomServer : public dshark::server_interface<dSHARKMessageType>
{
public:
CustomServer(uint16_t nPort) : dshark::server_interface<dSHARKMessageType>(nPort)
{
}
protected:
virtual bool OnClientConnect(std::shared_ptr<dshark::connection<dSHARKMessageType>> client)
{
dshark::message<dSHARKMessageType> msg;
msg.header.id = dSHARKMessageType::ServerAccept;
client->Send(msg);
return true;
}
// Called when a client appears to have disconnected
virtual void OnClientDisconnect(std::shared_ptr<dshark::connection<dSHARKMessageType>> client)
{
std::cout << "Removing client [" << client->GetID() << "]\n";
}
// Called when a message arrives
virtual void OnMessage(std::shared_ptr<dshark::connection<dSHARKMessageType>> client, dshark::message<dSHARKMessageType>& msg)
{
switch (msg.header.id)
{
case dSHARKMessageType::EvaluateBinary:
{
std::cout << "[" << client->GetID() << "]: evaluate binary\n";
auto myfile = std::fstream("output/file.vmfb", std::ios::out | std::ios::binary);
myfile.write((char*)&msg.body[0], 2048);
myfile.close();
char* f_ = "output/file.vmfb";
run_module(f_, 0);
client->Send(msg);
}
break;
case dSHARKMessageType::ServerPing:
{
std::cout << "[" << client->GetID() << "]: Server Ping\n";
// Simply bounce message back to client
client->Send(msg);
}
break;
case dSHARKMessageType::MessageAll:
{
std::cout << "[" << client->GetID() << "]: Message All\n";
// Construct a new message and send it to all clients
dshark::message<dSHARKMessageType> msg;
msg.header.id = dSHARKMessageType::ServerMessage;
msg << client->GetID();
MessageAllClients(msg, client);
}
break;
}
}
};
int main()
{
CustomServer server(60000);
server.Start();
while (1)
{
server.Update(-1, true);
}
return 0;
}