mirror of
https://github.com/usatiuk/remotefs.git
synced 2025-10-28 15:37:48 +01:00
dump
This commit is contained in:
24
networking/CMakeLists.txt
Normal file
24
networking/CMakeLists.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
|
||||
add_library(networking
|
||||
include/Server.hpp
|
||||
src/Server.cpp
|
||||
include/Client.hpp
|
||||
src/Client.cpp
|
||||
include/Helpers.hpp
|
||||
src/Helpers.cpp
|
||||
src/AsyncSslTransport.cpp
|
||||
include/AsyncSslTransport.hpp
|
||||
src/AsyncSslClientTransport.cpp
|
||||
include/AsyncSslClientTransport.hpp
|
||||
include/AsyncSslServerTransport.hpp
|
||||
src/AsyncSslServerTransport.cpp
|
||||
)
|
||||
|
||||
target_include_directories(networking PUBLIC include)
|
||||
|
||||
find_package(OpenSSL REQUIRED)
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(FUSE REQUIRED IMPORTED_TARGET fuse)
|
||||
|
||||
target_link_libraries(networking PRIVATE utils OpenSSL::SSL OpenSSL::Crypto PkgConfig::FUSE)
|
||||
34
networking/include/AsyncSslClientTransport.hpp
Normal file
34
networking/include/AsyncSslClientTransport.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 14.12.2024.
|
||||
//
|
||||
|
||||
#ifndef ASYNCMESSAGECLIENT_HPP
|
||||
#define ASYNCMESSAGECLIENT_HPP
|
||||
|
||||
#include "AsyncSslTransport.hpp"
|
||||
|
||||
class AsyncSslClientTransport : public AsyncSslTransport {
|
||||
public:
|
||||
AsyncSslClientTransport(SSL_CTX* ssl_ctx, int fd) : AsyncSslTransport(ssl_ctx, fd) {}
|
||||
|
||||
using SharedMsgPromiseT = std::shared_ptr<std::promise<std::shared_ptr<MsgWrapper>>>;
|
||||
|
||||
std::future<std::shared_ptr<MsgWrapper>> send_msg(std::vector<uint8_t> message);
|
||||
std::vector<uint8_t> send_msg_and_wait(std::vector<uint8_t> message);
|
||||
|
||||
protected:
|
||||
void handle_message(std::shared_ptr<MsgWrapper> msg) override;
|
||||
void handle_fail() override {
|
||||
stop();
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
void before_entry() override;
|
||||
|
||||
private:
|
||||
std::unordered_map<decltype(MsgWrapper::id), SharedMsgPromiseT> _promises;
|
||||
uint64_t _msg_id = 0;
|
||||
std::mutex _promises_mutex;
|
||||
};
|
||||
|
||||
#endif // ASYNCMESSAGECLIENT_HPP
|
||||
32
networking/include/AsyncSslServerTransport.hpp
Normal file
32
networking/include/AsyncSslServerTransport.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 15.12.2024.
|
||||
//
|
||||
|
||||
#ifndef SERVERMESSAGEPUMP_HPP
|
||||
#define SERVERMESSAGEPUMP_HPP
|
||||
|
||||
#include "AsyncSslTransport.hpp"
|
||||
|
||||
class AsyncSslServerTransport : public AsyncSslTransport {
|
||||
public:
|
||||
AsyncSslServerTransport(SSL_CTX* ssl_ctx, int fd, int client_id) : AsyncSslTransport(ssl_ctx, fd), _client_id(client_id) {}
|
||||
|
||||
// Null if finished
|
||||
std::shared_ptr<MsgWrapper> get_msg();
|
||||
|
||||
protected:
|
||||
void handle_message(std::shared_ptr<MsgWrapper> msg) override;
|
||||
void handle_fail() override;
|
||||
|
||||
void before_entry() override;
|
||||
|
||||
private:
|
||||
int _client_id;
|
||||
|
||||
std::deque<std::shared_ptr<MsgWrapper>> _msgs;
|
||||
std::mutex _msgs_mutex;
|
||||
std::condition_variable _msgs_condition;
|
||||
};
|
||||
|
||||
|
||||
#endif // SERVERMESSAGEPUMP_HPP
|
||||
62
networking/include/AsyncSslTransport.hpp
Normal file
62
networking/include/AsyncSslTransport.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 14.12.2024.
|
||||
//
|
||||
|
||||
#ifndef MESSAGEPUMP_HPP
|
||||
#define MESSAGEPUMP_HPP
|
||||
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <future>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "Helpers.hpp"
|
||||
|
||||
class AsyncSslTransport {
|
||||
public:
|
||||
AsyncSslTransport(SSL_CTX* ssl_ctx, int fd);
|
||||
virtual ~AsyncSslTransport() = 0;
|
||||
void run();
|
||||
|
||||
void send_message(std::shared_ptr<MsgWrapper> msg);
|
||||
|
||||
bool is_failed() const { return _failed; }
|
||||
bool is_stopped() const { return _stopped; }
|
||||
|
||||
void stop();
|
||||
|
||||
protected:
|
||||
virtual void handle_message(std::shared_ptr<MsgWrapper> msg) = 0;
|
||||
virtual void handle_fail() = 0;
|
||||
|
||||
virtual void before_entry() {}
|
||||
|
||||
std::unique_ptr<SSL, decltype(&SSL_free)> _ssl{nullptr, &SSL_free};
|
||||
int _fd;
|
||||
|
||||
private:
|
||||
void thread_entry();
|
||||
|
||||
std::atomic<bool> _stopped = 0;
|
||||
std::mutex _stopped_mutex;
|
||||
std::condition_variable _stopped_condition;
|
||||
|
||||
std::thread _thread;
|
||||
|
||||
std::mutex _to_send_mutex;
|
||||
int _to_send_notif_pipe[2];
|
||||
std::deque<std::shared_ptr<MsgWrapper>> _to_send;
|
||||
|
||||
std::atomic<bool> _failed;
|
||||
|
||||
AsyncSslTransport(const AsyncSslTransport& other) = delete;
|
||||
AsyncSslTransport(AsyncSslTransport&& other) noexcept = delete;
|
||||
AsyncSslTransport& operator=(const AsyncSslTransport& other) = delete;
|
||||
AsyncSslTransport& operator=(AsyncSslTransport&& other) noexcept = delete;
|
||||
};
|
||||
|
||||
#endif // MESSAGEPUMP_HPP
|
||||
43
networking/include/Client.hpp
Normal file
43
networking/include/Client.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
//
|
||||
// Created by stepus53 on 11.12.24.
|
||||
//
|
||||
|
||||
#ifndef TCPCLIENT_HPP
|
||||
#define TCPCLIENT_HPP
|
||||
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "AsyncSslClientTransport.hpp"
|
||||
|
||||
class Client {
|
||||
public:
|
||||
Client(uint16_t port, std::string ip, std::string cert_path, std::string key_path);
|
||||
|
||||
void run();
|
||||
|
||||
AsyncSslClientTransport& transport() { return *_transport; }
|
||||
|
||||
protected:
|
||||
uint16_t _port;
|
||||
std::string _ip;
|
||||
|
||||
std::string _cert_path;
|
||||
std::string _key_path;
|
||||
|
||||
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> _ssl_ctx;
|
||||
std::unique_ptr<SSL, decltype(&SSL_free)> _ssl{nullptr, &SSL_free};
|
||||
|
||||
int _sock;
|
||||
size_t _msg_id = 0;
|
||||
|
||||
private:
|
||||
std::optional<AsyncSslClientTransport> _transport;
|
||||
};
|
||||
|
||||
#endif // TCPCLIENT_HPP
|
||||
37
networking/include/Helpers.hpp
Normal file
37
networking/include/Helpers.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
|
||||
#ifndef NETWORKING_HELPERS_H
|
||||
#define NETWORKING_HELPERS_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "Options.h"
|
||||
#include "stuff.hpp"
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
using MsgIdType = uint64_t;
|
||||
|
||||
struct MsgWrapper {
|
||||
MsgIdType id;
|
||||
std::vector<uint8_t> data;
|
||||
};
|
||||
|
||||
struct MsgHeader {
|
||||
uint64_t id;
|
||||
uint64_t len;
|
||||
} __attribute__((packed));
|
||||
|
||||
namespace Helpers {
|
||||
void poll_wait(int fd, bool write, int timeout = checked_cast<int>(Options::get<size_t>("timeout")) * 1000);
|
||||
bool SSL_write(SSL* ctx, int fd, const std::vector<uint8_t>& buf);
|
||||
void init_nonblock(int fd);
|
||||
std::vector<uint8_t> SSL_read_n(SSL* ctx, int fd, size_t n);
|
||||
MsgWrapper SSL_read_msg(SSL* ctx, int fd);
|
||||
void SSL_send_msg(SSL* ctx, int fd, const MsgWrapper& buf);
|
||||
} // namespace Helpers
|
||||
|
||||
#endif
|
||||
53
networking/include/Server.hpp
Normal file
53
networking/include/Server.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 09.12.2024.
|
||||
//
|
||||
|
||||
#ifndef TCPSERVER_HPP
|
||||
#define TCPSERVER_HPP
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "AsyncSslServerTransport.hpp"
|
||||
#include "Helpers.hpp"
|
||||
|
||||
struct ClientCtx {
|
||||
std::optional<std::string> client_name;
|
||||
AsyncSslServerTransport transport;
|
||||
std::mutex ctx_mutex;
|
||||
};
|
||||
|
||||
class Server {
|
||||
public:
|
||||
Server(uint16_t port, uint32_t ip, std::string cert_path, std::string key_path);
|
||||
|
||||
void run();
|
||||
|
||||
protected:
|
||||
uint16_t _port;
|
||||
uint32_t _ip;
|
||||
|
||||
std::string _cert_path;
|
||||
std::string _key_path;
|
||||
|
||||
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> _ssl_ctx;
|
||||
|
||||
void process_req(int conn_fd);
|
||||
|
||||
virtual std::vector<uint8_t> handle_message(ClientCtx& client, std::vector<uint8_t> data) = 0;
|
||||
|
||||
private:
|
||||
std::atomic<int> _total_req{0};
|
||||
std::atomic<int> _req_in_progress{0};
|
||||
std::mutex _req_in_progress_mutex;
|
||||
std::condition_variable _req_in_progress_cond;
|
||||
};
|
||||
|
||||
|
||||
#endif // TCPSERVER_HPP
|
||||
60
networking/src/AsyncSslClientTransport.cpp
Normal file
60
networking/src/AsyncSslClientTransport.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 14.12.2024.
|
||||
//
|
||||
|
||||
#include "AsyncSslClientTransport.hpp"
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "Logger.h"
|
||||
|
||||
std::future<std::shared_ptr<MsgWrapper>> AsyncSslClientTransport::send_msg(std::vector<uint8_t> message) {
|
||||
auto promise = std::make_shared<std::promise<std::shared_ptr<MsgWrapper>>>();
|
||||
|
||||
decltype(_msg_id) id;
|
||||
{
|
||||
std::lock_guard lock(_promises_mutex);
|
||||
id = _msg_id++;
|
||||
_promises.emplace(id, promise);
|
||||
}
|
||||
|
||||
send_message(std::make_shared<MsgWrapper>(id, std::move(message)));
|
||||
|
||||
return promise->get_future();
|
||||
}
|
||||
|
||||
std::vector<uint8_t> AsyncSslClientTransport::send_msg_and_wait(std::vector<uint8_t> message) {
|
||||
auto future = send_msg(std::move(message));
|
||||
return future.get()->data;
|
||||
}
|
||||
|
||||
void AsyncSslClientTransport::before_entry() {
|
||||
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Connecting"; }, Logger::INFO);
|
||||
|
||||
int r;
|
||||
while ((r = SSL_connect(_ssl.get())) <= 0) {
|
||||
ERR_print_errors_fp(stderr);
|
||||
int err = SSL_get_error(_ssl.get(), r);
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||
Helpers::poll_wait(_fd, err == SSL_ERROR_WANT_WRITE);
|
||||
continue;
|
||||
}
|
||||
throw OpenSSLException("SSL_connect() failed");
|
||||
}
|
||||
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Connected"; }, Logger::INFO);
|
||||
}
|
||||
|
||||
void AsyncSslClientTransport::handle_message(std::shared_ptr<MsgWrapper> msg) {
|
||||
std::lock_guard lock(_promises_mutex);
|
||||
auto future_it = _promises.find(msg->id);
|
||||
|
||||
if (future_it == _promises.end()) {
|
||||
Logger::log(Logger::RemoteFs, "Could not find future for msg with id " + std::to_string(msg->id),
|
||||
Logger::ERROR);
|
||||
return;
|
||||
}
|
||||
|
||||
future_it->second->set_value(msg);
|
||||
_promises.erase(future_it);
|
||||
}
|
||||
51
networking/src/AsyncSslServerTransport.cpp
Normal file
51
networking/src/AsyncSslServerTransport.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 15.12.2024.
|
||||
//
|
||||
|
||||
#include "AsyncSslServerTransport.hpp"
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "Logger.h"
|
||||
|
||||
void AsyncSslServerTransport::handle_message(std::shared_ptr<MsgWrapper> msg) {
|
||||
std::lock_guard lock(_msgs_mutex);
|
||||
_msgs.emplace_back(msg);
|
||||
_msgs_condition.notify_all();
|
||||
}
|
||||
|
||||
|
||||
void AsyncSslServerTransport::handle_fail() {
|
||||
_msgs_condition.notify_all();
|
||||
}
|
||||
|
||||
void AsyncSslServerTransport::before_entry() {
|
||||
int r;
|
||||
while ((r = SSL_accept(_ssl.get())) <= 0) {
|
||||
ERR_print_errors_fp(stderr);
|
||||
int err = SSL_get_error(_ssl.get(), r);
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||
Helpers::poll_wait(_fd, err == SSL_ERROR_WANT_WRITE);
|
||||
continue;
|
||||
}
|
||||
throw OpenSSLException("SSL_accept() failed");
|
||||
}
|
||||
Logger::log(Logger::RemoteFs, "Client " + std::to_string(_client_id) + " connected\n", Logger::INFO);
|
||||
}
|
||||
|
||||
std::shared_ptr<MsgWrapper> AsyncSslServerTransport::get_msg() {
|
||||
std::unique_lock lock(_msgs_mutex);
|
||||
_msgs_condition.wait(lock, [&] { return !_msgs.empty() || is_failed() || is_stopped(); });
|
||||
|
||||
if (is_failed())
|
||||
return nullptr;
|
||||
|
||||
auto ret = _msgs.begin();
|
||||
|
||||
auto real_ret = *ret;
|
||||
|
||||
_msgs.erase(_msgs.begin());
|
||||
|
||||
return real_ret;
|
||||
}
|
||||
249
networking/src/AsyncSslTransport.cpp
Normal file
249
networking/src/AsyncSslTransport.cpp
Normal file
@@ -0,0 +1,249 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 14.12.2024.
|
||||
//
|
||||
|
||||
#include "AsyncSslTransport.hpp"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <poll.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "Logger.h"
|
||||
#include "Serialize.hpp"
|
||||
#include "stuff.hpp"
|
||||
|
||||
AsyncSslTransport::AsyncSslTransport(SSL_CTX* ssl_ctx, int fd) : _ssl(SSL_new(ssl_ctx), &SSL_free), _fd(fd) {
|
||||
SSL_set_fd(_ssl.get(), _fd);
|
||||
pipe(_to_send_notif_pipe);
|
||||
Helpers::init_nonblock(_fd);
|
||||
}
|
||||
|
||||
AsyncSslTransport::~AsyncSslTransport() {
|
||||
stop();
|
||||
_thread.join();
|
||||
}
|
||||
|
||||
void AsyncSslTransport::stop() {
|
||||
std::unique_lock lock(_stopped_mutex);
|
||||
_stopped_condition.notify_all();
|
||||
_stopped = true;
|
||||
write(_to_send_notif_pipe[1], "1", 1);
|
||||
}
|
||||
|
||||
void AsyncSslTransport::send_message(std::shared_ptr<MsgWrapper> msg) {
|
||||
std::unique_lock lock(_to_send_mutex);
|
||||
_to_send.push_back(std::move(msg));
|
||||
write(_to_send_notif_pipe[1], "1", 1);
|
||||
}
|
||||
|
||||
void AsyncSslTransport::thread_entry() {
|
||||
try {
|
||||
before_entry();
|
||||
|
||||
bool sending = false;
|
||||
std::vector<uint8_t> to_send_buf{};
|
||||
size_t cur_sent = 0;
|
||||
|
||||
bool reading_msg = false; // False if reading header, true if message
|
||||
std::vector<uint8_t> read_buf;
|
||||
size_t cur_read{};
|
||||
uint64_t msg_id = 0;
|
||||
size_t msg_len = sizeof(MsgHeader); // Message length to read if reading message, or header len
|
||||
read_buf.resize(msg_len);
|
||||
|
||||
while (!_stopped) {
|
||||
if (!sending) {
|
||||
std::shared_ptr<MsgWrapper> to_send_now{};
|
||||
|
||||
std::unique_lock lock(_to_send_mutex);
|
||||
auto next = _to_send.begin();
|
||||
if (next != _to_send.end()) {
|
||||
to_send_now = *next;
|
||||
_to_send.erase(next);
|
||||
}
|
||||
|
||||
if (to_send_now) {
|
||||
MsgHeader header{};
|
||||
header.id = htobe64(to_send_now->id);
|
||||
header.len = htobe64(to_send_now->data.size());
|
||||
to_send_buf.resize(sizeof(MsgHeader) + to_send_now->data.size());
|
||||
memcpy(to_send_buf.data(), &header, sizeof(MsgHeader));
|
||||
memcpy(to_send_buf.data() + sizeof(MsgHeader), to_send_now->data.data(), to_send_now->data.size());
|
||||
sending = true;
|
||||
Logger::log(
|
||||
Logger::RemoteFs, [&](std::ostream& os) { os << "Started sending message " << to_send_now->id; },
|
||||
Logger::DEBUG);
|
||||
}
|
||||
}
|
||||
|
||||
while (sending) {
|
||||
size_t written_now = 0;
|
||||
int ret;
|
||||
if ((ret = SSL_write_ex(_ssl.get(), to_send_buf.data() + cur_sent, to_send_buf.size() - cur_sent,
|
||||
&written_now)) <= 0) {
|
||||
int err = SSL_get_error(_ssl.get(), ret);
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||
break;
|
||||
}
|
||||
|
||||
throw OpenSSLException("write failed");
|
||||
}
|
||||
Logger::log(
|
||||
Logger::RemoteFs,
|
||||
[&](std::ostream& os) {
|
||||
os << "Written " << written_now;
|
||||
if (Logger::en_level(Logger::RemoteFs, Logger::TRACE)) {
|
||||
os << ": ";
|
||||
for (size_t i = 0; i < written_now; i++) {
|
||||
os << std::setw(2) << std::setfill('0') << std::hex
|
||||
<< (int) to_send_buf[i + cur_sent] << " ";
|
||||
}
|
||||
}
|
||||
},
|
||||
Logger::DEBUG);
|
||||
cur_sent += written_now;
|
||||
|
||||
if (cur_sent == to_send_buf.size()) {
|
||||
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Finished sending"; }, Logger::DEBUG);
|
||||
to_send_buf.resize(0);
|
||||
cur_sent = 0;
|
||||
sending = false;
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
int ret;
|
||||
size_t read_now = 0;
|
||||
if ((ret = SSL_read_ex(_ssl.get(), read_buf.data() + cur_read, msg_len, &read_now)) <= 0) {
|
||||
int err = SSL_get_error(_ssl.get(), ret);
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
||||
break;
|
||||
}
|
||||
|
||||
throw OpenSSLException("read failed");
|
||||
}
|
||||
Logger::log(
|
||||
Logger::RemoteFs,
|
||||
[&](std::ostream& os) {
|
||||
os << "Read " << read_now;
|
||||
if (Logger::en_level(Logger::RemoteFs, Logger::TRACE)) {
|
||||
os << ": ";
|
||||
for (size_t i = 0; i < read_now; i++) {
|
||||
os << std::setw(2) << std::setfill('0') << std::hex << (int) read_buf[i + cur_read]
|
||||
<< " ";
|
||||
}
|
||||
}
|
||||
},
|
||||
Logger::DEBUG);
|
||||
cur_read += read_now;
|
||||
|
||||
if (cur_read == msg_len) {
|
||||
if (reading_msg) {
|
||||
handle_message(std::make_shared<MsgWrapper>(msg_id, std::move(read_buf)));
|
||||
reading_msg = false;
|
||||
msg_len = sizeof(MsgHeader);
|
||||
Logger::log(
|
||||
Logger::RemoteFs,
|
||||
[&](std::ostream& os) { os << "Finished receiving message " << msg_id; },
|
||||
Logger::DEBUG);
|
||||
} else {
|
||||
MsgHeader hdr;
|
||||
memcpy(&hdr, read_buf.data(), sizeof(hdr));
|
||||
reading_msg = true;
|
||||
msg_len = be64toh(hdr.len);
|
||||
msg_id = be64toh(hdr.id);
|
||||
Logger::log(
|
||||
Logger::RemoteFs,
|
||||
[&](std::ostream& os) { os << "Started receiving message " << msg_id; }, Logger::DEBUG);
|
||||
}
|
||||
|
||||
read_buf.clear();
|
||||
read_buf.resize(msg_len);
|
||||
cur_read = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pollfd fds[2];
|
||||
|
||||
fds[0].fd = _fd;
|
||||
fds[0].events = POLLIN;
|
||||
if (sending)
|
||||
fds[0].events |= POLLOUT;
|
||||
fds[0].revents = 0;
|
||||
|
||||
fds[1].fd = _to_send_notif_pipe[0];
|
||||
fds[1].events = POLLIN;
|
||||
fds[1].revents = 0;
|
||||
|
||||
Logger::log(Logger::RemoteFs, [&](std::ostream& os) { os << "Waiting"; }, Logger::DEBUG);
|
||||
|
||||
|
||||
if (poll(fds, 2, checked_cast<int>(Options::get<size_t>("timeout")) * 1000) < 0) {
|
||||
if (errno == EINTR)
|
||||
return;
|
||||
throw ErrnoException("Could not poll");
|
||||
}
|
||||
|
||||
if (!(fds[0].revents & POLLOUT) && !(fds[1].revents & POLLIN) && !(fds[0].revents & POLLIN)) {
|
||||
throw ErrnoException("Could not poll (timeout?)");
|
||||
}
|
||||
|
||||
if (fds[1].revents & POLLIN) {
|
||||
char temp;
|
||||
Logger::log(
|
||||
Logger::RemoteFs, [&](std::ostream& os) { os << "Received message on pipe"; }, Logger::DEBUG);
|
||||
read(_to_send_notif_pipe[0], &temp, 1);
|
||||
}
|
||||
|
||||
// Logger::log(
|
||||
// Logger::Server,
|
||||
// [&](std::ostream& os) {
|
||||
// os << "Sent " << (*next)->id << " with " << (*next)->data.size() << " bytes: ";
|
||||
// if (Logger::en_level(Logger::Server, Logger::TRACE))
|
||||
// for (unsigned char i: (*next)->data)
|
||||
// os << (int) i << " ";
|
||||
// },
|
||||
// Logger::DEBUG);
|
||||
// Helpers::SSL_send_msg(_ssl, _fd, **next);
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
_failed = true;
|
||||
Logger::log(Logger::RemoteFs, e.what(), Logger::ERROR);
|
||||
handle_fail();
|
||||
}
|
||||
SSL_shutdown(_ssl.get());
|
||||
OPENSSL_thread_stop();
|
||||
}
|
||||
|
||||
// void SSLMessagePump::receiver_entry() {
|
||||
// try {
|
||||
// before_receiver();
|
||||
//
|
||||
// while (!_stopped) {
|
||||
// auto msg = std::make_shared<MsgWrapper>(Helpers::SSL_read_msg(_ssl, _fd));
|
||||
// Logger::log(
|
||||
// Logger::Server,
|
||||
// [&](std::ostream& os) {
|
||||
// os << "Received " << msg->id << " with " << msg->data.size() << " bytes: ";
|
||||
// if (Logger::en_level(Logger::Server, Logger::TRACE))
|
||||
// for (unsigned char i: msg->data)
|
||||
// os << (int) i << " ";
|
||||
// },
|
||||
// Logger::DEBUG);
|
||||
// handle_message(msg);
|
||||
// }
|
||||
// } catch (std::exception& e) {
|
||||
// _failed = true;
|
||||
// Logger::log(Logger::Server, e.what(), Logger::ERROR);
|
||||
// handle_fail();
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
void AsyncSslTransport::run() {
|
||||
_thread = std::thread([&]() { thread_entry(); });
|
||||
}
|
||||
72
networking/src/Client.cpp
Normal file
72
networking/src/Client.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
//
|
||||
// Created by stepus53 on 11.12.24.
|
||||
//
|
||||
|
||||
#include "Client.hpp"
|
||||
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <signal.h>
|
||||
#include <sys/select.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <openssl/err.h>
|
||||
|
||||
#include "Exception.h"
|
||||
#include "Helpers.hpp"
|
||||
#include "Logger.h"
|
||||
|
||||
// From https://wiki.openssl.org/index.php/Simple_TLS_Server
|
||||
static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> create_context() {
|
||||
const SSL_METHOD* method;
|
||||
SSL_CTX* ctx;
|
||||
|
||||
method = TLS_client_method();
|
||||
|
||||
ctx = SSL_CTX_new(method);
|
||||
if (!ctx) {
|
||||
throw OpenSSLException("Unable to create SSL context");
|
||||
}
|
||||
|
||||
return {ctx, &SSL_CTX_free};
|
||||
}
|
||||
|
||||
Client::Client(uint16_t port, std::string ip, std::string cert_path, std::string key_path) :
|
||||
_port(port), _ip(ip), _cert_path(cert_path), _key_path(key_path), _ssl_ctx(create_context()) {
|
||||
SSL_CTX_set_verify(_ssl_ctx.get(), SSL_VERIFY_PEER, nullptr);
|
||||
if (SSL_CTX_load_verify_locations(_ssl_ctx.get(), cert_path.c_str(), nullptr) <= 0) {
|
||||
throw OpenSSLException("Unable to read certificate file");
|
||||
}
|
||||
}
|
||||
|
||||
void Client::run() {
|
||||
protoent* proto = getprotobyname("tcp");
|
||||
if (proto == NULL) {
|
||||
throw ErrnoException("Could not get TCP protocol info");
|
||||
}
|
||||
|
||||
_sock = socket(AF_INET, SOCK_STREAM, proto->p_proto);
|
||||
sockaddr_in addr;
|
||||
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(_port);
|
||||
|
||||
if (inet_pton(AF_INET, _ip.c_str(), &addr.sin_addr) <= 0)
|
||||
throw ErrnoException("inet_pton()");
|
||||
|
||||
memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
|
||||
#ifdef APPLE
|
||||
addr.sin_len = sizeof(struct sockaddr_in),
|
||||
#endif
|
||||
|
||||
if (connect(_sock, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) throw ErrnoException("connect()");
|
||||
|
||||
_transport.emplace(_ssl_ctx.get(), _sock);
|
||||
_transport->run();
|
||||
}
|
||||
93
networking/src/Helpers.cpp
Normal file
93
networking/src/Helpers.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 09.12.2024.
|
||||
//
|
||||
|
||||
#include "Helpers.hpp"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <poll.h>
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "Exception.h"
|
||||
#include "Options.h"
|
||||
#include "stuff.hpp"
|
||||
|
||||
void Helpers::poll_wait(int fd, bool write, int timeout) {
|
||||
pollfd p;
|
||||
p.fd = fd;
|
||||
p.events = write ? POLLOUT : POLLIN;
|
||||
p.revents = 0;
|
||||
if (poll(&p, 1, timeout) < 0) {
|
||||
if (errno == EINTR)
|
||||
return;
|
||||
throw ErrnoException("Could not poll");
|
||||
}
|
||||
if (!(p.revents & (write ? POLLOUT : POLLIN))) {
|
||||
throw ErrnoException("Could not poll (timeout?)");
|
||||
}
|
||||
}
|
||||
|
||||
void Helpers::init_nonblock(int fd) {
|
||||
int r = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
|
||||
if (r < 0) {
|
||||
throw ErrnoException("Could not set nonblocking mode");
|
||||
}
|
||||
}
|
||||
|
||||
bool Helpers::SSL_write(SSL* ctx, int fd, const std::vector<uint8_t>& buf) {
|
||||
for (size_t written = 0; written < buf.size();) {
|
||||
size_t written_now = 0;
|
||||
int ret;
|
||||
while ((ret = SSL_write_ex(ctx, buf.data() + written, buf.size() - written, &written_now)) <= 0) {
|
||||
int err = SSL_get_error(ctx, ret);
|
||||
if (err == SSL_ERROR_WANT_WRITE) {
|
||||
poll_wait(fd, true);
|
||||
continue;
|
||||
}
|
||||
|
||||
throw OpenSSLException("write failed");
|
||||
}
|
||||
written += written_now;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> Helpers::SSL_read_n(SSL* ctx, int fd, size_t n) {
|
||||
size_t nread = 0;
|
||||
size_t read_now = 0;
|
||||
std::vector<uint8_t> buf(n);
|
||||
int ret;
|
||||
|
||||
while (nread < n) {
|
||||
while ((ret = SSL_read_ex(ctx, buf.data() + nread, n, &read_now)) <= 0) {
|
||||
int err = SSL_get_error(ctx, ret);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
Helpers::poll_wait(fd, false);
|
||||
continue;
|
||||
}
|
||||
throw OpenSSLException("read failed");
|
||||
}
|
||||
nread += read_now;
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
MsgWrapper Helpers::SSL_read_msg(SSL* ctx, int fd) {
|
||||
auto hdrBuf = SSL_read_n(ctx, fd, sizeof(MsgHeader));
|
||||
MsgHeader hdr;
|
||||
memcpy(&hdr, hdrBuf.data(), sizeof(hdr));
|
||||
auto data = SSL_read_n(ctx, fd, hdr.len);
|
||||
return {hdr.id, std::move(data)};
|
||||
}
|
||||
|
||||
void Helpers::SSL_send_msg(SSL* ctx, int fd, const MsgWrapper& buf) {
|
||||
MsgHeader header{};
|
||||
header.id = buf.id;
|
||||
header.len = checked_cast<uint32_t>(buf.data.size());
|
||||
std::vector<uint8_t> hdrBuf(sizeof(MsgHeader));
|
||||
memcpy(hdrBuf.data(), &header, sizeof(MsgHeader));
|
||||
SSL_write(ctx, fd, hdrBuf);
|
||||
SSL_write(ctx, fd, buf.data);
|
||||
}
|
||||
171
networking/src/Server.cpp
Normal file
171
networking/src/Server.cpp
Normal file
@@ -0,0 +1,171 @@
|
||||
//
|
||||
// Created by Stepan Usatiuk on 09.12.2024.
|
||||
//
|
||||
|
||||
#ifndef TCPSERVER_IPP
|
||||
#define TCPSERVER_IPP
|
||||
|
||||
#include "Server.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <signal.h>
|
||||
#include <sys/select.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "Exception.h"
|
||||
#include "Helpers.hpp"
|
||||
#include "Logger.h"
|
||||
|
||||
// From https://wiki.openssl.org/index.php/Simple_TLS_Server
|
||||
static std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> create_context() {
|
||||
const SSL_METHOD* method;
|
||||
SSL_CTX* ctx;
|
||||
|
||||
method = TLS_server_method();
|
||||
|
||||
ctx = SSL_CTX_new(method);
|
||||
if (!ctx) {
|
||||
throw OpenSSLException("Unable to create SSL context");
|
||||
}
|
||||
|
||||
return {ctx, &SSL_CTX_free};
|
||||
}
|
||||
|
||||
static void configure_context(SSL_CTX* ctx, const std::string& cert_path, const std::string& key_path) {
|
||||
/* Set the key and cert */
|
||||
if (SSL_CTX_use_certificate_file(ctx, cert_path.c_str(), SSL_FILETYPE_PEM) <= 0) {
|
||||
throw OpenSSLException("Unable to read certificate file");
|
||||
}
|
||||
|
||||
if (SSL_CTX_use_PrivateKey_file(ctx, key_path.c_str(), SSL_FILETYPE_PEM) <= 0) {
|
||||
throw OpenSSLException("Unable to read private key file");
|
||||
}
|
||||
}
|
||||
|
||||
Server::Server(uint16_t port, uint32_t ip, std::string cert_path, std::string key_path) :
|
||||
_port(port), _ip(ip), _cert_path(std::move(cert_path)), _key_path(std::move(key_path)), _ssl_ctx(create_context()) {
|
||||
configure_context(_ssl_ctx.get(), _cert_path, _key_path);
|
||||
}
|
||||
|
||||
void Server::process_req(int conn_fd) {
|
||||
_req_in_progress.fetch_add(1);
|
||||
std::thread proc([=, this] {
|
||||
int id = _total_req.fetch_add(1);
|
||||
|
||||
Logger::log(Logger::RemoteFs, "Client " + std::to_string(id) + " connecting\n", Logger::INFO);
|
||||
|
||||
ClientCtx context{{}, {_ssl_ctx.get(), conn_fd, id}, {}};
|
||||
|
||||
try {
|
||||
Helpers::init_nonblock(conn_fd);
|
||||
|
||||
context.transport.run();
|
||||
|
||||
for (;;) {
|
||||
auto msg = context.transport.get_msg();
|
||||
|
||||
if (!msg)
|
||||
break;
|
||||
|
||||
std::thread msg_proc([&context, msg, this] {
|
||||
auto ret = this->handle_message(context, std::move(msg->data));
|
||||
context.transport.send_message(std::make_shared<MsgWrapper>(msg->id, std::move(ret)));
|
||||
});
|
||||
msg_proc.detach();
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
Logger::log(Logger::RemoteFs, std::string("Error: ") + e.what(), Logger::ERROR);
|
||||
}
|
||||
|
||||
close(conn_fd);
|
||||
_req_in_progress.fetch_sub(1);
|
||||
std::lock_guard<std::mutex> lock(_req_in_progress_mutex);
|
||||
Logger::log(Logger::RemoteFs, "Client " + std::to_string(id) + " finished\n", Logger::INFO);
|
||||
_req_in_progress_cond.notify_all();
|
||||
});
|
||||
proc.detach();
|
||||
}
|
||||
|
||||
void Server::run() {
|
||||
protoent* proto = getprotobyname("tcp");
|
||||
if (proto == NULL) {
|
||||
throw ErrnoException("Could not get TCP protocol info");
|
||||
}
|
||||
|
||||
int sock = socket(AF_INET, SOCK_STREAM, proto->p_proto);
|
||||
sockaddr_in addr;
|
||||
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(_port);
|
||||
addr.sin_addr = {_ip};
|
||||
memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
|
||||
#ifdef APPLE
|
||||
addr.sin_len = sizeof(struct sockaddr_in),
|
||||
#endif
|
||||
|
||||
Logger::log(
|
||||
Logger::RemoteFs,
|
||||
[&](std::ostream& os) {
|
||||
os << "Listening on ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
os << static_cast<int>(reinterpret_cast<uint8_t*>(&_ip)[i]);
|
||||
if (i != 3)
|
||||
os << ".";
|
||||
}
|
||||
os << ":" << _port;
|
||||
},
|
||||
Logger::INFO);
|
||||
|
||||
if (bind(sock, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {
|
||||
throw ErrnoException("Could not bind");
|
||||
}
|
||||
|
||||
if (listen(sock, 1) < 0) {
|
||||
throw ErrnoException("Could not listen");
|
||||
}
|
||||
|
||||
Helpers::init_nonblock(sock);
|
||||
|
||||
try {
|
||||
// while (!Signals::is_stopped()) {
|
||||
while (true) {
|
||||
try {
|
||||
Helpers::poll_wait(sock, false, -1);
|
||||
int conn = accept(sock, nullptr, nullptr);
|
||||
if (conn == -1) {
|
||||
throw ErrnoException("accept");
|
||||
continue;
|
||||
}
|
||||
process_req(conn);
|
||||
} catch (std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
Logger::log(Logger::RemoteFs, "Exiting", Logger::INFO);
|
||||
{
|
||||
std::unique_lock lock(_req_in_progress_mutex);
|
||||
_req_in_progress_cond.wait(lock, [&] { return _req_in_progress == 0; });
|
||||
}
|
||||
|
||||
close(sock);
|
||||
}
|
||||
|
||||
|
||||
#endif // TCPSERVER_IPP
|
||||
Reference in New Issue
Block a user