Skip to content

Commit

Permalink
avg_silhouette -> silhouette_avg
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Nov 1, 2022
1 parent e9a88a7 commit 5845b55
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 88 deletions.
8 changes: 4 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

S3method(as_tibble,cluster_metric_set)
S3method(augment,cluster_fit)
S3method(avg_silhouette,cluster_fit)
S3method(avg_silhouette,workflow)
S3method(extract_cluster_assignment,KMeansCluster)
S3method(extract_cluster_assignment,cluster_fit)
S3method(extract_cluster_assignment,hclust)
Expand Down Expand Up @@ -32,6 +30,8 @@ S3method(print,k_means)
S3method(set_args,cluster_spec)
S3method(set_engine,cluster_spec)
S3method(set_mode,cluster_spec)
S3method(silhouette_avg,cluster_fit)
S3method(silhouette_avg,workflow)
S3method(sse_ratio,cluster_fit)
S3method(sse_ratio,workflow)
S3method(sse_total,cluster_fit)
Expand All @@ -54,8 +54,6 @@ export(.convert_x_to_form_fit)
export(.convert_x_to_form_new)
export(ClusterR_kmeans_fit)
export(augment)
export(avg_silhouette)
export(avg_silhouette_vec)
export(check_empty_ellipse_tidyclust)
export(cluster_metric_set)
export(control_cluster)
Expand Down Expand Up @@ -94,6 +92,8 @@ export(required_pkgs)
export(set_args)
export(set_engine)
export(set_mode)
export(silhouette_avg)
export(silhouette_avg_vec)
export(silhouettes)
export(sse_ratio)
export(sse_ratio_vec)
Expand Down
32 changes: 16 additions & 16 deletions R/metric-silhouette.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,48 +59,48 @@ silhouettes <- function(object, new_data = NULL, dists = NULL,
#' as.matrix() %>%
#' dist()
#'
#' avg_silhouette(kmeans_fit, dists = dists)
#' silhouette_avg(kmeans_fit, dists = dists)
#'
#' avg_silhouette_vec(kmeans_fit, dists = dists)
#' silhouette_avg_vec(kmeans_fit, dists = dists)
#' @export
avg_silhouette <- function(object, ...) {
UseMethod("avg_silhouette")
silhouette_avg <- function(object, ...) {
UseMethod("silhouette_avg")
}

avg_silhouette <- new_cluster_metric(
avg_silhouette,
silhouette_avg <- new_cluster_metric(
silhouette_avg,
direction = "zero"
)

#' @export
#' @rdname avg_silhouette
avg_silhouette.cluster_fit <- function(object, new_data = NULL, dists = NULL,
#' @rdname silhouette_avg
silhouette_avg.cluster_fit <- function(object, new_data = NULL, dists = NULL,
dist_fun = NULL, ...) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::Dist
}

res <- avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
res <- silhouette_avg_impl(object, new_data, dists, dist_fun, ...)

tibble::tibble(
.metric = "avg_silhouette",
.metric = "silhouette_avg",
.estimator = "standard",
.estimate = res
)
}

#' @export
#' @rdname avg_silhouette
avg_silhouette.workflow <- avg_silhouette.cluster_fit
#' @rdname silhouette_avg
silhouette_avg.workflow <- silhouette_avg.cluster_fit

#' @export
#' @rdname avg_silhouette
avg_silhouette_vec <- function(object, new_data = NULL, dists = NULL,
#' @rdname silhouette_avg
silhouette_avg_vec <- function(object, new_data = NULL, dists = NULL,
dist_fun = Rfast::Dist, ...) {
avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
silhouette_avg_impl(object, new_data, dists, dist_fun, ...)
}

avg_silhouette_impl <- function(object, new_data = NULL, dists = NULL,
silhouette_avg_impl <- function(object, new_data = NULL, dists = NULL,
dist_fun = Rfast::Dist, ...) {
mean(silhouettes(object, new_data, dists, dist_fun, ...)$sil_width)
}
2 changes: 1 addition & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ reference:
well the model works.
contents:
- cluster_metric_set
- avg_silhouette
- silhouette_avg
- sse_ratio
- sse_total
- sse_within
Expand Down
2 changes: 1 addition & 1 deletion dev/cross_val_kmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ for (k in 2:10) {
wss_2 <- km_fit$fit$tot.withinss

sil <- km_fit %>%
avg_silhouette(tmp_test)
silhouette_avg(tmp_test)

res <- rbind(res,
c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2))
Expand Down
2 changes: 1 addition & 1 deletion dev/test_hclust_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ my_mod <- hier_clust(k = 3, linkage_method = "ward.D") %>% fit(~., mtcars)
tidyclust:::stats_hier_clust_predict(my_mod$fit, mtcars)
predict(my_mod, mtcars)

avg_silhouette(my_mod)
silhouette_avg(my_mod)
110 changes: 55 additions & 55 deletions man/avg_silhouette.Rd → man/silhouette_avg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/cluster_metric_set.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
my_metrics(kmeans_fit)
Condition
Error in `value[[3L]]()`:
! In metric: `avg_silhouette`
! In metric: `silhouette_avg`
Must supply either a dataset or distance matrix to compute silhouettes.

# cluster_metric_set error with wrong input
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-cluster_metric_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ test_that("cluster_metric_set works", {

kmeans_fit <- fit(kmeans_spec, ~., mtcars)

my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within, avg_silhouette)
my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within, silhouette_avg)

exp_res <- tibble::tibble(
.metric = c("sse_ratio", "sse_total", "sse_within", "avg_silhouette"),
.metric = c("sse_ratio", "sse_total", "sse_within", "silhouette_avg"),
.estimator = "standard",
.estimate = vapply(
list(sse_ratio_vec, sse_total_vec, sse_within_vec, avg_silhouette_vec),
list(sse_ratio_vec, sse_total_vec, sse_within_vec, silhouette_avg_vec),
function(x) x(kmeans_fit, new_data = mtcars),
FUN.VALUE = numeric(1)
)
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test-k_means_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ test_that("kmeans sihouette metrics work", {
)

expect_equal(
avg_silhouette_vec(kmeans_fit_stats, dists = dists),
silhouette_avg_vec(kmeans_fit_stats, dists = dists),
0.4993742,
tolerance = 0.005
)
expect_equal(
avg_silhouette_vec(kmeans_fit_ClusterR, dists = dists),
silhouette_avg_vec(kmeans_fit_ClusterR, dists = dists),
0.5473414,
tolerance = 0.005
)
Expand All @@ -135,12 +135,12 @@ test_that("kmeans sihouette metrics work with new data", {
)

expect_equal(
avg_silhouette_vec(kmeans_fit_stats, new_data = new_data),
silhouette_avg_vec(kmeans_fit_stats, new_data = new_data),
0.5176315,
tolerance = 0.005
)
expect_equal(
avg_silhouette_vec(kmeans_fit_ClusterR, new_data = new_data),
silhouette_avg_vec(kmeans_fit_ClusterR, new_data = new_data),
0.5176315,
tolerance = 0.005
)
Expand Down
4 changes: 2 additions & 2 deletions vignettes/articles/k_means.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ values, a dataset must also be supplied to the function.

```{r}
kmeans_fit %>%
avg_silhouette(penguins)
silhouette_avg(penguins)
```

### Changing distance measures
Expand All @@ -297,7 +297,7 @@ my_dist_2 <- function(x, y) {
kmeans_fit %>% sse_ratio(dist_fun = my_dist_2)
kmeans_fit %>% avg_silhouette(penguins, dist_fun = my_dist_1)
kmeans_fit %>% silhouette_avg(penguins, dist_fun = my_dist_1)
```

For more on using metrics for cluster model selection, see the Tuning vignette.
Expand Down

0 comments on commit 5845b55

Please sign in to comment.