diff --git a/include/CLUEstering/utils/detail/scores.hpp b/include/CLUEstering/utils/detail/scores.hpp index 6a50842c..6a279a28 100644 --- a/include/CLUEstering/utils/detail/scores.hpp +++ b/include/CLUEstering/utils/detail/scores.hpp @@ -1,6 +1,7 @@ #pragma once +#include "CLUEstering/core/DistanceMetrics.hpp" #include "CLUEstering/data_structures/PointsHost.hpp" #include "CLUEstering/data_structures/AssociationMap.hpp" #include @@ -20,32 +21,24 @@ namespace clue { template using Point = typename clue::PointsHost::Point; - template - inline auto distance(const Point& lhs, const Point& rhs) { - auto dist = 0.f; - for (auto dim = 0u; dim < Ndim; ++dim) { - dist += (lhs[dim] - rhs[dim]) * (lhs[dim] - rhs[dim]); - } - return std::sqrt(dist); - } - - template + template DistanceMetric = clue::EuclideanMetric> inline auto silhouette(const clue::host_associator& clusters, const clue::PointsHost& points, - int point) { + int point, + const DistanceMetric& metric = clue::EuclideanMetric{}) { auto a = 0.f; std::vector b_values; b_values.reserve(clusters.size() - 1); - a += - std::accumulate(clusters.lower_bound(points[point].cluster_index()), - clusters.upper_bound(points[point].cluster_index()), - 0.f, - [&](float acc, int other_point) { - if (other_point == point) - return acc; - return acc + detail::distance(points[point], points[other_point]); - }); + a += std::accumulate(clusters.lower_bound(points[point].cluster_index()), + clusters.upper_bound(points[point].cluster_index()), + 0.f, + [&](float acc, int other_point) { + if (other_point == point) + return acc; + return acc + metric(points[point], points[other_point]); + }); a /= static_cast(clusters.count(points[point].cluster_index()) - 1); for (auto cluster_idx = 0; cluster_idx < static_cast(clusters.size()); ++cluster_idx) { @@ -56,8 +49,7 @@ namespace clue { clusters.upper_bound(cluster_idx), 0.f, [&](float acc, int other_point) { - return acc + - detail::distance(points[point], points[other_point]); + return acc + metric(points[point], points[other_point]); }); b /= static_cast(clusters.count(cluster_idx)); b_values.push_back(b); @@ -72,15 +64,17 @@ namespace clue { } // namespace detail - template - inline auto silhouette(const clue::PointsHost& points, int point) { + template DistanceMetric> + inline auto silhouette(const clue::PointsHost& points, + int point, + const DistanceMetric& metric) { const auto clusters = clue::get_clusters(points); - return detail::silhouette(clusters, points, point); + return detail::silhouette(clusters, points, point, metric); } - template - inline auto silhouette(const clue::PointsHost& points) { + template DistanceMetric> + inline auto silhouette(const clue::PointsHost& points, const DistanceMetric& metric) { const auto clusters = clue::get_clusters(points); std::vector scores; auto valid_point = [&](int point) -> bool { return points[point].cluster_index() != -1; }; @@ -88,7 +82,7 @@ namespace clue { return clusters.count(points[point].cluster_index()) >= 2; }; auto compute_silhouette = [&](std::size_t point) -> float { - return detail::silhouette(clusters, points, point); + return detail::silhouette(clusters, points, point, metric); }; std::ranges::copy(std::views::iota(0) | std::views::take(points.size()) | std::views::filter(valid_point) | std::views::filter(valid_cluster) | diff --git a/include/CLUEstering/utils/scores.hpp b/include/CLUEstering/utils/scores.hpp index 49cf2042..9828f333 100644 --- a/include/CLUEstering/utils/scores.hpp +++ b/include/CLUEstering/utils/scores.hpp @@ -4,6 +4,7 @@ #pragma once +#include "CLUEstering/core/DistanceMetrics.hpp" #include "CLUEstering/data_structures/PointsHost.hpp" namespace clue { @@ -15,8 +16,11 @@ namespace clue { /// @param point The index of the point for which to compute the silhouette score /// @return The silhouette score of the specified point /// @note This function currently only works for points with non-periodic coordinates. - template - auto silhouette(const clue::PointsHost& points, std::size_t point); + template DistanceMetric = clue::EuclideanMetric> + auto silhouette(const clue::PointsHost& points, + std::size_t point, + const DistanceMetric& metric = clue::EuclideanMetric{}); /// @brief Compute the average silhouette score for the entire dataset. /// @@ -24,8 +28,10 @@ namespace clue { /// @param points The dataset containing the points /// @return The average silhouette score of the dataset /// @note This function currently only works for points with non-periodic coordinates. - template - auto silhouette(const clue::PointsHost& points); + template DistanceMetric = clue::EuclideanMetric> + auto silhouette(const clue::PointsHost& points, + const DistanceMetric& metric = clue::EuclideanMetric{}); } // namespace clue