Replaced custom networking with Boost::Asio.

This commit is contained in:
Robert J. Hansen
2019-09-02 20:58:58 -04:00
parent 9bc12d4369
commit 299432fe28
6 changed files with 157 additions and 376 deletions

View File

@@ -1,27 +1,14 @@
cmake_minimum_required(VERSION 3.4.0)
find_package (PythonInterp REQUIRED)
project(nsrlsvr)
set(VERSION "1.6.1")
set(VERSION "1.7.0")
set(PACKAGE_VERSION ${VERSION})
set(Boost_USE_STATIC_LIBS OFF)
set(Boost_USE_MULTITHREADED ON)
set(Boost_USE_STATIC_RUNTIME OFF)
find_package(Boost 1.60.0 REQUIRED COMPONENTS program_options)
find_package(PythonInterp REQUIRED)
find_package(Threads REQUIRED)
find_package(Boost 1.66.0 REQUIRED COMPONENTS program_options system)
include(GNUInstallDirs)
set(PKGDATADIR ${CMAKE_INSTALL_FULL_DATADIR}/nsrlsvr)
add_subdirectory(src)
add_subdirectory(man1)
set(CPACK_GENERATOR "RPM")
set(CPACK_RPM_PACKAGE_SUMMARY "A server for forensics triage using NIST's NSRL RDS.")
set(CPACK_RPM_PACKAGE_NAME "nsrlsvr")
set(CPACK_PACKAGE_NAME "nsrlsvr")
set(CPACK_RPM_PACKAGE_VERSION ${PACKAGE_VERSION})
set(CPACK_PACKAGE_VERSION ${PACKAGE_VERSION})
set(CPACK_PACKAGE_RELEASE "1")
set(CPACK_RPM_PACKAGE_LICENSE "ISC")
set(CPACK_RPM_PACKAGE_REQUIRES "boost-program-options >= 1.60.0")
set(CPACK_PACKAGE_DESCRIPTION_FILE "${CMAKE_CURRENT_SOURCE_DIR}/description.txt")
set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}-${CPACK_PACKAGE_RELEASE}.${CMAKE_SYSTEM_PROCESSOR}")
set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION /usr/share/man /usr/share/man/man1)
include(CPack)
add_subdirectory(man1)

View File

@@ -1,7 +1,6 @@
set(CMAKE_BUILD_TYPE Debug)
include_directories(${Boost_INCLUDE_DIRS})
add_executable(nsrlsvr handler.cc main.cc to_pair64.cc)
target_link_libraries(nsrlsvr ${Boost_PROGRAM_OPTIONS_LIBRARY})
target_link_libraries(nsrlsvr ${Boost_LIBRARIES} Threads::Threads)
set_property(TARGET nsrlsvr PROPERTY CXX_STANDARD 14)
set_property(TARGET nsrlsvr PROPERTY CXX_STANDARD_REQUIRED true)
add_definitions(-DPKGDATADIR="${PKGDATADIR}")

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2015-2016, Robert J. Hansen <rjh@sixdemonbag.org>
Copyright (c) 2015-2019, Robert J. Hansen <rjh@sixdemonbag.org>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
@@ -15,243 +15,150 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#include "main.h"
#include <iostream>
#include <algorithm>
#include <array>
#include <exception>
#include <inttypes.h>
#include <poll.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include <syslog.h>
#include <vector>
/* Additional defines necessary on Linux: */
#ifdef __linux__
#include <unistd.h> // because Fedora has lately taken to being weird
#endif
/* And Apple requires sys/uio.h for some reason. */
#ifdef __APPLE__
#include <sys/uio.h>
#include <unistd.h>
#endif
#include <sstream>
#include <boost/tokenizer.hpp>
using std::string;
using std::find;
using std::find_if;
using std::transform;
using std::vector;
using std::remove;
using std::exception;
using std::binary_search;
using std::pair;
using std::copy;
using std::back_inserter;
using std::array;
using std::fill;
using std::getline;
using std::stringstream;
using std::to_string;
using boost::asio::ip::tcp;
using boost::char_separator;
using boost::tokenizer;
// defined in main.cc
extern const vector<pair64>& hashes;
namespace {
class NetworkTimeout : public std::exception {
public:
virtual const char* what() const noexcept { return "network timeout"; }
enum class Command {
Version = 0,
Bye = 1,
Status = 2,
Query = 3,
Upshift = 4,
Downshift = 5,
Unknown = 6
};
class NetworkError : public std::exception {
public:
virtual const char* what() const noexcept { return "network error"; }
};
string
read_line(const int32_t sockfd, int timeout = 15)
{
static vector<char> buffer;
static array<char, 8192> rdbuf;
struct pollfd pfd;
struct timeval start;
struct timeval now;
time_t elapsed_time;
ssize_t bytes_received;
constexpr auto MEGABYTE = 1 << 20;
if (buffer.capacity() < MEGABYTE)
buffer.reserve(MEGABYTE);
// Step zero: check to see if there's already a string in the
// input queue awaiting a read.
auto iter = find(buffer.begin(), buffer.end(), '\n');
if (iter != buffer.end()) {
vector<char> newbuf(buffer.begin(), iter);
buffer.erase(buffer.begin(), iter + 1);
newbuf.erase(remove(newbuf.begin(), newbuf.end(), '\r'), newbuf.end());
return string(newbuf.begin(), newbuf.end());
}
// Per POSIX, this can only err if we access invalid memory.
// Since start is always valid, there's no problem here and
// no need to check gettimeofday's return code.
gettimeofday(&start, nullptr);
now.tv_sec = start.tv_sec;
now.tv_usec = start.tv_usec;
elapsed_time = now.tv_sec - start.tv_sec;
while ((elapsed_time < timeout)) {
pfd.fd = sockfd;
pfd.events = POLLIN;
pfd.revents = 0;
fill(rdbuf.begin(), rdbuf.end(), 0);
if ((buffer.size() > MEGABYTE) || (-1 == poll(&pfd, 1, 1000)) || (pfd.revents & POLLERR) || (pfd.revents & POLLHUP) || (pfd.revents & POLLNVAL)) {
log(LogLevel::ALERT, "network error: ");
if (buffer.size() > MEGABYTE) {
log(LogLevel::ALERT, "buffer too large");
}
if (pfd.revents & POLLERR) {
log(LogLevel::ALERT, "POLLERR");
}
if (pfd.revents & POLLHUP) {
log(LogLevel::ALERT, "POLLHUP");
}
if (pfd.revents & POLLNVAL) {
log(LogLevel::ALERT, "POLLNVAL");
}
throw NetworkError();
}
if (pfd.revents & POLLIN) {
bytes_received = recvfrom(sockfd, &rdbuf[0], rdbuf.size(), 0, NULL, 0);
if (0 == bytes_received) {
log(LogLevel::ALERT, "read_line read on closed socket");
throw NetworkError();
}
copy(rdbuf.begin(), rdbuf.begin() + bytes_received,
back_inserter(buffer));
}
iter = find(buffer.begin(), buffer.end(), '\n');
if (iter != buffer.end()) {
string line(buffer.begin(), iter);
if (line.at(line.size() - 1) == '\r') {
line = string(line.begin(), line.end() - 1);
}
buffer.erase(buffer.begin(), iter + 1);
return line;
}
gettimeofday(&now, nullptr);
elapsed_time = now.tv_sec - start.tv_sec;
}
throw NetworkTimeout();
}
void write_line(const int32_t sockfd, string&& line)
{
string output = line + "\r\n";
const char* msg = output.c_str();
if (-1 == send(sockfd, reinterpret_cast<const void*>(msg), output.size(), 0))
throw NetworkError();
}
auto tokenize(string&& line, char character = ' ')
auto tokenize(const string& line)
{
vector<string> rv;
transform(line.begin(), line.end(), line.begin(), toupper);
auto begin = find_if(line.cbegin(), line.cend(), [&](auto x) { return x != character; });
auto end = (begin != line.cend()) ? find(begin + 1, line.cend(), character)
: line.cend();
while (begin != line.cend()) {
rv.emplace_back(string{ begin, end });
if (end == line.cend()) {
break;
}
begin = find_if(end + 1, line.cend(), [&](auto x) { return x != character; });
end = (begin != line.cend()) ? find(begin + 1, line.cend(), character)
: line.cend();
char_separator<char> sep(" ");
tokenizer<char_separator<char>> tokens(line, sep);
for (const auto& t : tokens) {
rv.emplace_back(t);
}
return rv;
}
string
generate_response(vector<string>::const_iterator begin,
vector<string>::const_iterator end)
bool is_present_in_hashes(const string& hash)
{
string rv = "OK ";
return binary_search(hashes.cbegin(), hashes.cend(), to_pair64(hash));
}
for (auto i = begin; i != end; ++i) {
bool present = binary_search(hashes.cbegin(), hashes.cend(), to_pair64(*i));
auto getCommand(const string& cmdstring)
{
string localcmd = "";
transform(cmdstring.cbegin(), cmdstring.cend(), back_inserter(localcmd), ::toupper);
rv += present ? "1" : "0";
}
return rv;
auto cmd = Command::Unknown;
if (localcmd == "VERSION:")
cmd = Command::Version;
else if (localcmd == "BYE")
cmd = Command::Bye;
else if (localcmd == "STATUS")
cmd = Command::Status;
else if (localcmd == "QUERY")
cmd = Command::Query;
else if (localcmd == "UPSHIFT")
cmd = Command::Upshift;
else if (localcmd == "DOWNSHIFT")
cmd = Command::Downshift;
return cmd;
}
}
void handle_client(const int32_t fd)
void handle_client(tcp::iostream& stream)
{
enum class Command {
Version = 0,
Bye = 1,
Status = 2,
Query = 3,
Upshift = 4,
Downshift = 5,
Unknown = 6
};
const string ipaddr = stream.socket().remote_endpoint().address().to_string();
unsigned long long queries = 0;
try {
auto commands = tokenize(read_line(fd));
while (true) {
auto cmdstring = commands.at(0);
Command cmd = Command::Unknown;
while (stream) {
string line;
getline(stream, line);
if (line.size() == 0) return;
if (cmdstring == "VERSION:")
cmd = Command::Version;
else if (cmdstring == "BYE")
cmd = Command::Bye;
else if (cmdstring == "STATUS")
cmd = Command::Status;
else if (cmdstring == "QUERY")
cmd = Command::Query;
else if (cmdstring == "UPSHIFT")
cmd = Command::Upshift;
else if (cmdstring == "DOWNSHIFT")
cmd = Command::Downshift;
// trim leading/following whitespace
auto end_ws = line.find_last_not_of("\t\n\v\f\r ");
if (end_ws != string::npos) {
line.erase(end_ws + 1);
}
auto front_ws = line.find_first_not_of("\t\n\v\f\r ");
if (front_ws > 0) {
line.erase(0, front_ws);
}
switch (cmd) {
auto commands = tokenize(line);
switch (getCommand(commands.at(0))) {
case Command::Version:
write_line(fd, "OK");
stream << "OK\r\n";
break;
case Command::Bye:
return;
case Command::Status:
write_line(fd, "NOT SUPPORTED");
stream << "NOT SUPPORTED\r\n";
break;
case Command::Query:
write_line(fd,
generate_response(commands.begin() + 1, commands.end()));
{
stringstream rv;
rv << "OK ";
for (size_t idx = 1 ; idx < commands.size(); ++idx)
rv << (is_present_in_hashes(commands.at(idx)) ? "1" : "0");
rv << "\r\n";
queries += (commands.size() - 1);
stream << rv.str();
break;
}
case Command::Upshift:
write_line(fd, "NOT OK");
stream << "NOT OK\r\n";
break;
case Command::Downshift:
write_line(fd, "NOT OK");
stream << "NOT OK\r\n";
break;
case Command::Unknown:
write_line(fd, "NOT OK");
stream << "NOT OK\r\n";
return;
}
commands = tokenize(read_line(fd));
}
} catch (std::exception&) {
// Do nothing: just end the function, which will drop the connection.
}
catch (std::exception& e) {
log(LogLevel::ALERT, string("Error: ") + e.what());
// swallow the exception: we'll close the connection
// automagically on exit
//
// fall-through here to function returb
}
stringstream status_msg;
status_msg << ipaddr << " closed session after " << queries
<< " queries";
log(LogLevel::ALERT, status_msg.str());
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2015-2016, Robert J. Hansen <rjh@sixdemonbag.org>
Copyright (c) 2015-2019, Robert J. Hansen <rjh@sixdemonbag.org>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
@@ -16,8 +16,8 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#include "main.h"
#include <algorithm>
#include <arpa/inet.h>
#include <boost/program_options.hpp>
#include <boost/asio.hpp>
#include <cstdio>
#include <cstdlib>
#include <cstring>
@@ -33,11 +33,6 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#include <unistd.h>
#include <vector>
#ifdef __FreeBSD__
#include <netinet/in.h>
#include <sys/socket.h>
#endif
using std::string;
using std::transform;
using std::ifstream;
@@ -57,6 +52,7 @@ using boost::program_options::store;
using boost::program_options::parse_command_line;
using boost::program_options::notify;
using boost::program_options::value;
using boost::asio::ip::tcp;
namespace {
vector<pair64> hash_set;
@@ -131,7 +127,7 @@ void load_hashes()
hash_count += 1;
if (0 == hash_count % 1000000) {
string howmany{ to_string(hash_count / 1000000) };
log(LogLevel::ALERT, "loaded " + howmany + " million hashes");
log(LogLevel::INFO, "loaded " + howmany + " million hashes");
}
} catch (std::bad_alloc&) {
log(LogLevel::ALERT, "couldn't allocate enough memory");
@@ -147,16 +143,16 @@ void load_hashes()
if (hash_set.size() > 1) {
log(LogLevel::INFO, "ensuring no duplicates");
pair64 foo{ hash_set.at(0) };
for (auto iter = (hash_set.cbegin() + 1); iter != hash_set.cend(); ++iter) {
if (foo == *iter) {
if (*(iter - 1) == *iter) {
log(LogLevel::ALERT, "hash file contains duplicates -- "
"shutting down!");
exit(EXIT_FAILURE);
}
foo = *iter;
}
}
log(LogLevel::INFO, "successfully loaded hashes");
}
/** Converts this process into a well-behaved UNIX daemon.*/
@@ -191,36 +187,6 @@ void daemonize()
close(STDERR_FILENO);
}
/** Creates a server socket to listen for client connections. */
auto make_socket()
{
/* If anything in here is surprising, please check the standard
literature to make sure you understand TCP/IP. */
sockaddr_in server;
memset(static_cast<void*>(&server), 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = htonl(INADDR_ANY);
server.sin_port = htons(port);
const auto sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (sock < 0) {
log(LogLevel::WARN, "couldn't create a server socket");
exit(EXIT_FAILURE);
}
if (0 > bind(sock, reinterpret_cast<sockaddr*>(&server), sizeof(server))) {
log(LogLevel::WARN, "couldn't bind to port");
exit(EXIT_FAILURE);
}
if (0 > listen(sock, 20)) {
log(LogLevel::WARN, "couldn't listen for clients");
exit(EXIT_FAILURE);
}
log(LogLevel::INFO, "ready for clients");
return sock;
}
/** Parse command-line options.
@param argc argc from main()
@param argv argv from main()
@@ -329,6 +295,8 @@ void log(const LogLevel level, const string&& msg)
*/
int main(int argc, char* argv[])
{
static_assert(sizeof(unsigned long long) == 8,
"wait, what kind of system is this?");
parse_options(argc, argv);
if (!dry_run)
@@ -336,12 +304,6 @@ int main(int argc, char* argv[])
load_hashes();
int32_t client_sock{ 0 };
int32_t svr_sock{ make_socket() };
sockaddr_in client;
sockaddr* client_addr = reinterpret_cast<sockaddr*>(&client);
socklen_t client_length{ sizeof(client) };
// The following line helps avoid zombie processes. Normally parents
// need to reap their children in order to prevent zombie processes;
// if SIGCHLD is set to SIG_IGN, though, the processes can terminate
@@ -350,89 +312,26 @@ int main(int argc, char* argv[])
if (dry_run)
return EXIT_SUCCESS;
boost::asio::io_service io_service;
tcp::endpoint endpoint(tcp::v4(), port);
tcp::acceptor acceptor(io_service, endpoint);
while (true) {
if (0 > (client_sock = accept(svr_sock, client_addr, &client_length))) {
log(LogLevel::WARN, "could not accept connection");
switch (errno) {
case EAGAIN:
log(LogLevel::WARN, "-- EAGAIN");
break;
case ECONNABORTED:
log(LogLevel::WARN, "-- ECONNABORTED");
break;
case EINTR:
log(LogLevel::WARN, "-- EINTR");
break;
case EINVAL:
log(LogLevel::WARN, "-- EINVAL");
break;
case EMFILE:
log(LogLevel::WARN, "-- EMFILE");
break;
case ENFILE:
log(LogLevel::WARN, "-- ENFILE");
break;
case ENOTSOCK:
log(LogLevel::WARN, "-- ENOTSOCK");
break;
case EOPNOTSUPP:
log(LogLevel::WARN, "-- EOPNOTSUPP");
break;
case ENOBUFS:
log(LogLevel::WARN, "-- ENOBUFS");
break;
case ENOMEM:
log(LogLevel::WARN, "-- ENOMEM");
break;
case EPROTO:
log(LogLevel::WARN, "-- EPROTO");
break;
default:
log(LogLevel::WARN, "-- EUNKNOWN");
break;
}
tcp::iostream stream;
boost::system::error_code error;
acceptor.accept(*stream.rdbuf(), error);
if (error) {
continue;
}
string ipaddr{ inet_ntoa(client.sin_addr) };
string ipaddr = stream.socket().remote_endpoint().address().to_string();
log(LogLevel::ALERT, string("accepted a client: ") + ipaddr);
if (0 == fork()) {
log(LogLevel::ALERT, "calling handle_client");
handle_client(client_sock);
if (-1 == close(client_sock)) {
log(LogLevel::WARN, string("Could not close client: ") + ipaddr);
switch (errno) {
case EBADF:
log(LogLevel::WARN, "-- EBADF");
break;
case EINTR:
log(LogLevel::WARN, "-- EINTR");
break;
case EIO:
log(LogLevel::WARN, "-- EIO");
break;
}
} else {
log(LogLevel::ALERT, string("closed client ") + ipaddr);
}
handle_client(stream);
return 0;
} else {
if (-1 == close(client_sock)) {
log(LogLevel::WARN, string("Parent could not close client: ") + ipaddr);
switch (errno) {
case EBADF:
log(LogLevel::WARN, "-- EBADF");
break;
case EINTR:
log(LogLevel::WARN, "-- EINTR");
break;
case EIO:
log(LogLevel::WARN, "-- EIO");
break;
}
}
}
}

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2015-2016, Robert J. Hansen <rjh@sixdemonbag.org>
Copyright (c) 2015-2019, Robert J. Hansen <rjh@sixdemonbag.org>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
@@ -17,13 +17,15 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#ifndef MAIN_H
#define MAIN_H
#include <string>
#include <sys/types.h>
#include <syslog.h>
#include <utility>
#include <string>
#include <syslog.h>
#include <cstdint>
#include <boost/asio.hpp>
using pair64 = std::pair<uint64_t, uint64_t>;
// Note: C++11 guarantees an unsigned long long will be at least 64 bits.
// A compile-time assert in main.cc guarantees it will ONLY be 64 bits.
using pair64 = std::pair<unsigned long long, unsigned long long>;
enum class LogLevel
{
@@ -36,9 +38,9 @@ enum class LogLevel
};
void log(const LogLevel, const std::string&&);
void handle_client(const int32_t);
pair64 to_pair64(std::string);
std::string from_pair64(pair64);
void handle_client(boost::asio::ip::tcp::iostream& stream);
pair64 to_pair64(const std::string&);
std::string from_pair64(const pair64&);
bool operator<(const pair64& lhs, const pair64& rhs);
bool operator==(const pair64& lhs, const pair64& rhs);
bool operator>(const pair64& lhs, const pair64& rhs);

View File

@@ -1,5 +1,5 @@
/*
Copyright (c) 2015-2016, Robert J. Hansen <rjh@sixdemonbag.org>
Copyright (c) 2015-2019, Robert J. Hansen <rjh@sixdemonbag.org>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
@@ -16,68 +16,55 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#include "main.h"
#include <algorithm>
#include <ctype.h>
#include <iostream>
#include <stdexcept>
#include <regex>
#include <sstream>
#include <iomanip>
using std::pair;
using std::string;
using std::transform;
using std::make_pair;
using std::regex;
using std::regex_match;
using std::invalid_argument;
using std::stringstream;
using std::setw;
using std::setfill;
using std::hex;
string
from_pair64(pair64 input)
from_pair64(const pair64& input)
{
static string hexadecimal{ "0123456789ABCDEF" };
string left = "", right = "";
uint64_t first = input.first;
uint64_t second = input.second;
for (int i = 0; i < 16; ++i) {
left = hexadecimal[first & 0x0F] + left;
right = hexadecimal[second & 0x0F] + right;
first >>= 4;
second >>= 4;
}
return left + right;
stringstream stream;
stream << setfill('0')
<< setw(sizeof(unsigned long long) * 2)
<< hex
<< input.first
<< input.second;
return string(stream.str());
}
pair64
to_pair64(string input)
to_pair64(const string& input)
{
uint64_t left{ 0 };
uint64_t right{ 0 };
uint8_t val1{ 0 };
uint8_t val2{ 0 };
size_t index{ 0 };
char ch1{ 0 };
char ch2{ 0 };
static const regex md5_re{ "^[A-Fa-f0-9]{32}$" };
transform(input.begin(), input.end(), input.begin(), ::tolower);
if (!regex_match(input.cbegin(), input.cend(), md5_re))
throw invalid_argument("not a hash");
auto first = string(input.cbegin(), input.cbegin() + 16);
auto second = string(input.cbegin() + 16, input.cend());
auto left = std::strtoull(first.c_str(), nullptr, 16);
auto right = std::strtoull(second.c_str(), nullptr, 16);
for (index = 0; index < 16; index += 1) {
ch1 = input[index];
ch2 = input[index + 16];
val1 = (ch1 >= '0' and ch1 <= '9') ? static_cast<uint8_t>(ch1 - '0')
: static_cast<uint8_t>(ch1 - 'a') + 10;
val2 = (ch2 >= '0' and ch2 <= '9') ? static_cast<uint8_t>(ch2 - '0')
: static_cast<uint8_t>(ch2 - 'a') + 10;
left = (left << 4) + val1;
right = (right << 4) + val2;
}
return make_pair(left, right);
}
bool operator<(const pair64& lhs, const pair64& rhs)
{
return (lhs.first < rhs.first)
? true
: (lhs.first == rhs.first and lhs.second < rhs.second) ? true
: false;
return (lhs.first < rhs.first) or
(lhs.first == rhs.first and lhs.second < rhs.second);
}
bool operator==(const pair64& lhs, const pair64& rhs)
@@ -87,5 +74,5 @@ bool operator==(const pair64& lhs, const pair64& rhs)
bool operator>(const pair64& lhs, const pair64& rhs)
{
return ((not(lhs < rhs)) and (not(lhs == rhs)));
return ((!(lhs < rhs)) and (!(lhs == rhs)));
}