Skiplist fix destructors not being called

This commit is contained in:
2024-04-12 23:31:49 +02:00
parent 570826b1f8
commit 073d6c3af6
9 changed files with 170 additions and 96 deletions

View File

@@ -41,6 +41,9 @@ if (TEST_MODE)
add_compile_options(-fsanitize=address -fsanitize=undefined -fno-sanitize-recover)
add_link_options(-fsanitize=address -fsanitize=undefined -fno-sanitize-recover)
add_compile_options(-rdynamic)
add_link_options(-rdynamic)
set_target_properties(kernel.elf PROPERTIES EXCLUDE_FROM_ALL 1 EXCLUDE_FROM_DEFAULT_BUILD 1)
set_target_properties(syscalls_interface PROPERTIES EXCLUDE_FROM_ALL 1 EXCLUDE_FROM_DEFAULT_BUILD 1)
add_subdirectory(./unit-tests/)

View File

@@ -307,6 +307,8 @@ void kfree(void *addr) {
void *krealloc(void *addr, size_t newsize) {
assert(initialized);
if (addr == nullptr) return kmalloc(newsize);
struct HeapEntry *info = (struct HeapEntry *) (addr - (sizeof(struct HeapEntry)));
assert2(info->magic == KERN_HeapMagicTaken, "Bad realloc!");

View File

@@ -11,6 +11,7 @@
#include <utility>
#include "assert.h"
#include "kmem.hpp"
extern "C" int rand(void);
@@ -19,22 +20,27 @@ class SkipListBase {
protected:
static constexpr size_t maxL{31};
class NodeAllocator;
struct Node {
friend NodeAllocator;
Node *next[maxL + 1] = {nullptr};
Node *before = nullptr;
bool end = false;
Data &get() {
assert(!end);
assert(!end());
return *std::launder(reinterpret_cast<Data *>(&_data[0]));
}
const Data &get() const {
assert(!end);
assert(!end());
return *std::launder(reinterpret_cast<const Data *>(&_data[0]));
}
bool end() const { return _end; }
private:
alignas(Data) std::array<unsigned char, sizeof(Data)> _data;
bool _end;
};
class NodeAllocator {
@@ -42,12 +48,26 @@ protected:
Node *nodes[size];
int top = -1;
Node *get() {
Node *node;
if (top == -1)
node = static_cast<Node *>(kmalloc(sizeof(Node)));
else
node = nodes[top--];
node->_end = false;
node->before = nullptr;
node->next[0] = nullptr;
return node;
}
public:
NodeAllocator() noexcept = default;
~NodeAllocator() noexcept {
for (int i = top; i >= 0; i--) {
delete nodes[i];
kfree(nodes[i]);
}
}
@@ -58,26 +78,25 @@ protected:
//
void push(Node *&e) {
if (!e->end()) std::destroy_at(&e->get());
if (top >= size - 1) {
delete e;
kfree(e);
return;
}
nodes[++top] = e;
}
template<class... Args>
Node *get(Args &&...args) {
Node *ret = get();
new (&ret->get()) Data(std::forward<Args>(args)...);
return ret;
}
Node *get() {
if (top == -1) {
return new Node;
}
Node *node = nodes[top--];
node->end = false;
node->before = nullptr;
node->next[0] = nullptr;
return node;
Node *get_end() {
Node *ret = get();
ret->_end = true;
return ret;
}
};
@@ -96,10 +115,8 @@ protected:
size_t curL = 0;
SkipListBase() noexcept {
root = (Node *) nodeAllocator.get();
root->end = true;
endnode = (Node *) nodeAllocator.get();
endnode->end = true;
root = (Node *) nodeAllocator.get_end();
endnode = (Node *) nodeAllocator.get_end();
endnode->before = root;
for (size_t i = 0; i <= maxL; i++) {
@@ -112,9 +129,6 @@ public:
~SkipListBase() noexcept {
auto cur = root;
while (cur != nullptr) {
if (!cur->end)
std::destroy_at(&cur->get());
auto prev = cur;
cur = cur->next[0];
nodeAllocator.push(prev);
@@ -124,7 +138,7 @@ public:
SkipListBase(SkipListBase const &l) noexcept : SkipListBase() {
toUpdate[0] = root;
for (auto n = l.root->next[0]; n != nullptr && !n->end; n = n->next[0]) {
for (auto n = l.root->next[0]; n != nullptr && !n->end(); n = n->next[0]) {
size_t newLevel = randomL();
if (newLevel > curL) {
@@ -133,8 +147,7 @@ public:
curL = newLevel;
}
auto newNode = (Node *) nodeAllocator.get();
new (&newNode->get()) Data(n->get());
auto newNode = (Node *) nodeAllocator.get(std::move(n->get()));
newNode->before = toUpdate[0];
if (toUpdate[0]->next[0] != nullptr) toUpdate[0]->next[0]->before = newNode;
@@ -246,13 +259,13 @@ protected:
public:
template<typename L, typename R>
friend bool operator==(const SkipListBaseIteratorBase<L> &a, const SkipListBaseIteratorBase<R> &b) {
if (a.n->end && b.n->end) return true;
if (a.n->end() && b.n->end()) return true;
return a.n == b.n;
};
template<typename L, typename R>
friend bool operator!=(const SkipListBaseIteratorBase<L> &a, const SkipListBaseIteratorBase<R> &b) {
if (a.n->end && (a.n->end == b.n->end)) return false;
if (a.n->end() && (a.n->end() == b.n->end())) return false;
return a.n != b.n;
};
@@ -271,7 +284,7 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--) {
while (!cur->next[i]->end && Comparator()(cur->next[i]->get(), *begin))
while (!cur->next[i]->end() && Comparator()(cur->next[i]->get(), *begin))
cur = cur->next[i];
toUpdate[i] = cur;
}
@@ -307,13 +320,13 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--) {
while (!cur->next[i]->end && cur->next[i] != k.n)
while (!cur->next[i]->end() && cur->next[i] != k.n)
cur = cur->next[i];
toUpdate[i] = cur;
}
cur = cur->next[0];
if (cur->end || cur != k.n) return {cur, false};
if (cur->end() || cur != k.n) return {cur, false};
cur->next[0]->before = toUpdate[0];
@@ -328,7 +341,6 @@ protected:
root->next[curL] == nullptr)
curL--;
std::destroy_at(&cur->get());
auto ret = std::make_pair(cur->next[0], true);
nodeAllocator.push(cur);
return ret;
@@ -341,13 +353,13 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--) {
while (!cur->next[i]->end && Comparator()(cur->next[i]->get(), newNode->get()))
while (!cur->next[i]->end() && Comparator()(cur->next[i]->get(), newNode->get()))
cur = cur->next[i];
toUpdate[i] = cur;
}
cur = cur->next[0];
if constexpr (!Duplicate)
if (!cur->end && Comparator().eq(cur->get(), newNode->get())) return {cur, false};
if (!cur->end() && Comparator().eq(cur->get(), newNode->get())) return {cur, false};
size_t newLevel = randomL();
@@ -371,16 +383,18 @@ protected:
template<class Comparator, bool Duplicate>
std::pair<Node *, bool> insert(Data d) {
Node *n = nodeAllocator.get();
new (n->data) Data(std::move(d));
return insert<Comparator, Duplicate>(n);
Node *n = nodeAllocator.get(std::move(d));
auto ret = insert<Comparator, Duplicate>(n);
if (!ret.second) nodeAllocator.push(n);
return ret;
}
template<class Comparator, bool Duplicate, class... Args>
std::pair<Node *, bool> emplace(Args &&...args) {
Node *n = nodeAllocator.get();
new (&n->get()) Data(std::forward<Args>(args)...);
return insert<Comparator, Duplicate>(n);
Node *n = nodeAllocator.get(std::forward<Args>(args)...);
auto ret = insert<Comparator, Duplicate>(n);
if (!ret.second) nodeAllocator.push(n);
return ret;
}
// Comparator true if less than, 0 if equal
@@ -389,13 +403,13 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--) {
while (!cur->next[i]->end && Comparator()(cur->next[i]->get(), k))
while (!cur->next[i]->end() && Comparator()(cur->next[i]->get(), k))
cur = cur->next[i];
toUpdate[i] = cur;
}
cur = cur->next[0];
if (cur->end || !Comparator().eq(cur->get(), k)) return {cur, false};
if (cur->end() || !Comparator().eq(cur->get(), k)) return {cur, false};
cur->next[0]->before = toUpdate[0];
@@ -410,7 +424,6 @@ protected:
root->next[curL] == nullptr)
curL--;
std::destroy_at(&cur->get());
auto ret = std::make_pair(cur->next[0], true);
nodeAllocator.push(cur);
return ret;
@@ -421,9 +434,9 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--)
while (!cur->next[i]->end && (Comparator()(cur->next[i]->get(), k) || Comparator().eq(cur->next[i]->get(), k)))
while (!cur->next[i]->end() && (Comparator()(cur->next[i]->get(), k) || Comparator().eq(cur->next[i]->get(), k)))
cur = cur->next[i];
if (!cur->end && (Comparator()(cur->get(), k) || Comparator().eq(cur->get(), k)))
if (!cur->end() && (Comparator()(cur->get(), k) || Comparator().eq(cur->get(), k)))
cur = cur->next[0];
return cur;
}
@@ -433,7 +446,7 @@ protected:
Node *cur = root;
for (int i = curL; i >= 0; i--)
while (!cur->next[i]->end && Comparator()(cur->next[i]->get(), k))
while (!cur->next[i]->end() && Comparator()(cur->next[i]->get(), k))
cur = cur->next[i];
return cur->next[0];
@@ -444,13 +457,13 @@ public:
auto n = root->next[0];
auto n2 = r.root->next[0];
while (!n->end && !n2->end) {
while (!n->end() && !n2->end()) {
if (!(n->get() == n2->get())) return false;
n = n->next[0];
n2 = n2->next[0];
}
if ((n->end || n2->end) && n->end != n2->end) return false;
if ((n->end() || n2->end()) && n->end() != n2->end()) return false;
return true;
}
@@ -529,43 +542,43 @@ public:
const_iterator find(const K &k) const {
typename BaseT::Node *n = BaseT::template lower_bound<Cmp>(k);
if (n->end || n->get().first != k) return end();
if (n->end() || n->get().first != k) return end();
return const_iterator(n);
}
iterator find(const K &k) {
typename BaseT::Node *n = BaseT::template lower_bound<Cmp>(k);
if (n->end || n->get().first != k) return end();
if (n->end() || n->get().first != k) return end();
return iterator(n);
}
const_iterator upper_bound(const K &k) const {
typename BaseT::Node *n = BaseT::template upper_bound<Cmp>(k);
if (n->end) return end();
if (n->end()) return end();
return const_iterator(n);
}
iterator upper_bound(const K &k) {
typename BaseT::Node *n = BaseT::template upper_bound<Cmp>(k);
if (n->end) return end();
if (n->end()) return end();
return iterator(n);
}
iterator erase(const K &k) {
std::pair<typename BaseT::Node *, bool> n = BaseT::template erase<Cmp>(k);
if (n.first->end) return end();
if (n.first->end()) return end();
return iterator(n.first);
}
iterator erase(const_iterator first, const_iterator last) {
std::pair<typename BaseT::Node *, bool> n = BaseT::template erase<Cmp>(first, last);
if (n.first->end) return end();
if (n.first->end()) return end();
return iterator(n.first);
}
iterator erase(const_iterator el) {
std::pair<typename BaseT::Node *, bool> n = BaseT::erase(el);
if (n.first->end) return end();
if (n.first->end()) return end();
return iterator(n.first);
}
};

View File

@@ -8,66 +8,57 @@
#include "kmem.hpp"
#include "string.h"
// Null terminated string
class String {
public:
String() noexcept {
_data = static_cast<char *>(kmalloc(1 * sizeof(char)));
curLen = 0;
_data[0] = '\0';
String() {
//FIXME:
_data = static_cast<char *>(kmalloc(_cur_len + 1));
}
String(const char *in) noexcept {
curLen = strlen(in);
_data = static_cast<char *>(kmalloc((curLen + 1) * sizeof(char)));
_data[0] = '\0';
strcat(_data, in);
_cur_len = strlen(in);
_data = static_cast<char *>(kmalloc(_cur_len + 1));
memcpy(_data, in, _cur_len + 1);
}
String(String const &str) noexcept {
curLen = str.curLen;
_data = static_cast<char *>(kmalloc((curLen + 1) * sizeof(char)));
_data[0] = '\0';
strcat(_data, str._data);
_cur_len = str._cur_len;
_data = static_cast<char *>(kmalloc(_cur_len + 1));
memcpy(_data, str._data, _cur_len + 1);
}
String(String &&str) noexcept {
_data = str._data;
curLen = str.curLen;
_cur_len = str._cur_len;
str._data = static_cast<char *>(kmalloc(1 * sizeof(char)));
str.curLen = 0;
str._data[0] = '\0';
str._cur_len = 0;
str._data = nullptr;
}
String &operator=(String str) noexcept {
std::swap(_data, str._data);
std::swap(curLen, str.curLen);
std::swap(_cur_len, str._cur_len);
return *this;
}
~String() noexcept {
if (_data == nullptr) return;
kfree(_data);
_data = nullptr;
curLen = 0;
}
String &operator+=(String const &rhs) {
_data = static_cast<char *>(krealloc(_data, sizeof(char) * (curLen + rhs.curLen + 1)));
_data = static_cast<char *>(krealloc(_data, _cur_len + rhs._cur_len + 1));
assert(_data != nullptr);
strcat(_data, rhs._data);
curLen += rhs.curLen;
_cur_len += rhs._cur_len;
return *this;
}
String &operator+=(unsigned long value) {
char buf[20];
char buf[32];
itoa(value, buf, 10);
*this += buf;
@@ -75,7 +66,7 @@ public:
return *this;
}
String &operator+=(unsigned long long value) {
char buf[32];
char buf[64];
itoa(value, buf, 10);
*this += buf;
@@ -84,12 +75,12 @@ public:
}
String &operator+=(char c) {
_data = static_cast<char *>(krealloc(_data, sizeof(char) * (curLen + 2)));
_data = static_cast<char *>(krealloc(_data, _cur_len + 2));
assert(_data != nullptr);
_data[curLen] = c;
_data[curLen + 1] = '\0';
curLen++;
_data[_cur_len] = c;
_data[_cur_len + 1] = '\0';
_cur_len++;
return *this;
}
@@ -102,11 +93,11 @@ public:
}
size_t length() const {
return curLen;
return _cur_len;
}
bool empty() const {
return curLen == 0;
return _cur_len == 0;
}
bool operator==(String const &rhs) const {
@@ -130,8 +121,8 @@ public:
}
private:
size_t curLen = 0;
char *_data;
size_t _cur_len = 0;
char *_data = nullptr;
};
#endif

View File

@@ -1,5 +1,5 @@
add_library(KApi
INTERFACE
kmem.cpp
)
target_include_directories(KApi INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})

View File

@@ -0,0 +1,63 @@
//
// Created by Stepan Usatiuk on 12.04.2024.
//
#include "kmem.hpp"
#include <cassert>
#include <iostream>
#include <map>
#include <sstream>
#include <execinfo.h>
std::map<uintptr_t, std::pair<size_t, std::string>> track;
// Based on: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html
static std::string get_stacktrace() {
std::vector<void *> functions(50);
char **strings;
int n;
n = backtrace(functions.data(), 50);
strings = backtrace_symbols(functions.data(), n);
std::stringstream out;
if (strings != nullptr) {
out << "Stacktrace:" << std::endl;
for (int i = 0; i < n; i++) out << strings[i] << std::endl;
}
free(strings);
return out.str();
}
void *kmalloc(size_t n) {
void *ret = malloc(n);
track.emplace((uintptr_t) ret, std::make_pair(n, get_stacktrace()));
return ret;
}
void kfree(void *addr) {
if (addr == nullptr) return;
assert(track.contains((uintptr_t) addr));
track.erase((uintptr_t) addr);
free(addr);
}
void *krealloc(void *addr, size_t newsize) {
if (addr != nullptr) {
assert(track.contains((uintptr_t) addr));
track.erase((uintptr_t) addr);
}
void *ret = realloc(addr, newsize);
track.emplace((uintptr_t) ret, std::make_pair(newsize, get_stacktrace()));
return ret;
}
__attribute__((destructor)) void check() {
for (const auto &t: track) {
std::cerr << "Leaked " << t.second.first << " bytes" << '\n';
std::cerr << t.second.second << std::endl;
}
assert(track.empty());
}

View File

@@ -9,9 +9,9 @@
#include <cstdint>
#include <cstdlib>
void *kmalloc(size_t n) { return malloc(n); }
void kfree(void *addr) { return free(addr); }
void *krealloc(void *addr, size_t newsize) { return realloc(addr, newsize); }
void *kmalloc(size_t n);
void kfree(void *addr);
void *krealloc(void *addr, size_t newsize);
#endif //KMEM_HPP

View File

@@ -8,6 +8,7 @@ add_executable(
target_link_libraries(
SkipListTest
templates
KApi
GTest::gtest_main
)
add_executable(
@@ -18,6 +19,7 @@ add_executable(
target_link_libraries(
SkipListDetailedTest
templates
KApi
GTest::gtest_main
)

View File

@@ -517,7 +517,7 @@ TYPED_TEST(SkipListDetailedTestFixture, ItWorks) {
assert(oss.str() == "-100\n<-9..20>\n<48..93>\n<150..200>\n400=======");
// FIXME: I'll leave it a "stress test" for now
for (long long x = 25; x <= 250; x *= 10) {
for (long long x = 10; x <= 10; x *= 10) {
CRangeListTested t;
std::cout << x << "====" << std::endl;
auto start = std::chrono::high_resolution_clock::now();