Skip to content

Commit

Permalink
Merge pull request #199 from tidymodels/use-philentropy
Browse files Browse the repository at this point in the history
EmilHvitfeldt authored Jan 27, 2025
2 parents c9edff9 + 8542615 commit 2f9313a
Showing 18 changed files with 158 additions and 41 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -26,8 +26,8 @@ Imports:
hardhat (>= 1.0.0),
modelenv (>= 0.2.0.9000),
parsnip (>= 1.0.2),
philentropy (>= 0.9.0),
prettyunits (>= 1.1.0),
Rfast (>= 2.0.6),
rlang (>= 1.0.6),
rsample (>= 1.0.0),
stats,
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# tidyclust (development version)

* The philentropy package is now used to calculate distances rather than Rfast. (#199)

# tidyclust 0.2.3

* Update to fix revdep issue for clustMixType. (#190)
16 changes: 14 additions & 2 deletions R/extract_fit_summary.R
Original file line number Diff line number Diff line change
@@ -167,15 +167,27 @@ extract_fit_summary.hclust <- function(object, ...) {
sse_within_total_total <- map2_dbl(
by_clust$data,
seq_len(n_clust),
~sum(Rfast::dista(centroids[.y, ], .x))
~sum(
philentropy::dist_many_many(
as.matrix(centroids[.y, ]),
as.matrix(.x),
method = "euclidean"
)
)
)

list(
cluster_names = unique(clusts),
centroids = centroids,
n_members = unname(as.integer(table(clusts))),
sse_within_total_total = sse_within_total_total,
sse_total = sum(Rfast::dista(t(overall_centroid), training_data)),
sse_total = sum(
philentropy::dist_many_many(
t(overall_centroid),
as.matrix(training_data),
method = "euclidean"
)
),
orig_labels = NULL,
cluster_assignments = clusts
)
6 changes: 4 additions & 2 deletions R/hier_clust.R
Original file line number Diff line number Diff line change
@@ -193,9 +193,11 @@ translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) {
num_clusters = NULL,
cut_height = NULL,
linkage_method = NULL,
dist_fun = Rfast::Dist
dist_fun = philentropy::distance
) {
dmat <- dist_fun(x)
suppressMessages(
dmat <- dist_fun(x)
)
res <- stats::hclust(stats::as.dist(dmat), method = linkage_method)
attr(res, "num_clusters") <- num_clusters
attr(res, "cut_height") <- cut_height
21 changes: 16 additions & 5 deletions R/metric-helpers.R
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ prep_data_dist <- function(
object,
new_data = NULL,
dists = NULL,
dist_fun = Rfast::Dist
dist_fun = philentropy::distance
) {
# Sihouettes requires a distance matrix
if (is.null(new_data) && is.null(dists)) {
@@ -46,7 +46,9 @@ prep_data_dist <- function(

# Calculate distances including optionally supplied params
if (is.null(dists)) {
dists <- dist_fun(new_data)
suppressMessages(
dists <- dist_fun(new_data)
)
}

return(
@@ -63,11 +65,20 @@ prep_data_dist <- function(
#' @param new_data A data frame
#' @param centroids A data frame where each row is a centroid.
#' @param dist_fun A function for computing matrix-to-matrix distances. Defaults
#' to `Rfast::dista()`
get_centroid_dists <- function(new_data, centroids, dist_fun = Rfast::dista) {
#' to
#' `function(x, y) philentropy::dist_many_many(x, y, method = "euclidean")`.
get_centroid_dists <- function(
new_data,
centroids,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
}
) {
if (ncol(new_data) != ncol(centroids)) {
rlang::abort("Centroids must have same columns as data.")
}

dist_fun(centroids, new_data)
suppressMessages(
dist_fun(as.matrix(centroids), as.matrix(new_data))
)
}
8 changes: 4 additions & 4 deletions R/metric-silhouette.R
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ silhouette <- function(
object,
new_data = NULL,
dists = NULL,
dist_fun = Rfast::Dist
dist_fun = philentropy::distance
) {
if (inherits(object, "cluster_spec")) {
rlang::abort(
@@ -126,7 +126,7 @@ silhouette_avg.cluster_fit <- function(
...
) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::Dist
dist_fun <- philentropy::distance
}

res <- silhouette_avg_impl(object, new_data, dists, dist_fun, ...)
@@ -148,7 +148,7 @@ silhouette_avg_vec <- function(
object,
new_data = NULL,
dists = NULL,
dist_fun = Rfast::Dist,
dist_fun = philentropy::distance,
...
) {
silhouette_avg_impl(object, new_data, dists, dist_fun, ...)
@@ -158,7 +158,7 @@ silhouette_avg_impl <- function(
object,
new_data = NULL,
dists = NULL,
dist_fun = Rfast::Dist,
dist_fun = philentropy::distance,
...
) {
mean(silhouette(object, new_data, dists, dist_fun, ...)$sil_width)
56 changes: 44 additions & 12 deletions R/metric-sse.R
Original file line number Diff line number Diff line change
@@ -19,7 +19,13 @@
#'
#' sse_within(kmeans_fit)
#' @export
sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
sse_within <- function(
object,
new_data = NULL,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
}
) {
if (inherits(object, "cluster_spec")) {
rlang::abort(
paste(
@@ -43,7 +49,12 @@ sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
n_members = summ$n_members
)
} else {
dist_to_centroids <- dist_fun(summ$centroids, new_data)
suppressMessages(
dist_to_centroids <- dist_fun(
as.matrix(summ$centroids),
as.matrix(new_data)
)
)

res <- dist_to_centroids %>%
tibble::as_tibble(.name_repair = "minimal") %>%
@@ -121,7 +132,9 @@ sse_within_total.cluster_fit <- function(
...
) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::dista
dist_fun <- function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
}
}

res <- sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -142,7 +155,9 @@ sse_within_total.workflow <- sse_within_total.cluster_fit
sse_within_total_vec <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
sse_within_total_impl(object, new_data, dist_fun, ...)
@@ -151,7 +166,9 @@ sse_within_total_vec <- function(
sse_within_total_impl <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE)
@@ -208,7 +225,9 @@ sse_total.cluster_fit <- function(
...
) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::dista
dist_fun <- function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
}
}

res <- sse_total_impl(object, new_data, dist_fun, ...)
@@ -229,7 +248,9 @@ sse_total.workflow <- sse_total.cluster_fit
sse_total_vec <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
sse_total_impl(object, new_data, dist_fun, ...)
@@ -238,7 +259,9 @@ sse_total_vec <- function(
sse_total_impl <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
# Preprocess data before computing distances if appropriate
@@ -253,7 +276,10 @@ sse_total_impl <- function(
} else {
overall_mean <- colSums(summ$centroids * summ$n_members) /
sum(summ$n_members)
tot <- dist_fun(t(as.matrix(overall_mean)), new_data)^2 %>% sum()
suppressMessages(
tot <- dist_fun(t(as.matrix(overall_mean)), as.matrix(new_data))^2 %>%
sum()
)
}

return(tot)
@@ -310,7 +336,9 @@ sse_ratio.cluster_fit <- function(
...
) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::dista
dist_fun <- function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
}
}
res <- sse_ratio_impl(object, new_data, dist_fun, ...)

@@ -330,7 +358,9 @@ sse_ratio.workflow <- sse_ratio.cluster_fit
sse_ratio_vec <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
sse_ratio_impl(object, new_data, dist_fun, ...)
@@ -339,7 +369,9 @@ sse_ratio_vec <- function(
sse_ratio_impl <- function(
object,
new_data = NULL,
dist_fun = Rfast::dista,
dist_fun = function(x, y) {
philentropy::dist_many_many(x, y, method = "euclidean")
},
...
) {
sse_within_total_vec(object, new_data, dist_fun) /
13 changes: 11 additions & 2 deletions R/predict_helpers.R
Original file line number Diff line number Diff line change
@@ -96,7 +96,11 @@ make_predictions <- function(x, prefix, n_clusters) {
)

# need this to be obs on rows, dist to new data on cols
dists_new <- Rfast::dista(xnew = training_data, x = new_data, trans = TRUE)
dists_new <- philentropy::dist_many_many(
training_data,
new_data,
method = "euclidean"
)

cluster_dists <- dplyr::bind_cols(data.frame(dists_new), clusters) %>%
dplyr::group_by(.cluster) %>%
@@ -109,7 +113,12 @@ make_predictions <- function(x, prefix, n_clusters) {
## Centroid linkage_method, dist to center

cluster_centers <- extract_centroids(object) %>% dplyr::select(-.cluster)
dists_means <- Rfast::dista(new_data, cluster_centers)

dists_means <- philentropy::dist_many_many(
new_data,
cluster_centers,
method = "euclidean"
)

pred_clusts_num <- apply(dists_means, 1, which.min)
} else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) {
2 changes: 1 addition & 1 deletion man/dot-hier_clust_fit_stats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions man/get_centroid_dists.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/prep_data_dist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/silhouette.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/silhouette_avg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/sse_ratio.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/sse_total.Rd
9 changes: 8 additions & 1 deletion man/sse_within.Rd
10 changes: 9 additions & 1 deletion man/sse_within_total.Rd
6 changes: 3 additions & 3 deletions vignettes/articles/k_means.Rmd
Original file line number Diff line number Diff line change
@@ -291,11 +291,11 @@ matrix (i.e., all pairwise distances between observations).

```{r}
my_dist_1 <- function(x) {
Rfast::Dist(x, method = "manhattan")
philentropy::distance(x, method = "manhattan")
}
my_dist_2 <- function(x, y) {
Rfast::dista(x, y, method = "manhattan")
philentropy::dist_many_many(x, y, method = "manhattan")
}
kmeans_fit %>% sse_ratio(dist_fun = my_dist_2)
@@ -404,7 +404,7 @@ pens %>%

```{r, echo = FALSE}
#| fig-alt: "scatter chart. bill_length_mm along the x-axis, bill_depth_mm along the y-axis. 3 vague cluster appears in the point cloud. Point are colored according to how close they were to the color points."
closest_center <- Rfast::dista(as.matrix(pens), as.matrix(pens[init, ])) %>%
closest_center <- philentropy::dist_many_many(as.matrix(pens), as.matrix(pens[init, ]), method = "euclidean") %>%
apply(1, which.min)
pens %>%

0 comments on commit 2f9313a

Please sign in to comment.