diff --git a/dbscan.cpp b/dbscan.cpp index 4dd0da5..a24a5fa 100644 --- a/dbscan.cpp +++ b/dbscan.cpp @@ -1,36 +1,31 @@ #include "dbscan.hpp" - #include #include - #include #include +#include -// And this is the "dataset to kd-tree" adaptor class: - -inline auto get_pt(const point2& p, std::size_t dim) +// Helper function to get the coordinate of a 2D point given the dimension (0 for x, 1 for y) +inline float get_pt(const Point2& p, std::size_t dim) { - if(dim == 0) return p.x; + if (dim == 0) return p.x; return p.y; } - -inline auto get_pt(const point3& p, std::size_t dim) +// Helper function to get the coordinate of a 3D point given the dimension (0 for x, 1 for y, 2 for z) +inline float get_pt(const Point3& p, std::size_t dim) { - if(dim == 0) return p.x; - if(dim == 1) return p.y; + if (dim == 0) return p.x; + if (dim == 1) return p.y; return p.z; } - +// Adaptor class for interfacing with the KD-tree implementation template -struct adaptor +struct Adaptor { - const std::span& points; - adaptor(const std::span& points) : points(points) { } - - /// CRTP helper method - //inline const Derived& derived() const { return obj; } + const std::span& points; + Adaptor(const std::span& points) : points(points) {} // Must return the number of data points inline std::size_t kdtree_get_point_count() const { return points.size(); } @@ -49,56 +44,61 @@ struct adaptor template bool kdtree_get_bbox(BBOX& /*bb*/) const { return false; } + // Return a pointer to the x coordinate of the idx'th point auto const * elem_ptr(const std::size_t idx) const { return &points[idx].x; } }; - - -auto sort_clusters(std::vector>& clusters) +// Function to sort clusters by their point indices +void sort_clusters(std::vector>& clusters) { - for(auto& cluster: clusters) + for (auto& cluster : clusters) { std::sort(cluster.begin(), cluster.end()); } } - template -auto dbscan(const Adaptor& adapt, float eps, int min_pts) +std::vector> dbscan_impl(const Adaptor& adapt, float eps, int min_pts) { + // Squaring epsilon for distance comparison eps *= eps; + using namespace nanoflann; - using my_kd_tree_t = KDTreeSingleIndexAdaptor, decltype(adapt), n_cols>; + using my_kd_tree_t = KDTreeSingleIndexAdaptor, Adaptor, n_cols>; + // Building the KD-tree index auto index = my_kd_tree_t(n_cols, adapt, KDTreeSingleIndexAdaptorParams(10)); index.buildIndex(); const auto n_points = adapt.kdtree_get_point_count(); - auto visited = std::vector(n_points); + auto visited = std::vector(n_points); auto clusters = std::vector>(); - auto matches = std::vector>(); + auto matches = std::vector>(); auto sub_matches = std::vector>(); - for(size_t i = 0; i < n_points; i++) + for (size_t i = 0; i < n_points; i++) { if (visited[i]) continue; + // Radius search for neighbors within epsilon distance index.radiusSearch(adapt.elem_ptr(i), eps, matches, SearchParams(32, 0.f, false)); if (matches.size() < static_cast(min_pts)) continue; visited[i] = true; + // Creating a new cluster and adding the core point auto cluster = std::vector({i}); - while (matches.empty() == false) + while (!matches.empty()) { auto nb_idx = matches.back().first; matches.pop_back(); if (visited[nb_idx]) continue; visited[nb_idx] = true; + // Radius search for neighbors of the neighbor index.radiusSearch(adapt.elem_ptr(nb_idx), eps, sub_matches, SearchParams(32, 0.f, false)); if (sub_matches.size() >= static_cast(min_pts)) @@ -113,18 +113,21 @@ auto dbscan(const Adaptor& adapt, float eps, int min_pts) return clusters; } +// DBSCAN class constructor +DBSCAN::DBSCAN(float eps, int min_pts) + : eps_(eps), min_pts_(min_pts) +{} -auto dbscan(const std::span& data, float eps, int min_pts) -> std::vector> +// DBSCAN run method for 2D points +std::vector> DBSCAN::run(const std::span& data) { - const auto adapt = adaptor(data); - - return dbscan<2>(adapt, eps, min_pts); + const auto adapt = Adaptor(data); + return dbscan_impl<2>(adapt, eps_, min_pts_); } - -auto dbscan(const std::span& data, float eps, int min_pts) -> std::vector> +// DBSCAN run method for 3D points +std::vector> DBSCAN::run(const std::span& data) { - const auto adapt = adaptor(data); - - return dbscan<3>(adapt, eps, min_pts); -} \ No newline at end of file + const auto adapt = Adaptor(data); + return dbscan_impl<3>(adapt, eps_, min_pts_); +} diff --git a/dbscan.hpp b/dbscan.hpp index 29fb46f..fb6fb7a 100644 --- a/dbscan.hpp +++ b/dbscan.hpp @@ -1,34 +1,41 @@ #pragma once - #include #include #include #include #include -struct point2 +// Structure to represent a 2D point +struct Point2 { float x, y; }; -struct point3 +// Structure to represent a 3D point +struct Point3 { float x, y, z; }; -auto dbscan(const std::span& data, float eps, int min_pts) -> std::vector>; -auto dbscan(const std::span& data, float eps, int min_pts) -> std::vector>; - -// template -// auto dbscan(const std::span& data, float eps, int min_pts) -// { -// static_assert(dim == 2 or dim == 3, "This only supports either 2D or 3D points"); -// assert(data.size() % dim == 0); - -// if(dim == 2) -// { -// auto * const ptr = reinterpret_cast (data.data()); -// auto points = std::span -// } -// } \ No newline at end of file +// DBSCAN class definition +class DBSCAN +{ +public: + // Constructor to initialize the DBSCAN parameters + DBSCAN(float eps, int min_pts); + + // Method to run DBSCAN on 2D data + std::vector> run(const std::span& data); + + // Method to run DBSCAN on 3D data + std::vector> run(const std::span& data); + +private: + float eps_; // Epsilon value for neighborhood radius + int min_pts_; // Minimum number of points to form a cluster + + // Template method to run DBSCAN on generic data points (2D or 3D) + template + std::vector> run_impl(const std::span& data); +}; diff --git a/example.cpp b/example.cpp index e33a03c..ccaf232 100644 --- a/example.cpp +++ b/example.cpp @@ -10,78 +10,71 @@ #include #include - auto check_from_chars_error(std::errc err, const std::string_view& line, int line_counter) { - if(err == std::errc()) + if (err == std::errc()) return; - - if(err == std::errc::invalid_argument) + + if (err == std::errc::invalid_argument) { - std::cerr << "Error: Invalid value \"" << line - << "\" at line " << line_counter << "\n"; + std::cerr << "Error: Invalid value \"" << line << "\" at line " << line_counter << "\n"; std::exit(1); } - if(err == std::errc::result_out_of_range) + if (err == std::errc::result_out_of_range) { - std::cerr << "Error: Value \"" << line << "\"out of range at line " - << line_counter << "\n"; + std::cerr << "Error: Value \"" << line << "\" out of range at line " << line_counter << "\n"; std::exit(1); } - } - auto push_values(std::vector& store, const std::string_view& line, int line_counter) { auto ptr = line.data(); - auto ec = std::errc(); + auto ec = std::errc(); auto n_pushed = 0; do { float value; - auto [p, ec] = std::from_chars(ptr, line.data() + line.size(), value); + auto [p, ec] = std::from_chars(ptr, line.data() + line.size(), value); ptr = p + 1; check_from_chars_error(ec, line, line_counter); n_pushed++; store.push_back(value); - }while(ptr < line.data() + line.size()); + } while (ptr < line.data() + line.size()); return n_pushed; } - auto read_values(const std::string& filename) { std::ifstream file(filename); - if(not file.good()) + if (not file.good()) { std::perror(filename.c_str()); std::exit(2); } auto count = 0; - auto points = std::vector(); - auto dim = 0; + auto dim = 0; - while(not file.eof()) + while (not file.eof()) { count++; auto line = std::string(); std::getline(file, line); - if(not line.empty()) + if (not line.empty()) { auto n_pushed = push_values(points, line, count); - if(count != 1) + if (count != 1) { - if(n_pushed != dim) + if (n_pushed != dim) { std::cerr << "Inconsistent number of dimensions at line '" << count << "'\n"; std::exit(1); @@ -94,14 +87,13 @@ auto read_values(const std::string& filename) return std::tuple(points, dim); } - template auto to_num(const std::string& str) { T value = 0; auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), value); - if(ec != std::errc()) + if (ec != std::errc()) { std::cerr << "Error converting value '" << str << "'\n"; std::exit(1); @@ -109,15 +101,14 @@ auto to_num(const std::string& str) return value; } - // noise will be labelled as 0 auto label(const std::vector>& clusters, size_t n) { auto flat_clusters = std::vector(n); - for(size_t i = 0; i < clusters.size(); i++) + for (size_t i = 0; i < clusters.size(); i++) { - for(auto p: clusters[i]) + for (auto p : clusters[i]) { flat_clusters[p] = i + 1; } @@ -126,53 +117,49 @@ auto label(const std::vector>& clusters, size_t n) return flat_clusters; } - -auto dbscan2d(const std::span& data, float eps, int min_pts) +void dbscan2d(const std::span& data, float eps, int min_pts) { - auto points = std::vector(data.size() / 2); - + auto points = std::vector(data.size() / 2); std::memcpy(points.data(), data.data(), sizeof(float) * data.size()); - auto clusters = dbscan(points, eps, min_pts); - auto flat = label (clusters, points.size()); + DBSCAN dbscan(eps, min_pts); + auto clusters = dbscan.run(points); + auto flat = label(clusters, points.size()); - for(size_t i = 0; i < points.size(); i++) + for (size_t i = 0; i < points.size(); i++) { - std::cout << points[i].x << ',' << points[i].y << ',' << flat[i] << '\n'; + std::cout << points[i].x << ',' << points[i].y << ',' << flat[i] << '\n'; } } - -auto dbscan3d(const std::span& data, float eps, int min_pts) +void dbscan3d(const std::span& data, float eps, int min_pts) { - auto points = std::vector(data.size() / 3); - + auto points = std::vector(data.size() / 3); std::memcpy(points.data(), data.data(), sizeof(float) * data.size()); - auto clusters = dbscan(points, eps, min_pts); - auto flat = label (clusters, points.size()); + DBSCAN dbscan(eps, min_pts); + auto clusters = dbscan.run(points); + auto flat = label(clusters, points.size()); - for(size_t i = 0; i < points.size(); i++) + for (size_t i = 0; i < points.size(); i++) { std::cout << points[i].x << ',' << points[i].y << ',' << points[i].z << ',' << flat[i] << '\n'; } } - - int main(int argc, char** argv) { - if(argc != 4) + if (argc != 4) { std::cerr << "usage: example \n"; return 1; } - auto epsilon = to_num(argv[2]); - auto min_pts = to_num (argv[3]); + auto epsilon = to_num(argv[2]); + auto min_pts = to_num(argv[3]); auto [values, dim] = read_values(argv[1]); - if(dim == 2) + if (dim == 2) { dbscan2d(values, epsilon, min_pts); } @@ -180,4 +167,4 @@ int main(int argc, char** argv) { dbscan3d(values, epsilon, min_pts); } -} \ No newline at end of file +}