This commit is contained in:
2025-02-15 09:53:11 +01:00
commit 0922b77fad
44 changed files with 3384 additions and 0 deletions

24
networking/CMakeLists.txt Normal file
View File

@@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.15)
add_library(networking
include/Server.hpp
src/Server.cpp
include/Client.hpp
src/Client.cpp
include/Helpers.hpp
src/Helpers.cpp
src/AsyncSslTransport.cpp
include/AsyncSslTransport.hpp
src/AsyncSslClientTransport.cpp
include/AsyncSslClientTransport.hpp
include/AsyncSslServerTransport.hpp
src/AsyncSslServerTransport.cpp
)
target_include_directories(networking PUBLIC include)
find_package(OpenSSL REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_check_modules(FUSE REQUIRED IMPORTED_TARGET fuse)
target_link_libraries(networking PRIVATE utils OpenSSL::SSL OpenSSL::Crypto PkgConfig::FUSE)

View File

@@ -0,0 +1,34 @@
//
// Created by Stepan Usatiuk on 14.12.2024.
//
#ifndef ASYNCMESSAGECLIENT_HPP
#define ASYNCMESSAGECLIENT_HPP
#include "AsyncSslTransport.hpp"
class AsyncSslClientTransport : public AsyncSslTransport {
public:
AsyncSslClientTransport(SSL_CTX* ssl_ctx, int fd) : AsyncSslTransport(ssl_ctx, fd) {}
using SharedMsgPromiseT = std::shared_ptr<std::promise<std::shared_ptr<MsgWrapper>>>;
std::future<std::shared_ptr<MsgWrapper>> send_msg(std::vector<uint8_t> message);
std::vector<uint8_t> send_msg_and_wait(std::vector<uint8_t> message);
protected:
void handle_message(std::shared_ptr<MsgWrapper> msg) override;
void handle_fail() override {
stop();
std::exit(EXIT_FAILURE);
}
void before_entry() override;
private:
std::unordered_map<decltype(MsgWrapper::id), SharedMsgPromiseT> _promises;
uint64_t _msg_id = 0;
std::mutex _promises_mutex;
};
#endif // ASYNCMESSAGECLIENT_HPP

View File

@@ -0,0 +1,32 @@
//
// Created by Stepan Usatiuk on 15.12.2024.
//
#ifndef SERVERMESSAGEPUMP_HPP
#define SERVERMESSAGEPUMP_HPP
#include "AsyncSslTransport.hpp"
class AsyncSslServerTransport : public AsyncSslTransport {
public:
AsyncSslServerTransport(SSL_CTX* ssl_ctx, int fd, int client_id) : AsyncSslTransport(ssl_ctx, fd), _client_id(client_id) {}
// Null if finished
std::shared_ptr<MsgWrapper> get_msg();
protected:
void handle_message(std::shared_ptr<MsgWrapper> msg) override;
void handle_fail() override;
void before_entry() override;
private:
int _client_id;
std::deque<std::shared_ptr<MsgWrapper>> _msgs;
std::mutex _msgs_mutex;
std::condition_variable _msgs_condition;
};
#endif // SERVERMESSAGEPUMP_HPP

View File

@@ -0,0 +1,62 @@
//
// Created by Stepan Usatiuk on 14.12.2024.
//
#ifndef MESSAGEPUMP_HPP
#define MESSAGEPUMP_HPP
#include <condition_variable>
#include <deque>
#include <future>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <openssl/ssl.h>
#include "Helpers.hpp"
class AsyncSslTransport {
public:
AsyncSslTransport(SSL_CTX* ssl_ctx, int fd);
virtual ~AsyncSslTransport() = 0;
void run();
void send_message(std::shared_ptr<MsgWrapper> msg);
bool is_failed() const { return _failed; }
bool is_stopped() const { return _stopped; }
void stop();
protected:
virtual void handle_message(std::shared_ptr<MsgWrapper> msg) = 0;
virtual void handle_fail() = 0;
virtual void before_entry() {}
std::unique_ptr<SSL, decltype(&SSL_free)> _ssl{nullptr, &SSL_free};
int _fd;
private:
void thread_entry();
std::atomic<bool> _stopped = 0;
std::mutex _stopped_mutex;
std::condition_variable _stopped_condition;
std::thread _thread;
std::mutex _to_send_mutex;
int _to_send_notif_pipe[2];
std::deque<std::shared_ptr<MsgWrapper>> _to_send;
std::atomic<bool> _failed;
AsyncSslTransport(const AsyncSslTransport& other) = delete;
AsyncSslTransport(AsyncSslTransport&& other) noexcept = delete;
AsyncSslTransport& operator=(const AsyncSslTransport& other) = delete;
AsyncSslTransport& operator=(AsyncSslTransport&& other) noexcept = delete;
};
#endif // MESSAGEPUMP_HPP

View File

@@ -0,0 +1,43 @@
//
// Created by stepus53 on 11.12.24.
//
#ifndef TCPCLIENT_HPP
#define TCPCLIENT_HPP
#include <cstdint>
#include <mutex>
#include <string>
#include <memory>
#include <optional>
#include <openssl/ssl.h>
#include "AsyncSslClientTransport.hpp"
class Client {
public:
Client(uint16_t port, std::string ip, std::string cert_path, std::string key_path);
void run();
AsyncSslClientTransport& transport() { return *_transport; }
protected:
uint16_t _port;
std::string _ip;
std::string _cert_path;
std::string _key_path;
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> _ssl_ctx;
std::unique_ptr<SSL, decltype(&SSL_free)> _ssl{nullptr, &SSL_free};
int _sock;
size_t _msg_id = 0;
private:
std::optional<AsyncSslClientTransport> _transport;
};
#endif // TCPCLIENT_HPP

View File

@@ -0,0 +1,37 @@
#ifndef NETWORKING_HELPERS_H
#define NETWORKING_HELPERS_H
#include <cstdint>
#include <deque>
#include <functional>
#include <vector>
#include "Options.h"
#include "stuff.hpp"
#include <openssl/ssl.h>
using MsgIdType = uint64_t;
struct MsgWrapper {
MsgIdType id;
std::vector<uint8_t> data;
};
struct MsgHeader {
uint64_t id;
uint64_t len;
} __attribute__((packed));
namespace Helpers {
void poll_wait(int fd, bool write, int timeout = checked_cast<int>(Options::get<size_t>("timeout")) * 1000);
bool SSL_write(SSL* ctx, int fd, const std::vector<uint8_t>& buf);
void init_nonblock(int fd);
std::vector<uint8_t> SSL_read_n(SSL* ctx, int fd, size_t n);
MsgWrapper SSL_read_msg(SSL* ctx, int fd);
void SSL_send_msg(SSL* ctx, int fd, const MsgWrapper& buf);
} // namespace Helpers
#endif

View File

@@ -0,0 +1,53 @@
//
// Created by Stepan Usatiuk on 09.12.2024.
//
#ifndef TCPSERVER_HPP
#define TCPSERVER_HPP
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <optional>
#include <string>
#include <openssl/ssl.h>
#include "AsyncSslServerTransport.hpp"
#include "Helpers.hpp"
struct ClientCtx {
std::optional<std::string> client_name;
AsyncSslServerTransport transport;
std::mutex ctx_mutex;
};
class Server {
public:
Server(uint16_t port, uint32_t ip, std::string cert_path, std::string key_path);
void run();
protected:
uint16_t _port;
uint32_t _ip;
std::string _cert_path;
std::string _key_path;
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> _ssl_ctx;
void process_req(int conn_fd);
virtual std::vector<uint8_t> handle_message(ClientCtx& client, std::vector<uint8_t> data) = 0;
private:
std::atomic<int> _total_req{0};
std::atomic<int> _req_in_progress{0};
std::mutex _req_in_progress_mutex;
std::condition_variable _req_in_progress_cond;
};
#endif // TCPSERVER_HPP

View File

@@ -0,0 +1,60 @@
//
// Created by Stepan Usatiuk on 14.12.2024.
//
#include "AsyncSslClientTransport.hpp"
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "Logger.h"
std::future<std::shared_ptr<MsgWrapper>> AsyncSslClientTransport::send_msg(std::vector<uint8_t> message) {
auto promise = std::make_shared<std::promise<std::shared_ptr<MsgWrapper>>>();
decltype(_msg_id) id;
{
std::lock_guard lock(_promises_mutex);
id = _msg_id++;
_promises.emplace(id, promise);
}
send_message(std::make_shared<MsgWrapper>(id, std::move(message)));
return promise->get_future();
}
std::vector<uint8_t> AsyncSslClientTransport::send_msg_and_wait(std::vector<uint8_t> message) {
auto future = send_msg(std::move(message));
return future.get()->data;
}
void AsyncSslClientTransport::before_entry() {
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Connecting"; }, Logger::INFO);
int r;
while ((r = SSL_connect(_ssl.get())) <= 0) {
ERR_print_errors_fp(stderr);
int err = SSL_get_error(_ssl.get(), r);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
Helpers::poll_wait(_fd, err == SSL_ERROR_WANT_WRITE);
continue;
}
throw OpenSSLException("SSL_connect() failed");
}
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Connected"; }, Logger::INFO);
}
void AsyncSslClientTransport::handle_message(std::shared_ptr<MsgWrapper> msg) {
std::lock_guard lock(_promises_mutex);
auto future_it = _promises.find(msg->id);
if (future_it == _promises.end()) {
Logger::log(Logger::RemoteFs, "Could not find future for msg with id " + std::to_string(msg->id),
Logger::ERROR);
return;
}
future_it->second->set_value(msg);
_promises.erase(future_it);
}

View File

@@ -0,0 +1,51 @@
//
// Created by Stepan Usatiuk on 15.12.2024.
//
#include "AsyncSslServerTransport.hpp"
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "Logger.h"
void AsyncSslServerTransport::handle_message(std::shared_ptr<MsgWrapper> msg) {
std::lock_guard lock(_msgs_mutex);
_msgs.emplace_back(msg);
_msgs_condition.notify_all();
}
void AsyncSslServerTransport::handle_fail() {
_msgs_condition.notify_all();
}
void AsyncSslServerTransport::before_entry() {
int r;
while ((r = SSL_accept(_ssl.get())) <= 0) {
ERR_print_errors_fp(stderr);
int err = SSL_get_error(_ssl.get(), r);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
Helpers::poll_wait(_fd, err == SSL_ERROR_WANT_WRITE);
continue;
}
throw OpenSSLException("SSL_accept() failed");
}
Logger::log(Logger::RemoteFs, "Client " + std::to_string(_client_id) + " connected\n", Logger::INFO);
}
std::shared_ptr<MsgWrapper> AsyncSslServerTransport::get_msg() {
std::unique_lock lock(_msgs_mutex);
_msgs_condition.wait(lock, [&] { return !_msgs.empty() || is_failed() || is_stopped(); });
if (is_failed())
return nullptr;
auto ret = _msgs.begin();
auto real_ret = *ret;
_msgs.erase(_msgs.begin());
return real_ret;
}

View File

@@ -0,0 +1,249 @@
//
// Created by Stepan Usatiuk on 14.12.2024.
//
#include "AsyncSslTransport.hpp"
#include <fcntl.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <poll.h>
#include <unistd.h>
#include <iomanip>
#include "Logger.h"
#include "Serialize.hpp"
#include "stuff.hpp"
AsyncSslTransport::AsyncSslTransport(SSL_CTX* ssl_ctx, int fd) : _ssl(SSL_new(ssl_ctx), &SSL_free), _fd(fd) {
SSL_set_fd(_ssl.get(), _fd);
pipe(_to_send_notif_pipe);
Helpers::init_nonblock(_fd);
}
AsyncSslTransport::~AsyncSslTransport() {
stop();
_thread.join();
}
void AsyncSslTransport::stop() {
std::unique_lock lock(_stopped_mutex);
_stopped_condition.notify_all();
_stopped = true;
write(_to_send_notif_pipe[1], "1", 1);
}
void AsyncSslTransport::send_message(std::shared_ptr<MsgWrapper> msg) {
std::unique_lock lock(_to_send_mutex);
_to_send.push_back(std::move(msg));
write(_to_send_notif_pipe[1], "1", 1);
}
void AsyncSslTransport::thread_entry() {
try {
before_entry();
bool sending = false;
std::vector<uint8_t> to_send_buf{};
size_t cur_sent = 0;
bool reading_msg = false; // False if reading header, true if message
std::vector<uint8_t> read_buf;
size_t cur_read{};
uint64_t msg_id = 0;
size_t msg_len = sizeof(MsgHeader); // Message length to read if reading message, or header len
read_buf.resize(msg_len);
while (!_stopped) {
if (!sending) {
std::shared_ptr<MsgWrapper> to_send_now{};
std::unique_lock lock(_to_send_mutex);
auto next = _to_send.begin();
if (next != _to_send.end()) {
to_send_now = *next;
_to_send.erase(next);
}
if (to_send_now) {
MsgHeader header{};
header.id = htobe64(to_send_now->id);
header.len = htobe64(to_send_now->data.size());
to_send_buf.resize(sizeof(MsgHeader) + to_send_now->data.size());
memcpy(to_send_buf.data(), &header, sizeof(MsgHeader));
memcpy(to_send_buf.data() + sizeof(MsgHeader), to_send_now->data.data(), to_send_now->data.size());
sending = true;
Logger::log(
Logger::RemoteFs, [&](std::ostream& os) { os << "Started sending message " << to_send_now->id; },
Logger::DEBUG);
}
}
while (sending) {
size_t written_now = 0;
int ret;
if ((ret = SSL_write_ex(_ssl.get(), to_send_buf.data() + cur_sent, to_send_buf.size() - cur_sent,
&written_now)) <= 0) {
int err = SSL_get_error(_ssl.get(), ret);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
break;
}
throw OpenSSLException("write failed");
}
Logger::log(
Logger::RemoteFs,
[&](std::ostream& os) {
os << "Written " << written_now;
if (Logger::en_level(Logger::RemoteFs, Logger::TRACE)) {
os << ": ";
for (size_t i = 0; i < written_now; i++) {
os << std::setw(2) << std::setfill('0') << std::hex
<< (int) to_send_buf[i + cur_sent] << " ";
}
}
},
Logger::DEBUG);
cur_sent += written_now;
if (cur_sent == to_send_buf.size()) {
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Finished sending"; }, Logger::DEBUG);
to_send_buf.resize(0);
cur_sent = 0;
sending = false;
}
}
while (true) {
int ret;
size_t read_now = 0;
if ((ret = SSL_read_ex(_ssl.get(), read_buf.data() + cur_read, msg_len, &read_now)) <= 0) {
int err = SSL_get_error(_ssl.get(), ret);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
break;
}
throw OpenSSLException("read failed");
}
Logger::log(
Logger::RemoteFs,
[&](std::ostream& os) {
os << "Read " << read_now;
if (Logger::en_level(Logger::RemoteFs, Logger::TRACE)) {
os << ": ";
for (size_t i = 0; i < read_now; i++) {
os << std::setw(2) << std::setfill('0') << std::hex << (int) read_buf[i + cur_read]
<< " ";
}
}
},
Logger::DEBUG);
cur_read += read_now;
if (cur_read == msg_len) {
if (reading_msg) {
handle_message(std::make_shared<MsgWrapper>(msg_id, std::move(read_buf)));
reading_msg = false;
msg_len = sizeof(MsgHeader);
Logger::log(
Logger::RemoteFs,
[&](std::ostream& os) { os << "Finished receiving message " << msg_id; },
Logger::DEBUG);
} else {
MsgHeader hdr;
memcpy(&hdr, read_buf.data(), sizeof(hdr));
reading_msg = true;
msg_len = be64toh(hdr.len);
msg_id = be64toh(hdr.id);
Logger::log(
Logger::RemoteFs,
[&](std::ostream& os) { os << "Started receiving message " << msg_id; }, Logger::DEBUG);
}
read_buf.clear();
read_buf.resize(msg_len);
cur_read = 0;
}
}
pollfd fds[2];
fds[0].fd = _fd;
fds[0].events = POLLIN;
if (sending)
fds[0].events |= POLLOUT;
fds[0].revents = 0;
fds[1].fd = _to_send_notif_pipe[0];
fds[1].events = POLLIN;
fds[1].revents = 0;
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Waiting"; }, Logger::DEBUG);
if (poll(fds, 2, checked_cast<int>(Options::get<size_t>("timeout")) * 1000) < 0) {
if (errno == EINTR)
return;
throw ErrnoException("Could not poll");
}
if (!(fds[0].revents & POLLOUT) && !(fds[1].revents & POLLIN) && !(fds[0].revents & POLLIN)) {
throw ErrnoException("Could not poll (timeout?)");
}
if (fds[1].revents & POLLIN) {
char temp;
Logger::log(
Logger::RemoteFs, [&](std::ostream& os) { os << "Received message on pipe"; }, Logger::DEBUG);
read(_to_send_notif_pipe[0], &temp, 1);
}
// Logger::log(
// Logger::Server,
// [&](std::ostream& os) {
// os << "Sent " << (*next)->id << " with " << (*next)->data.size() << " bytes: ";
// if (Logger::en_level(Logger::Server, Logger::TRACE))
// for (unsigned char i: (*next)->data)
// os << (int) i << " ";
// },
// Logger::DEBUG);
// Helpers::SSL_send_msg(_ssl, _fd, **next);
}
} catch (std::exception& e) {
_failed = true;
Logger::log(Logger::RemoteFs, e.what(), Logger::ERROR);
handle_fail();
}
SSL_shutdown(_ssl.get());
OPENSSL_thread_stop();
}
// void SSLMessagePump::receiver_entry() {
// try {
// before_receiver();
//
// while (!_stopped) {
// auto msg = std::make_shared<MsgWrapper>(Helpers::SSL_read_msg(_ssl, _fd));
// Logger::log(
// Logger::Server,
// [&](std::ostream& os) {
// os << "Received " << msg->id << " with " << msg->data.size() << " bytes: ";
// if (Logger::en_level(Logger::Server, Logger::TRACE))
// for (unsigned char i: msg->data)
// os << (int) i << " ";
// },
// Logger::DEBUG);
// handle_message(msg);
// }
// } catch (std::exception& e) {
// _failed = true;
// Logger::log(Logger::Server, e.what(), Logger::ERROR);
// handle_fail();
// }
// }
void AsyncSslTransport::run() {
_thread = std::thread([&]() { thread_entry(); });
}

72
networking/src/Client.cpp Normal file
View File

@@ -0,0 +1,72 @@
//
// Created by stepus53 on 11.12.24.
//
#include "Client.hpp"
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <signal.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include <openssl/err.h>
#include "Exception.h"
#include "Helpers.hpp"
#include "Logger.h"
// From https://wiki.openssl.org/index.php/Simple_TLS_Server
static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> create_context() {
const SSL_METHOD* method;
SSL_CTX* ctx;
method = TLS_client_method();
ctx = SSL_CTX_new(method);
if (!ctx) {
throw OpenSSLException("Unable to create SSL context");
}
return {ctx, &SSL_CTX_free};
}
Client::Client(uint16_t port, std::string ip, std::string cert_path, std::string key_path) :
_port(port), _ip(ip), _cert_path(cert_path), _key_path(key_path), _ssl_ctx(create_context()) {
SSL_CTX_set_verify(_ssl_ctx.get(), SSL_VERIFY_PEER, nullptr);
if (SSL_CTX_load_verify_locations(_ssl_ctx.get(), cert_path.c_str(), nullptr) <= 0) {
throw OpenSSLException("Unable to read certificate file");
}
}
void Client::run() {
protoent* proto = getprotobyname("tcp");
if (proto == NULL) {
throw ErrnoException("Could not get TCP protocol info");
}
_sock = socket(AF_INET, SOCK_STREAM, proto->p_proto);
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(_port);
if (inet_pton(AF_INET, _ip.c_str(), &addr.sin_addr) <= 0)
throw ErrnoException("inet_pton()");
memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
#ifdef APPLE
addr.sin_len = sizeof(struct sockaddr_in),
#endif
if (connect(_sock, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) throw ErrnoException("connect()");
_transport.emplace(_ssl_ctx.get(), _sock);
_transport->run();
}

View File

@@ -0,0 +1,93 @@
//
// Created by Stepan Usatiuk on 09.12.2024.
//
#include "Helpers.hpp"
#include <fcntl.h>
#include <poll.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "Exception.h"
#include "Options.h"
#include "stuff.hpp"
void Helpers::poll_wait(int fd, bool write, int timeout) {
pollfd p;
p.fd = fd;
p.events = write ? POLLOUT : POLLIN;
p.revents = 0;
if (poll(&p, 1, timeout) < 0) {
if (errno == EINTR)
return;
throw ErrnoException("Could not poll");
}
if (!(p.revents & (write ? POLLOUT : POLLIN))) {
throw ErrnoException("Could not poll (timeout?)");
}
}
void Helpers::init_nonblock(int fd) {
int r = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
if (r < 0) {
throw ErrnoException("Could not set nonblocking mode");
}
}
bool Helpers::SSL_write(SSL* ctx, int fd, const std::vector<uint8_t>& buf) {
for (size_t written = 0; written < buf.size();) {
size_t written_now = 0;
int ret;
while ((ret = SSL_write_ex(ctx, buf.data() + written, buf.size() - written, &written_now)) <= 0) {
int err = SSL_get_error(ctx, ret);
if (err == SSL_ERROR_WANT_WRITE) {
poll_wait(fd, true);
continue;
}
throw OpenSSLException("write failed");
}
written += written_now;
}
return true;
}
std::vector<uint8_t> Helpers::SSL_read_n(SSL* ctx, int fd, size_t n) {
size_t nread = 0;
size_t read_now = 0;
std::vector<uint8_t> buf(n);
int ret;
while (nread < n) {
while ((ret = SSL_read_ex(ctx, buf.data() + nread, n, &read_now)) <= 0) {
int err = SSL_get_error(ctx, ret);
if (err == SSL_ERROR_WANT_READ) {
Helpers::poll_wait(fd, false);
continue;
}
throw OpenSSLException("read failed");
}
nread += read_now;
}
return buf;
}
MsgWrapper Helpers::SSL_read_msg(SSL* ctx, int fd) {
auto hdrBuf = SSL_read_n(ctx, fd, sizeof(MsgHeader));
MsgHeader hdr;
memcpy(&hdr, hdrBuf.data(), sizeof(hdr));
auto data = SSL_read_n(ctx, fd, hdr.len);
return {hdr.id, std::move(data)};
}
void Helpers::SSL_send_msg(SSL* ctx, int fd, const MsgWrapper& buf) {
MsgHeader header{};
header.id = buf.id;
header.len = checked_cast<uint32_t>(buf.data.size());
std::vector<uint8_t> hdrBuf(sizeof(MsgHeader));
memcpy(hdrBuf.data(), &header, sizeof(MsgHeader));
SSL_write(ctx, fd, hdrBuf);
SSL_write(ctx, fd, buf.data);
}

171
networking/src/Server.cpp Normal file
View File

@@ -0,0 +1,171 @@
//
// Created by Stepan Usatiuk on 09.12.2024.
//
#ifndef TCPSERVER_IPP
#define TCPSERVER_IPP
#include "Server.hpp"
#include <memory>
#include <string>
#include <thread>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <signal.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "Exception.h"
#include "Helpers.hpp"
#include "Logger.h"
// From https://wiki.openssl.org/index.php/Simple_TLS_Server
static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> create_context() {
const SSL_METHOD* method;
SSL_CTX* ctx;
method = TLS_server_method();
ctx = SSL_CTX_new(method);
if (!ctx) {
throw OpenSSLException("Unable to create SSL context");
}
return {ctx, &SSL_CTX_free};
}
static void configure_context(SSL_CTX* ctx, const std::string& cert_path, const std::string& key_path) {
/* Set the key and cert */
if (SSL_CTX_use_certificate_file(ctx, cert_path.c_str(), SSL_FILETYPE_PEM) <= 0) {
throw OpenSSLException("Unable to read certificate file");
}
if (SSL_CTX_use_PrivateKey_file(ctx, key_path.c_str(), SSL_FILETYPE_PEM) <= 0) {
throw OpenSSLException("Unable to read private key file");
}
}
Server::Server(uint16_t port, uint32_t ip, std::string cert_path, std::string key_path) :
_port(port), _ip(ip), _cert_path(std::move(cert_path)), _key_path(std::move(key_path)), _ssl_ctx(create_context()) {
configure_context(_ssl_ctx.get(), _cert_path, _key_path);
}
void Server::process_req(int conn_fd) {
_req_in_progress.fetch_add(1);
std::thread proc([=, this] {
int id = _total_req.fetch_add(1);
Logger::log(Logger::RemoteFs, "Client " + std::to_string(id) + " connecting\n", Logger::INFO);
ClientCtx context{{}, {_ssl_ctx.get(), conn_fd, id}, {}};
try {
Helpers::init_nonblock(conn_fd);
context.transport.run();
for (;;) {
auto msg = context.transport.get_msg();
if (!msg)
break;
std::thread msg_proc([&context, msg, this] {
auto ret = this->handle_message(context, std::move(msg->data));
context.transport.send_message(std::make_shared<MsgWrapper>(msg->id, std::move(ret)));
});
msg_proc.detach();
}
} catch (std::exception& e) {
Logger::log(Logger::RemoteFs, std::string("Error: ") + e.what(), Logger::ERROR);
}
close(conn_fd);
_req_in_progress.fetch_sub(1);
std::lock_guard<std::mutex> lock(_req_in_progress_mutex);
Logger::log(Logger::RemoteFs, "Client " + std::to_string(id) + " finished\n", Logger::INFO);
_req_in_progress_cond.notify_all();
});
proc.detach();
}
void Server::run() {
protoent* proto = getprotobyname("tcp");
if (proto == NULL) {
throw ErrnoException("Could not get TCP protocol info");
}
int sock = socket(AF_INET, SOCK_STREAM, proto->p_proto);
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(_port);
addr.sin_addr = {_ip};
memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
#ifdef APPLE
addr.sin_len = sizeof(struct sockaddr_in),
#endif
Logger::log(
Logger::RemoteFs,
[&](std::ostream& os) {
os << "Listening on ";
for (int i = 0; i < 4; i++) {
os << static_cast<int>(reinterpret_cast<uint8_t*>(&_ip)[i]);
if (i != 3)
os << ".";
}
os << ":" << _port;
},
Logger::INFO);
if (bind(sock, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {
throw ErrnoException("Could not bind");
}
if (listen(sock, 1) < 0) {
throw ErrnoException("Could not listen");
}
Helpers::init_nonblock(sock);
try {
// while (!Signals::is_stopped()) {
while (true) {
try {
Helpers::poll_wait(sock, false, -1);
int conn = accept(sock, nullptr, nullptr);
if (conn == -1) {
throw ErrnoException("accept");
continue;
}
process_req(conn);
} catch (std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}
} catch (std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
Logger::log(Logger::RemoteFs, "Exiting", Logger::INFO);
{
std::unique_lock lock(_req_in_progress_mutex);
_req_in_progress_cond.wait(lock, [&] { return _req_in_progress == 0; });
}
close(sock);
}
#endif // TCPSERVER_IPP