Skip to content

Commit

Permalink
sse_within -> sse_within_total
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Nov 1, 2022
1 parent 5845b55 commit 7188282
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 74 deletions.
8 changes: 4 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ S3method(sse_ratio,cluster_fit)
S3method(sse_ratio,workflow)
S3method(sse_total,cluster_fit)
S3method(sse_total,workflow)
S3method(sse_within,cluster_fit)
S3method(sse_within,workflow)
S3method(sse_within_total,cluster_fit)
S3method(sse_within_total,workflow)
S3method(tidy,cluster_fit)
S3method(translate_tidyclust,default)
S3method(translate_tidyclust,hier_clust)
Expand Down Expand Up @@ -99,8 +99,8 @@ export(sse_ratio)
export(sse_ratio_vec)
export(sse_total)
export(sse_total_vec)
export(sse_within)
export(sse_within_vec)
export(sse_within_total)
export(sse_within_total_vec)
export(tidy)
export(translate_tidyclust)
export(translate_tidyclust.default)
Expand Down
8 changes: 4 additions & 4 deletions R/extract_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extract_fit_summary.kmeans <- function(object, ...) {
cluster_names = names,
centroids = centroids,
n_members = object$size[reorder_clusts],
sse_within_total = object$withinss[reorder_clusts],
sse_within_total_total = object$withinss[reorder_clusts],
sse_total = object$totss,
orig_labels = unname(object$cluster),
cluster_assignments = cluster_asignments
Expand All @@ -75,7 +75,7 @@ extract_fit_summary.KMeansCluster <- function(object, ...) {
cluster_names = names,
centroids = centroids,
n_members = object$obs_per_cluster[reorder_clusts],
sse_within_total = object$WCSS_per_cluster[reorder_clusts],
sse_within_total_total = object$WCSS_per_cluster[reorder_clusts],
sse_total = object$total_SSE,
orig_labels = object$clusters,
cluster_assignments = cluster_asignments
Expand Down Expand Up @@ -103,7 +103,7 @@ extract_fit_summary.hclust <- function(object, ...) {
map(dplyr::summarize_all, mean) %>%
dplyr::bind_rows()

sse_within_total <- map2_dbl(
sse_within_total_total <- map2_dbl(
by_clust$data,
seq_len(n_clust),
~ sum(Rfast::dista(centroids[.y, ], .x))
Expand All @@ -113,7 +113,7 @@ extract_fit_summary.hclust <- function(object, ...) {
cluster_names = unique(clusts),
centroids = centroids,
n_members = unname(table(clusts)),
sse_within_total = sse_within_total,
sse_within_total_total = sse_within_total_total,
sse_total = sum(Rfast::dista(t(overall_centroid), training_data)),
orig_labels = NULL,
cluster_assignments = clusts
Expand Down
36 changes: 18 additions & 18 deletions R/metric-sse.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ within_cluster_sse <- function(object, new_data = NULL,
if (is.null(new_data)) {
res <- tibble::tibble(
.cluster = factor(summ$cluster_names),
wss = summ$sse_within_total,
wss = summ$sse_within_total_total,
n_members = summ$n_members
)
} else {
Expand Down Expand Up @@ -69,48 +69,48 @@ within_cluster_sse <- function(object, new_data = NULL,
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' sse_within(kmeans_fit)
#' sse_within_total(kmeans_fit)
#'
#' sse_within_vec(kmeans_fit)
#' sse_within_total_vec(kmeans_fit)
#' @export
sse_within <- function(object, ...) {
UseMethod("sse_within")
sse_within_total <- function(object, ...) {
UseMethod("sse_within_total")
}

sse_within <- new_cluster_metric(
sse_within,
sse_within_total <- new_cluster_metric(
sse_within_total,
direction = "zero"
)

#' @export
#' @rdname sse_within
sse_within.cluster_fit <- function(object, new_data = NULL,
#' @rdname sse_within_total
sse_within_total.cluster_fit <- function(object, new_data = NULL,
dist_fun = NULL, ...) {
if (is.null(dist_fun)) {
dist_fun <- Rfast::dista
}

res <- sse_within_impl(object, new_data, dist_fun, ...)
res <- sse_within_total_impl(object, new_data, dist_fun, ...)

tibble::tibble(
.metric = "sse_within",
.metric = "sse_within_total",
.estimator = "standard",
.estimate = res
)
}

#' @export
#' @rdname sse_within
sse_within.workflow <- sse_within.cluster_fit
#' @rdname sse_within_total
sse_within_total.workflow <- sse_within_total.cluster_fit

#' @export
#' @rdname sse_within
sse_within_vec <- function(object, new_data = NULL,
#' @rdname sse_within_total
sse_within_total_vec <- function(object, new_data = NULL,
dist_fun = Rfast::dista, ...) {
sse_within_impl(object, new_data, dist_fun, ...)
sse_within_total_impl(object, new_data, dist_fun, ...)
}

sse_within_impl <- function(object, new_data = NULL,
sse_within_total_impl <- function(object, new_data = NULL,
dist_fun = Rfast::dista, ...) {
sum(within_cluster_sse(object, new_data, dist_fun, ...)$wss, na.rm = TRUE)
}
Expand Down Expand Up @@ -253,6 +253,6 @@ sse_ratio_impl <- function(object,
new_data = NULL,
dist_fun = Rfast::dista,
...) {
sse_within_vec(object, new_data, dist_fun) /
sse_within_total_vec(object, new_data, dist_fun) /
sse_total_vec(object, new_data, dist_fun)
}
2 changes: 1 addition & 1 deletion R/tune_cluster.R
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ check_metrics <- function(x, object) {
if (is.null(x)) {
switch(mode,
partition = {
x <- cluster_metric_set(sse_within, sse_total)
x <- cluster_metric_set(sse_within_total, sse_total)
},
unknown = {
rlang::abort(
Expand Down
2 changes: 1 addition & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ reference:
- silhouette_avg
- sse_ratio
- sse_total
- sse_within
- sse_within_total
- title: Tuning
desc: >
Functions to allow multiple cluster specifications to be fit at once.
Expand Down
2 changes: 1 addition & 1 deletion dev/cross_val_kmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ for (k in 2:10) {
km_fit <- km %>% fit(~., data = tmp_train)

wss <- km_fit %>%
sse_within_total(tmp_test)
sse_within_total_total(tmp_test)

wss_2 <- km_fit$fit$tot.withinss

Expand Down
22 changes: 11 additions & 11 deletions man/sse_within.Rd → man/sse_within_total.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/tune_cluster.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@
tmp <- tune::show_best(res)
Condition
Warning:
No value of `metric` was given; metric 'sse_within' will be used.
No value of `metric` was given; metric 'sse_within_total' will be used.

---

Code
tmp <- tune::select_best(res)
Condition
Warning:
No value of `metric` was given; metric 'sse_within' will be used.
No value of `metric` was given; metric 'sse_within_total' will be used.

8 changes: 4 additions & 4 deletions tests/testthat/test-cluster_metric_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ test_that("cluster_metric_set works", {

kmeans_fit <- fit(kmeans_spec, ~., mtcars)

my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within, silhouette_avg)
my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within_total, silhouette_avg)

exp_res <- tibble::tibble(
.metric = c("sse_ratio", "sse_total", "sse_within", "silhouette_avg"),
.metric = c("sse_ratio", "sse_total", "sse_within_total", "silhouette_avg"),
.estimator = "standard",
.estimate = vapply(
list(sse_ratio_vec, sse_total_vec, sse_within_vec, silhouette_avg_vec),
list(sse_ratio_vec, sse_total_vec, sse_within_total_vec, silhouette_avg_vec),
function(x) x(kmeans_fit, new_data = mtcars),
FUN.VALUE = numeric(1)
)
Expand All @@ -23,7 +23,7 @@ test_that("cluster_metric_set works", {

expect_snapshot(error = TRUE, my_metrics(kmeans_fit))

my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within)
my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within_total)

expect_equal(
my_metrics(kmeans_fit),
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-k_means_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ test_that("kmeans sse metrics work", {
)

expect_equal(
sse_within_vec(kmeans_fit_stats),
sse_within_total_vec(kmeans_fit_stats),
km_orig$tot.withinss,
tolerance = 0.005
)
Expand All @@ -41,7 +41,7 @@ test_that("kmeans sse metrics work", {
)

expect_equal(
sse_within_vec(kmeans_fit_ClusterR),
sse_within_total_vec(kmeans_fit_ClusterR),
sum(km_orig_2$WCSS_per_cluster),
tolerance = 0.005
)
Expand Down Expand Up @@ -70,7 +70,7 @@ test_that("kmeans sse metrics work on new data", {
)

expect_equal(
sse_within_vec(kmeans_fit_stats, new_data),
sse_within_total_vec(kmeans_fit_stats, new_data),
15654.38,
tolerance = 0.005
)
Expand Down
Loading

0 comments on commit 7188282

Please sign in to comment.