diff --git a/NEWS.md b/NEWS.md index fcd1a6d4..33d98428 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # tidyclust (development version) +* `silhouette()` and `silhouette_avg()` now return NAs instead of erroring when applied to a clustering object with 1 cluster. (#104) + * Fixed bug where `extract_cluster_assignment()` doesn't work for `hier_clust()` models in workflows where `num_clusters` is specified in `extract_cluster_assignment()`. # tidyclust 0.1.0 diff --git a/R/metric-silhouette.R b/R/metric-silhouette.R index 9015bdc2..b885e39b 100644 --- a/R/metric-silhouette.R +++ b/R/metric-silhouette.R @@ -31,6 +31,16 @@ silhouette <- function(object, new_data = NULL, dists = NULL, sil <- cluster::silhouette(clust_int, preproc$dists) + if (!inherits(sil, "silhouette")) { + res <- tibble::tibble( + cluster = preproc$clusters, + neighbor = factor(rep(NA_character_, length(preproc$clusters)), + levels = levels(preproc$clusters)), + sil_width = NA_real_ + ) + return(res) + } + sil %>% unclass() %>% tibble::as_tibble() %>% diff --git a/tests/testthat/test-metric-silhouette.R b/tests/testthat/test-metric-silhouette.R new file mode 100644 index 00000000..04049cf4 --- /dev/null +++ b/tests/testthat/test-metric-silhouette.R @@ -0,0 +1,18 @@ +test_that("multiplication works", { + kmeans_spec <- k_means(num_clusters = 1) %>% + set_engine("stats") + + kmeans_fit <- fit(kmeans_spec, ~., mtcars) + + dists <- mtcars %>% + as.matrix() %>% + dist() + + res <- silhouette(kmeans_fit, dists = dists) + exp_res <- tibble::tibble( + cluster = rep(factor("Cluster_1"), 32), + neighbor = rep(factor(NA, levels = "Cluster_1"), 32), + sil_width = rep(NA_real_, 32) + ) + expect_identical(res, exp_res) +})