diff --git a/src/Data_Structures/Trie.hpp b/src/Data_Structures/Trie.hpp new file mode 100644 index 0000000..aea902d --- /dev/null +++ b/src/Data_Structures/Trie.hpp @@ -0,0 +1,115 @@ +#include +#include +#include + +namespace mystd +{ +class Trie +{ + public: + Trie() : root(std::make_shared(char())) {} + + bool find(const std::string &toFind) const + { + auto currNode = root; + for (int i = 0; i < toFind.size(); i++) + { + auto e = toFind[i]; + bool found = false; + for (const auto &child : currNode->children) + { + if (child->value == e) + { + found = true; + currNode = child; + break; + } + } + if (!found || (i == toFind.size() - 1 && found && !currNode->isLast)) + return false; + } + return true; + } + + void add_item(const std::string &toAdd) + { + auto currNode = root; + + for (int i = 0; i < toAdd.size(); i++) + { + auto e = toAdd[i]; + bool found = false; + for (const auto &child : currNode->children) + { + if (child->value == e) + { + found = true; + currNode = child; + break; + } + } + if (!found) + { + auto newNode = std::make_shared(e); + newNode->parent = currNode; + currNode->children.push_back(newNode); + currNode = newNode; + } + if (i == toAdd.size() - 1) + currNode->isLast = true; + } + } + + bool delete_item(const std::string &toDelete) + { + if (!find(toDelete) || toDelete.size() == 0) + return false; + std::stack> removes; + auto currNode = root; + for (int i = 0; i < toDelete.size(); i++) + { + for (const auto &child : currNode->children) + { + if (child->value == toDelete[i]) + { + removes.push(child); + currNode = child; + } + } + } + removes.top()->isLast = false; + while (!removes.empty()) + { + auto currNode = removes.top(); + removes.pop(); + if (currNode->children.size() > 0 || currNode->isLast == true) + { + break; + } + auto par = currNode->parent.lock(); + for (int i = 0; i < par->children.size(); i++) + { + if (par->children[i] == currNode) + { + par->children.erase(par->children.begin() + i); + } + } + currNode->parent.reset(); + } + return true; + } + + private: + struct Node + { + Node(const char &init) : value(init), isLast(false) {} + + std::weak_ptr parent; + std::vector> children; + char value; + bool isLast; + }; + + std::shared_ptr root; +}; +} // namespace mystd \ No newline at end of file diff --git a/tst/test_trie.cpp b/tst/test_trie.cpp new file mode 100644 index 0000000..8ef3e34 --- /dev/null +++ b/tst/test_trie.cpp @@ -0,0 +1,83 @@ +#include +#include + +using namespace mystd; + +TEST(TrieTest, BasicTrie) +{ + auto t = Trie(); + t.add_item("hello"); + ASSERT_EQ(true, t.find("hello")); + ASSERT_EQ(false, t.find("hell")); + t.add_item("cat"); + t.add_item("cate"); + ASSERT_EQ(false, t.find("ca")); + ASSERT_EQ(true, t.find("cat")); + ASSERT_EQ(true, t.find("cate")); + t.delete_item("cat"); + ASSERT_EQ(false, t.find("cat")); + ASSERT_EQ(true, t.find("cate")); + t.delete_item("cate"); + ASSERT_EQ(false, t.find("cat")); + ASSERT_EQ(false, t.find("cate")); +} + +TEST(TrieTest, OverlappingPrefixes) +{ + Trie t; + t.add_item("apple"); + t.add_item("app"); + t.add_item("apricot"); + + ASSERT_TRUE(t.find("apple")); + ASSERT_TRUE(t.find("app")); + ASSERT_TRUE(t.find("apricot")); + ASSERT_FALSE(t.find("appl")); +} + +TEST(TrieTest, DeletionEdgeCases) +{ + Trie t; + t.add_item("a"); + t.add_item("ab"); + t.add_item("abc"); + + t.delete_item("ab"); + ASSERT_TRUE(t.find("a")); + ASSERT_FALSE(t.find("ab")); + ASSERT_TRUE(t.find("abc")); + + t.delete_item("abc"); + ASSERT_TRUE(t.find("a")); + ASSERT_FALSE(t.find("abc")); + + t.delete_item("a"); + ASSERT_FALSE(t.find("a")); +} + +TEST(TrieTest, CaseSensitivity) +{ + Trie t; + t.add_item("Test"); + t.add_item("test"); + + ASSERT_TRUE(t.find("Test")); + ASSERT_TRUE(t.find("test")); + ASSERT_FALSE(t.find("Tes")); + ASSERT_FALSE(t.find("TEST")); + + t.delete_item("test"); + ASSERT_FALSE(t.find("test")); + ASSERT_TRUE(t.find("Test")); +} + +TEST(TrieTest, LongWordStressTest) +{ + Trie t; + std::string long_word(10000, 'a'); + t.add_item(long_word); + ASSERT_TRUE(t.find(long_word)); + + long_word[9999] = 'b'; + ASSERT_FALSE(t.find(long_word)); +}