From e388e00ac18df7bbdd43a9544c481b2b7ae6bb4a Mon Sep 17 00:00:00 2001 From: "Kelly N. Bodwin" Date: Tue, 26 Jul 2022 03:21:13 -0400 Subject: [PATCH] merge main stuff --- .gitignore | 14 +- DESCRIPTION | 114 +- NAMESPACE | 346 +++--- R/arguments.R | 214 ++-- R/augment.R | 66 +- R/dials.R | 32 +- R/engines.R | 172 +-- R/extract_assignment.R | 122 +- R/extract_characterization.R | 40 +- R/extract_summary.R | 196 ++-- R/finalize.R | 124 +- R/fit.R | 668 +++++------ R/hier_clust.R | 192 +-- R/hier_clust_data.R | 264 ++--- R/k_means.R | 306 ++--- R/k_means_data.R | 228 ++-- R/metric-silhouette.R | 214 ++-- R/predict.R | 288 ++--- R/predict_helpers.R | 202 ++-- R/reexports.R | 144 +-- R/translate.R | 302 ++--- R/tunable.R | 162 +-- R/update.R | 56 +- README.Rmd | 148 +-- README.md | 290 ++--- _pkgdown.yml | 134 +-- dev/cross_val_kmeans.R | 216 ++-- dev/kmeans.Rmd | 264 ++--- dev/test_hc.R | 66 +- dev/test_hclust_predict.R | 78 +- dev/to do | 42 +- man/augment.Rd | 62 +- man/avg_silhouette.Rd | 110 +- man/extract_centroids.Rd | 52 +- man/extract_cluster_assignment.Rd | 50 +- man/extract_fit_summary.Rd | 56 +- man/figures/logo.svg | 1290 ++++++++++----------- man/finalize_model_tidyclust.Rd | 70 +- man/fit.Rd | 210 ++-- man/hclust_fit.Rd | 72 +- man/hier_clust.Rd | 86 +- man/k_means.Rd | 54 +- man/num_clusters.Rd | 48 +- man/predict.cluster_fit.Rd | 136 +-- man/reexports.Rd | 82 +- man/silhouettes.Rd | 72 +- man/sse_ratio.Rd | 78 +- man/tidyclust_update.Rd | 84 +- man/tot_sse.Rd | 84 +- man/tot_wss.Rd | 84 +- man/translate_tidyclust.Rd | 86 +- man/within_cluster_sse.Rd | 66 +- tests/testthat/_snaps/arguments.md | 70 +- tests/testthat/_snaps/hier_clust.md | 118 +- tests/testthat/_snaps/hier_clust.new.md | 118 +- tests/testthat/_snaps/k_means.md | 130 +-- tests/testthat/_snaps/registration.md | 678 +++++------ tests/testthat/helper-tidyclust-package.R | 54 +- tests/testthat/test-arguments.R | 90 +- tests/testthat/test-augment.R | 56 +- tests/testthat/test-cluster_metric_set.R | 84 +- tests/testthat/test-control.R | 28 +- tests/testthat/test-extract_summary.R | 80 +- tests/testthat/test-hier_clust.R | 170 +-- tests/testthat/test-k_means.R | 202 ++-- tests/testthat/test-k_means_diagnostics.R | 206 ++-- tests/testthat/test-predict_formats.R | 20 +- tests/testthat/test-tune_cluster.R | 766 ++++++------ 68 files changed, 5753 insertions(+), 5753 deletions(-) diff --git a/.gitignore b/.gitignore index ce37a082..114b666f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ -.Rproj.user -.Rhistory -.Rdata -.httr-oauth -.DS_Store -docs -hex sticker/ +.Rproj.user +.Rhistory +.Rdata +.httr-oauth +.DS_Store +docs +hex sticker/ diff --git a/DESCRIPTION b/DESCRIPTION index 1211cb41..11338e59 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,57 +1,57 @@ -Package: tidyclust -Title: What the Package Does (One Line, Title Case) -Version: 0.0.0.9000 -Authors@R: c( - person("Emil", "Hvitfeldt", , "emilhhvitfeldt@gmail.com", role = c("aut", "cre"), - comment = c(ORCID = "0000-0002-0679-1945")), - person("Kelly", "Bodwin", , "kelly@bodwin.us", role = "aut"), - person("RStudio", role = c("cph", "fnd")) - ) -Description: What the package does (one paragraph). -License: MIT + file LICENSE -URL: https://github.com/EmilHvitfeldt/tidyclust -BugReports: https://github.com/EmilHvitfeldt/tidyclust/issues -Imports: - cli, - dials, - dplyr, - forcats, - foreach, - generics, - glue, - hardhat (>= 0.1.6.9001), - magrittr, - parsnip, - prettyunits, - RcppHungarian, - rlang, - rsample, - stats, - tibble, - tidyr, - tune, - utils, - vctrs -Suggests: - cluster, - ClusterR, - covr, - flexclust, - janitor, - knitr, - modeldata, - recipes, - rmarkdown, - Rfast, - testthat (>= 3.0.0), - workflows -VignetteBuilder: - knitr -Remotes: - tidymodels/parsnip, - tidymodels/workflows@celery -Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins -Config/testthat/edition: 3 -Encoding: UTF-8 -Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.0.9000 +Package: tidyclust +Title: What the Package Does (One Line, Title Case) +Version: 0.0.0.9000 +Authors@R: c( + person("Emil", "Hvitfeldt", , "emilhhvitfeldt@gmail.com", role = c("aut", "cre"), + comment = c(ORCID = "0000-0002-0679-1945")), + person("Kelly", "Bodwin", , "kelly@bodwin.us", role = "aut"), + person("RStudio", role = c("cph", "fnd")) + ) +Description: What the package does (one paragraph). +License: MIT + file LICENSE +URL: https://github.com/EmilHvitfeldt/tidyclust +BugReports: https://github.com/EmilHvitfeldt/tidyclust/issues +Imports: + cli, + dials, + dplyr, + forcats, + foreach, + generics, + glue, + hardhat (>= 0.1.6.9001), + magrittr, + parsnip, + prettyunits, + RcppHungarian, + rlang, + rsample, + stats, + tibble, + tidyr, + tune, + utils, + vctrs +Suggests: + cluster, + ClusterR, + covr, + flexclust, + janitor, + knitr, + modeldata, + recipes, + rmarkdown, + Rfast, + testthat (>= 3.0.0), + workflows +VignetteBuilder: + knitr +Remotes: + tidymodels/parsnip, + tidymodels/workflows@celery +Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins +Config/testthat/edition: 3 +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +RoxygenNote: 7.2.0.9000 diff --git a/NAMESPACE b/NAMESPACE index 77354f66..8dca7bb0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,173 +1,173 @@ -# Generated by roxygen2: do not edit by hand - -S3method(as_tibble,cluster_metric_set) -S3method(augment,cluster_fit) -S3method(avg_silhouette,cluster_fit) -S3method(avg_silhouette,workflow) -S3method(extract_cluster_assignment,KMeansCluster) -S3method(extract_cluster_assignment,cluster_fit) -S3method(extract_cluster_assignment,hclust) -S3method(extract_cluster_assignment,kmeans) -S3method(extract_cluster_assignment,workflow) -S3method(extract_fit_summary,KMeansCluster) -S3method(extract_fit_summary,cluster_fit) -S3method(extract_fit_summary,hclust) -S3method(extract_fit_summary,kmeans) -S3method(extract_fit_summary,workflow) -S3method(extract_parameter_set_dials,cluster_spec) -S3method(fit,cluster_spec) -S3method(fit_xy,cluster_spec) -S3method(glance,cluster_fit) -S3method(load_pkgs,cluster_spec) -S3method(min_grid,cluster_spec) -S3method(predict,cluster_fit) -S3method(predict_cluster,cluster_fit) -S3method(predict_raw,cluster_fit) -S3method(print,cluster_fit) -S3method(print,cluster_metric_set) -S3method(print,cluster_spec) -S3method(print,control_cluster) -S3method(print,hier_clust) -S3method(print,k_means) -S3method(sse_ratio,cluster_fit) -S3method(sse_ratio,workflow) -S3method(tidy,cluster_fit) -S3method(tot_sse,cluster_fit) -S3method(tot_sse,workflow) -S3method(tot_wss,cluster_fit) -S3method(tot_wss,workflow) -S3method(translate_tidyclust,default) -S3method(translate_tidyclust,hier_clust) -S3method(translate_tidyclust,k_means) -S3method(tune_cluster,cluster_spec) -S3method(tune_cluster,default) -S3method(tune_cluster,workflow) -S3method(update,k_means) -export("%>%") -export(.convert_form_to_x_fit) -export(.convert_form_to_x_new) -export(.convert_x_to_form_fit) -export(.convert_x_to_form_new) -export(ClusterR_kmeans_fit) -export(augment) -export(avg_silhouette) -export(avg_silhouette_vec) -export(check_empty_ellipse_tidyclust) -export(check_model_doesnt_exist_tidyclust) -export(check_model_exists_tidyclust) -export(cluster_metric_set) -export(control_cluster) -export(enrichment) -export(extract_centroids) -export(extract_cluster_assignment) -export(extract_fit_parsnip) -export(extract_fit_summary) -export(extract_parameter_set_dials) -export(extract_preprocessor) -export(extract_spec_parsnip) -export(finalize_model_tidyclust) -export(finalize_workflow_tidyclust) -export(fit) -export(fit.cluster_spec) -export(fit_xy) -export(fit_xy.cluster_spec) -export(get_dependency_tidyclust) -export(get_encoding_tidyclust) -export(get_fit_tidyclust) -export(get_from_env_tidyclust) -export(get_model_env_tidyclust) -export(get_pred_type_tidyclust) -export(glance) -export(hclust_fit) -export(hier_clust) -export(k_means) -export(load_pkgs) -export(make_classes_tidyclust) -export(min_grid) -export(new_cluster_metric) -export(new_cluster_spec) -export(num_clusters) -export(predict.cluster_fit) -export(predict_cluster) -export(predict_cluster.cluster_fit) -export(predict_raw) -export(predict_raw.cluster_fit) -export(prepare_data) -export(reconcile_clusterings) -export(required_pkgs) -export(set_args) -export(set_args.cluster_spec) -export(set_dependency_tidyclust) -export(set_encoding_tidyclust) -export(set_engine) -export(set_engine.cluster_spec) -export(set_env_val_tidyclust) -export(set_fit_tidyclust) -export(set_mode) -export(set_mode.cluster_spec) -export(set_model_arg_tidyclust) -export(set_model_engine_tidyclust) -export(set_model_mode_tidyclust) -export(set_new_model_tidyclust) -export(set_pred_tidyclust) -export(show_model_info_tidyclust) -export(silhouettes) -export(sse_ratio) -export(sse_ratio_vec) -export(tidy) -export(tot_sse) -export(tot_sse_vec) -export(tot_wss) -export(tot_wss_vec) -export(translate_tidyclust) -export(translate_tidyclust.default) -export(tune) -export(tune_cluster) -export(within_cluster_sse) -importFrom(dplyr,bind_cols) -importFrom(generics,augment) -importFrom(generics,fit) -importFrom(generics,fit_xy) -importFrom(generics,glance) -importFrom(generics,min_grid) -importFrom(generics,required_pkgs) -importFrom(generics,tidy) -importFrom(hardhat,extract_fit_parsnip) -importFrom(hardhat,extract_parameter_set_dials) -importFrom(hardhat,extract_preprocessor) -importFrom(hardhat,extract_spec_parsnip) -importFrom(hardhat,tune) -importFrom(magrittr,"%>%") -importFrom(parsnip,make_call) -importFrom(parsnip,maybe_data_frame) -importFrom(parsnip,maybe_matrix) -importFrom(parsnip,model_printer) -importFrom(parsnip,null_value) -importFrom(parsnip,predict_raw) -importFrom(parsnip,set_args) -importFrom(parsnip,set_engine) -importFrom(parsnip,set_mode) -importFrom(parsnip,show_call) -importFrom(rlang,"%||%") -importFrom(rlang,abort) -importFrom(rlang,as_function) -importFrom(rlang,enquo) -importFrom(rlang,enquos) -importFrom(rlang,get_expr) -importFrom(rlang,global_env) -importFrom(rlang,is_logical) -importFrom(rlang,is_true) -importFrom(rlang,missing_arg) -importFrom(rlang,quos) -importFrom(rlang,set_names) -importFrom(rlang,sym) -importFrom(stats,.getXlevels) -importFrom(stats,as.formula) -importFrom(stats,model.frame) -importFrom(stats,model.matrix) -importFrom(stats,model.offset) -importFrom(stats,model.weights) -importFrom(stats,na.omit) -importFrom(tibble,as_tibble) -importFrom(tune,load_pkgs) -importFrom(utils,capture.output) +# Generated by roxygen2: do not edit by hand + +S3method(as_tibble,cluster_metric_set) +S3method(augment,cluster_fit) +S3method(avg_silhouette,cluster_fit) +S3method(avg_silhouette,workflow) +S3method(extract_cluster_assignment,KMeansCluster) +S3method(extract_cluster_assignment,cluster_fit) +S3method(extract_cluster_assignment,hclust) +S3method(extract_cluster_assignment,kmeans) +S3method(extract_cluster_assignment,workflow) +S3method(extract_fit_summary,KMeansCluster) +S3method(extract_fit_summary,cluster_fit) +S3method(extract_fit_summary,hclust) +S3method(extract_fit_summary,kmeans) +S3method(extract_fit_summary,workflow) +S3method(extract_parameter_set_dials,cluster_spec) +S3method(fit,cluster_spec) +S3method(fit_xy,cluster_spec) +S3method(glance,cluster_fit) +S3method(load_pkgs,cluster_spec) +S3method(min_grid,cluster_spec) +S3method(predict,cluster_fit) +S3method(predict_cluster,cluster_fit) +S3method(predict_raw,cluster_fit) +S3method(print,cluster_fit) +S3method(print,cluster_metric_set) +S3method(print,cluster_spec) +S3method(print,control_cluster) +S3method(print,hier_clust) +S3method(print,k_means) +S3method(set_args,cluster_spec) +S3method(set_engine,cluster_spec) +S3method(set_mode,cluster_spec) +S3method(sse_ratio,cluster_fit) +S3method(sse_ratio,workflow) +S3method(tidy,cluster_fit) +S3method(tot_sse,cluster_fit) +S3method(tot_sse,workflow) +S3method(tot_wss,cluster_fit) +S3method(tot_wss,workflow) +S3method(translate_tidyclust,default) +S3method(translate_tidyclust,hier_clust) +S3method(translate_tidyclust,k_means) +S3method(tune_cluster,cluster_spec) +S3method(tune_cluster,default) +S3method(tune_cluster,workflow) +S3method(update,k_means) +export("%>%") +export(.convert_form_to_x_fit) +export(.convert_form_to_x_new) +export(.convert_x_to_form_fit) +export(.convert_x_to_form_new) +export(ClusterR_kmeans_fit) +export(augment) +export(avg_silhouette) +export(avg_silhouette_vec) +export(check_empty_ellipse_tidyclust) +export(check_model_doesnt_exist_tidyclust) +export(check_model_exists_tidyclust) +export(cluster_metric_set) +export(control_cluster) +export(enrichment) +export(extract_centroids) +export(extract_cluster_assignment) +export(extract_fit_parsnip) +export(extract_fit_summary) +export(extract_parameter_set_dials) +export(extract_preprocessor) +export(extract_spec_parsnip) +export(finalize_model_tidyclust) +export(finalize_workflow_tidyclust) +export(fit) +export(fit.cluster_spec) +export(fit_xy) +export(fit_xy.cluster_spec) +export(get_dependency_tidyclust) +export(get_encoding_tidyclust) +export(get_fit_tidyclust) +export(get_from_env_tidyclust) +export(get_model_env_tidyclust) +export(get_pred_type_tidyclust) +export(glance) +export(hclust_fit) +export(hier_clust) +export(k_means) +export(load_pkgs) +export(make_classes_tidyclust) +export(min_grid) +export(new_cluster_metric) +export(new_cluster_spec) +export(num_clusters) +export(predict.cluster_fit) +export(predict_cluster) +export(predict_cluster.cluster_fit) +export(predict_raw) +export(predict_raw.cluster_fit) +export(prepare_data) +export(reconcile_clusterings) +export(required_pkgs) +export(set_args) +export(set_dependency_tidyclust) +export(set_encoding_tidyclust) +export(set_engine) +export(set_env_val_tidyclust) +export(set_fit_tidyclust) +export(set_mode) +export(set_model_arg_tidyclust) +export(set_model_engine_tidyclust) +export(set_model_mode_tidyclust) +export(set_new_model_tidyclust) +export(set_pred_tidyclust) +export(show_model_info_tidyclust) +export(silhouettes) +export(sse_ratio) +export(sse_ratio_vec) +export(tidy) +export(tot_sse) +export(tot_sse_vec) +export(tot_wss) +export(tot_wss_vec) +export(translate_tidyclust) +export(translate_tidyclust.default) +export(tune) +export(tune_cluster) +export(within_cluster_sse) +importFrom(dplyr,bind_cols) +importFrom(generics,augment) +importFrom(generics,fit) +importFrom(generics,fit_xy) +importFrom(generics,glance) +importFrom(generics,min_grid) +importFrom(generics,required_pkgs) +importFrom(generics,tidy) +importFrom(hardhat,extract_fit_parsnip) +importFrom(hardhat,extract_parameter_set_dials) +importFrom(hardhat,extract_preprocessor) +importFrom(hardhat,extract_spec_parsnip) +importFrom(hardhat,tune) +importFrom(magrittr,"%>%") +importFrom(parsnip,make_call) +importFrom(parsnip,maybe_data_frame) +importFrom(parsnip,maybe_matrix) +importFrom(parsnip,model_printer) +importFrom(parsnip,null_value) +importFrom(parsnip,predict_raw) +importFrom(parsnip,set_args) +importFrom(parsnip,set_engine) +importFrom(parsnip,set_mode) +importFrom(parsnip,show_call) +importFrom(rlang,"%||%") +importFrom(rlang,abort) +importFrom(rlang,as_function) +importFrom(rlang,enquo) +importFrom(rlang,enquos) +importFrom(rlang,get_expr) +importFrom(rlang,global_env) +importFrom(rlang,is_logical) +importFrom(rlang,is_true) +importFrom(rlang,missing_arg) +importFrom(rlang,quos) +importFrom(rlang,set_names) +importFrom(rlang,sym) +importFrom(stats,.getXlevels) +importFrom(stats,as.formula) +importFrom(stats,model.frame) +importFrom(stats,model.matrix) +importFrom(stats,model.offset) +importFrom(stats,model.weights) +importFrom(stats,na.omit) +importFrom(tibble,as_tibble) +importFrom(tune,load_pkgs) +importFrom(utils,capture.output) diff --git a/R/arguments.R b/R/arguments.R index 52279b47..62e9bef2 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -1,107 +1,107 @@ -check_eng_args <- function(args, obj, core_args) { - # Make sure that we are not trying to modify an argument that - # is explicitly protected in the method metadata or arg_key - protected_args <- unique(c(obj$protect, core_args)) - common_args <- intersect(protected_args, names(args)) - if (length(common_args) > 0) { - args <- args[!(names(args) %in% common_args)] - common_args <- paste0(common_args, collapse = ", ") - rlang::warn(glue::glue( - "The following arguments cannot be manually modified ", - "and were removed: {common_args}." - )) - } - args -} - -make_x_call <- function(object, target) { - fit_args <- object$method$fit$args - - # Get the arguments related to data: - if (is.null(object$method$fit$data)) { - data_args <- c(x = "x") - } else { - data_args <- object$method$fit$data - } - - object$method$fit$args[[unname(data_args["x"])]] <- - switch(target, - none = rlang::expr(x), - data.frame = rlang::expr(maybe_data_frame(x)), - matrix = rlang::expr(maybe_matrix(x)), - rlang::abort(glue::glue("Invalid data type target: {target}.")) - ) - - fit_call <- make_call( - fun = object$method$fit$func["fun"], - ns = object$method$fit$func["pkg"], - object$method$fit$args - ) - - fit_call -} - -make_form_call <- function(object, env = NULL) { - fit_args <- object$method$fit$args - - # Get the arguments related to data: - if (is.null(object$method$fit$data)) { - data_args <- c(formula = "formula", data = "data") - } else { - data_args <- object$method$fit$data - } - - # add data arguments - for (i in seq_along(data_args)) { - fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i]) - } - - # sub in actual formula - fit_args[[unname(data_args["formula"])]] <- env$formula - - fit_call <- make_call( - fun = object$method$fit$func["fun"], - ns = object$method$fit$func["pkg"], - fit_args - ) - fit_call -} - -#' @export -set_args.cluster_spec <- function(object, ...) { - the_dots <- enquos(...) - if (length(the_dots) == 0) - rlang::abort("Please pass at least one named argument.") - main_args <- names(object$args) - new_args <- names(the_dots) - for (i in new_args) { - if (any(main_args == i)) { - object$args[[i]] <- the_dots[[i]] - } else { - object$eng_args[[i]] <- the_dots[[i]] - } - } - new_cluster_spec( - cls = class(object)[1], - args = object$args, - eng_args = object$eng_args, - mode = object$mode, - method = NULL, - engine = object$engine - ) -} - -#' @export -set_mode.cluster_spec <- function(object, mode) { - cls <- class(object)[1] - if (rlang::is_missing(mode)) { - spec_modes <- rlang::env_get( - get_model_env_tidyclust(), - paste0(cls, "_modes") - ) - stop_incompatible_mode(spec_modes, cls = cls) - } - check_spec_mode_engine_val(cls, object$engine, mode) - object$mode <- mode - object -} +check_eng_args <- function(args, obj, core_args) { + # Make sure that we are not trying to modify an argument that + # is explicitly protected in the method metadata or arg_key + protected_args <- unique(c(obj$protect, core_args)) + common_args <- intersect(protected_args, names(args)) + if (length(common_args) > 0) { + args <- args[!(names(args) %in% common_args)] + common_args <- paste0(common_args, collapse = ", ") + rlang::warn(glue::glue( + "The following arguments cannot be manually modified ", + "and were removed: {common_args}." + )) + } + args +} + +make_x_call <- function(object, target) { + fit_args <- object$method$fit$args + + # Get the arguments related to data: + if (is.null(object$method$fit$data)) { + data_args <- c(x = "x") + } else { + data_args <- object$method$fit$data + } + + object$method$fit$args[[unname(data_args["x"])]] <- + switch(target, + none = rlang::expr(x), + data.frame = rlang::expr(maybe_data_frame(x)), + matrix = rlang::expr(maybe_matrix(x)), + rlang::abort(glue::glue("Invalid data type target: {target}.")) + ) + + fit_call <- make_call( + fun = object$method$fit$func["fun"], + ns = object$method$fit$func["pkg"], + object$method$fit$args + ) + + fit_call +} + +make_form_call <- function(object, env = NULL) { + fit_args <- object$method$fit$args + + # Get the arguments related to data: + if (is.null(object$method$fit$data)) { + data_args <- c(formula = "formula", data = "data") + } else { + data_args <- object$method$fit$data + } + + # add data arguments + for (i in seq_along(data_args)) { + fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i]) + } + + # sub in actual formula + fit_args[[unname(data_args["formula"])]] <- env$formula + + fit_call <- make_call( + fun = object$method$fit$func["fun"], + ns = object$method$fit$func["pkg"], + fit_args + ) + fit_call +} + +#' @export +set_args.cluster_spec <- function(object, ...) { + the_dots <- enquos(...) + if (length(the_dots) == 0) + rlang::abort("Please pass at least one named argument.") + main_args <- names(object$args) + new_args <- names(the_dots) + for (i in new_args) { + if (any(main_args == i)) { + object$args[[i]] <- the_dots[[i]] + } else { + object$eng_args[[i]] <- the_dots[[i]] + } + } + new_cluster_spec( + cls = class(object)[1], + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + +#' @export +set_mode.cluster_spec <- function(object, mode) { + cls <- class(object)[1] + if (rlang::is_missing(mode)) { + spec_modes <- rlang::env_get( + get_model_env_tidyclust(), + paste0(cls, "_modes") + ) + stop_incompatible_mode(spec_modes, cls = cls) + } + check_spec_mode_engine_val(cls, object$engine, mode) + object$mode <- mode + object +} diff --git a/R/augment.R b/R/augment.R index c8d7299d..05b422e4 100644 --- a/R/augment.R +++ b/R/augment.R @@ -1,33 +1,33 @@ -#' Augment data with predictions -#' -#' `augment()` will add column(s) for predictions to the given data. -#' -#' For partition models, a `.pred_cluster` column is added. -#' -#' @param x A `cluster_fit` object produced by [fit.cluster_spec()] or -#' [fit_xy.cluster_spec()] . -#' @param new_data A data frame or matrix. -#' @param ... Not currently used. -#' @rdname augment -#' @export -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' kmeans_fit %>% -#' augment(new_data = mtcars) -augment.cluster_fit <- function(x, new_data, ...) { - ret <- new_data - if (x$spec$mode == "partition") { - check_spec_pred_type(x, "cluster") - ret <- dplyr::bind_cols( - ret, - stats::predict(x, new_data = new_data) - ) - } else { - rlang::abort(paste("Unknown mode:", x$spec$mode)) - } - as_tibble(ret) -} +#' Augment data with predictions +#' +#' `augment()` will add column(s) for predictions to the given data. +#' +#' For partition models, a `.pred_cluster` column is added. +#' +#' @param x A `cluster_fit` object produced by [fit.cluster_spec()] or +#' [fit_xy.cluster_spec()] . +#' @param new_data A data frame or matrix. +#' @param ... Not currently used. +#' @rdname augment +#' @export +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' kmeans_fit %>% +#' augment(new_data = mtcars) +augment.cluster_fit <- function(x, new_data, ...) { + ret <- new_data + if (x$spec$mode == "partition") { + check_spec_pred_type(x, "cluster") + ret <- dplyr::bind_cols( + ret, + stats::predict(x, new_data = new_data) + ) + } else { + rlang::abort(paste("Unknown mode:", x$spec$mode)) + } + as_tibble(ret) +} diff --git a/R/dials.R b/R/dials.R index 9b4baf42..d4f91ffd 100644 --- a/R/dials.R +++ b/R/dials.R @@ -1,16 +1,16 @@ -#' Number of Clusters -#' -#' @inheritParams dials::Laplace -#' @examples -#' num_clusters() -#' @export -num_clusters <- function(range = c(1L, 10L), trans = NULL) { - dials::new_quant_param( - type = "integer", - range = range, - inclusive = c(TRUE, TRUE), - trans = trans, - label = c(num_clusters = "# Clusters"), - finalize = NULL - ) -} +#' Number of Clusters +#' +#' @inheritParams dials::Laplace +#' @examples +#' num_clusters() +#' @export +num_clusters <- function(range = c(1L, 10L), trans = NULL) { + dials::new_quant_param( + type = "integer", + range = range, + inclusive = c(TRUE, TRUE), + trans = trans, + label = c(num_clusters = "# Clusters"), + finalize = NULL + ) +} diff --git a/R/engines.R b/R/engines.R index bfb10c00..5d6d4cab 100644 --- a/R/engines.R +++ b/R/engines.R @@ -1,86 +1,86 @@ -#' @export -set_engine.cluster_spec <- function(object, engine, ...) { - mod_type <- class(object)[1] - - if (rlang::is_missing(engine)) { - stop_missing_engine(mod_type) - } - object$engine <- engine - check_spec_mode_engine_val(mod_type, object$engine, object$mode) - - new_cluster_spec( - cls = mod_type, - args = object$args, - eng_args = enquos(...), - mode = object$mode, - method = NULL, - engine = object$engine - ) -} - -stop_missing_engine <- function(cls) { - info <- - get_from_env_tidyclust(cls) %>% - dplyr::group_by(mode) %>% - dplyr::summarize( - msg = paste0( - unique(mode), " {", - paste0(unique(engine), collapse = ", "), - "}" - ), - .groups = "drop" - ) - if (nrow(info) == 0) { - rlang::abort(paste0("No known engines for `", cls, "()`.")) - } - msg <- paste0(info$msg, collapse = ", ") - msg <- paste("Missing engine. Possible mode/engine combinations are:", msg) - rlang::abort(msg) -} - -load_libs <- function(x, quiet, attach = FALSE) { - for (pkg in x$method$libs) { - if (!attach) { - suppressPackageStartupMessages(requireNamespace(pkg, quietly = quiet)) - } else { - library(pkg, character.only = TRUE) - } - } - invisible(x) -} - -specific_model <- function(x) { - cls <- class(x) - cls[cls != "cluster_spec"] -} - -possible_engines <- function(object, ...) { - m_env <- get_model_env_tidyclust() - engs <- rlang::env_get(m_env, specific_model(object)) - unique(engs$engine) -} - -shhhh <- function(x) { - suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE)) -} - -is_installed <- function(pkg) { - res <- try(shhhh(pkg), silent = TRUE) - res -} - -check_installs <- function(x) { - if (length(x$method$libs) > 0) { - is_inst <- map_lgl(x$method$libs, is_installed) - if (any(!is_inst)) { - missing_pkg <- x$method$libs[!is_inst] - missing_pkg <- paste0(missing_pkg, collapse = ", ") - rlang::abort( - glue::glue( - "This engine requires some package installs: ", - glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ") - ) - ) - } - } -} +#' @export +set_engine.cluster_spec <- function(object, engine, ...) { + mod_type <- class(object)[1] + + if (rlang::is_missing(engine)) { + stop_missing_engine(mod_type) + } + object$engine <- engine + check_spec_mode_engine_val(mod_type, object$engine, object$mode) + + new_cluster_spec( + cls = mod_type, + args = object$args, + eng_args = enquos(...), + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + +stop_missing_engine <- function(cls) { + info <- + get_from_env_tidyclust(cls) %>% + dplyr::group_by(mode) %>% + dplyr::summarize( + msg = paste0( + unique(mode), " {", + paste0(unique(engine), collapse = ", "), + "}" + ), + .groups = "drop" + ) + if (nrow(info) == 0) { + rlang::abort(paste0("No known engines for `", cls, "()`.")) + } + msg <- paste0(info$msg, collapse = ", ") + msg <- paste("Missing engine. Possible mode/engine combinations are:", msg) + rlang::abort(msg) +} + +load_libs <- function(x, quiet, attach = FALSE) { + for (pkg in x$method$libs) { + if (!attach) { + suppressPackageStartupMessages(requireNamespace(pkg, quietly = quiet)) + } else { + library(pkg, character.only = TRUE) + } + } + invisible(x) +} + +specific_model <- function(x) { + cls <- class(x) + cls[cls != "cluster_spec"] +} + +possible_engines <- function(object, ...) { + m_env <- get_model_env_tidyclust() + engs <- rlang::env_get(m_env, specific_model(object)) + unique(engs$engine) +} + +shhhh <- function(x) { + suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE)) +} + +is_installed <- function(pkg) { + res <- try(shhhh(pkg), silent = TRUE) + res +} + +check_installs <- function(x) { + if (length(x$method$libs) > 0) { + is_inst <- map_lgl(x$method$libs, is_installed) + if (any(!is_inst)) { + missing_pkg <- x$method$libs[!is_inst] + missing_pkg <- paste0(missing_pkg, collapse = ", ") + rlang::abort( + glue::glue( + "This engine requires some package installs: ", + glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ") + ) + ) + } + } +} diff --git a/R/extract_assignment.R b/R/extract_assignment.R index 5050532e..7f0655c1 100644 --- a/R/extract_assignment.R +++ b/R/extract_assignment.R @@ -1,61 +1,61 @@ -#' Extract cluster assignments from model -#' -#' @param object An cluster_spec object. -#' @param ... Other arguments passed to methods. -#' -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' kmeans_fit %>% -#' extract_cluster_assignment() -#' @export -extract_cluster_assignment <- function(object, ...) { - UseMethod("extract_cluster_assignment") -} - -#' @export -extract_cluster_assignment.cluster_fit <- function(object, ...) { - extract_cluster_assignment(object$fit, ...) -} - -#' @export -extract_cluster_assignment.workflow <- function(object, ...) { - extract_cluster_assignment(object$fit$fit$fit) -} - -#' @export -extract_cluster_assignment.kmeans <- function(object, ...) { - cluster_assignment_tibble(object$cluster, length(object$size)) -} - -#' @export -extract_cluster_assignment.KMeansCluster <- function(object, ...) { - cluster_assignment_tibble(object$clusters, length(object$obs_per_cluster)) -} - -#' @export -extract_cluster_assignment.hclust <- function(object, ...) { - - # if k or h is passed in the dots, use those. Otherwise, use attributes - # from original model specification - args <- list(...) - if (!("k" %in% names(args) | "cut_height" %in% names(args))) { - k <- attr(object, "k") - cut_height <- attr(object, "cut_height") - } - clusters <- stats::cutree(object, k, h = cut_height) - cluster_assignment_tibble(clusters, length(unique(clusters))) -} - -# ------------------------------------------------------------------------------ - -cluster_assignment_tibble <- function(clusters, n_clusters) { - reorder_clusts <- order(unique(clusters)) - names <- paste0("Cluster_", 1:n_clusters) - res <- names[reorder_clusts][clusters] - - tibble::tibble(.cluster = factor(res)) -} +#' Extract cluster assignments from model +#' +#' @param object An cluster_spec object. +#' @param ... Other arguments passed to methods. +#' +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' kmeans_fit %>% +#' extract_cluster_assignment() +#' @export +extract_cluster_assignment <- function(object, ...) { + UseMethod("extract_cluster_assignment") +} + +#' @export +extract_cluster_assignment.cluster_fit <- function(object, ...) { + extract_cluster_assignment(object$fit, ...) +} + +#' @export +extract_cluster_assignment.workflow <- function(object, ...) { + extract_cluster_assignment(object$fit$fit$fit) +} + +#' @export +extract_cluster_assignment.kmeans <- function(object, ...) { + cluster_assignment_tibble(object$cluster, length(object$size)) +} + +#' @export +extract_cluster_assignment.KMeansCluster <- function(object, ...) { + cluster_assignment_tibble(object$clusters, length(object$obs_per_cluster)) +} + +#' @export +extract_cluster_assignment.hclust <- function(object, ...) { + + # if k or h is passed in the dots, use those. Otherwise, use attributes + # from original model specification + args <- list(...) + if (!("k" %in% names(args) | "cut_height" %in% names(args))) { + k <- attr(object, "k") + cut_height <- attr(object, "cut_height") + } + clusters <- stats::cutree(object, k, h = cut_height) + cluster_assignment_tibble(clusters, length(unique(clusters))) +} + +# ------------------------------------------------------------------------------ + +cluster_assignment_tibble <- function(clusters, n_clusters) { + reorder_clusts <- order(unique(clusters)) + names <- paste0("Cluster_", 1:n_clusters) + res <- names[reorder_clusts][clusters] + + tibble::tibble(.cluster = factor(res)) +} diff --git a/R/extract_characterization.R b/R/extract_characterization.R index 20a216f6..b4c61477 100644 --- a/R/extract_characterization.R +++ b/R/extract_characterization.R @@ -1,20 +1,20 @@ -#' Extract clusters from model -#' -#' @param object An cluster_spec object. -#' @param ... Other arguments passed to methods. -#' -#' @examples -#' set.seed(1234) -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' kmeans_fit %>% -#' extract_centroids() -#' @export -extract_centroids <- function(object, ...) { - summ <- extract_fit_summary(object) - clusters <- tibble::tibble(.cluster = summ$cluster_names) - bind_cols(clusters, summ$centroids) -} +#' Extract clusters from model +#' +#' @param object An cluster_spec object. +#' @param ... Other arguments passed to methods. +#' +#' @examples +#' set.seed(1234) +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' kmeans_fit %>% +#' extract_centroids() +#' @export +extract_centroids <- function(object, ...) { + summ <- extract_fit_summary(object) + clusters <- tibble::tibble(.cluster = summ$cluster_names) + bind_cols(clusters, summ$centroids) +} diff --git a/R/extract_summary.R b/R/extract_summary.R index 96d2b70a..feaac179 100644 --- a/R/extract_summary.R +++ b/R/extract_summary.R @@ -1,98 +1,98 @@ - -#' S3 method to get fitted model summary info depending on engine -#' -#' @param object a fitted cluster_spec object -#' @param ... other arguments passed to methods -#' -#' @return A list with various summary elements -#' -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' kmeans_fit %>% -#' extract_fit_summary() -#' @export -extract_fit_summary <- function(object, ...) { - UseMethod("extract_fit_summary") -} - -#' @export -extract_fit_summary.cluster_fit <- function(object, ...) { - extract_fit_summary(object$fit) -} - -#' @export -extract_fit_summary.workflow <- function(object, ...) { - extract_fit_summary(object$fit$fit$fit) -} - -#' @export -extract_fit_summary.kmeans <- function(object, ...) { - reorder_clusts <- order(unique(object$cluster)) - names <- paste0("Cluster_", seq_len(nrow(object$centers))) - - list( - cluster_names = names, - centroids = tibble::as_tibble(object$centers[reorder_clusts, , drop = FALSE]), - n_members = object$size[reorder_clusts], - within_sse = object$withinss[reorder_clusts], - tot_sse = object$totss, - orig_labels = unname(object$cluster), - cluster_assignments = names[reorder_clusts][object$cluster] - ) -} - -#' @export -extract_fit_summary.KMeansCluster <- function(object, ...) { - reorder_clusts <- order(unique(object$cluster)) - names <- paste0("Cluster_", seq_len(nrow(object$centroids))) - - list( - cluster_names = names, - centroids = tibble::as_tibble(object$centroids[reorder_clusts, , drop = FALSE]), - n_members = object$obs_per_cluster[reorder_clusts], - within_sse = object$WCSS_per_cluster[reorder_clusts], - tot_sse = object$total_SSE, - orig_labels = object$clusters, - cluster_assignments = names[reorder_clusts][object$clusters] - ) -} - -#' @export -extract_fit_summary.hclust <- function(object, ...) { - - clusts <- extract_cluster_assignment(object, ...)$.cluster - n_clust <- dplyr::n_distinct(clusts) - - training_data <- attr(object, "training_data") - - overall_centroid <- colMeans(training_data) - - by_clust <- training_data %>% - tibble::as_tibble() %>% - dplyr::mutate( - .cluster = clusts - ) %>% - dplyr::group_by(.cluster) %>% - tidyr::nest() - - centroids <- by_clust$data %>% - purrr::map_dfr(~ .x %>% dplyr::summarize_all(mean)) - - within_sse <- by_clust$data %>% - purrr::map2_dbl(1:n_clust, - ~ sum(Rfast::dista(centroids[.y,], .x))) - - list( - cluster_names = unique(clusts), - centroids = centroids, - n_members = unname(table(clusts)), - within_sse = within_sse, - tot_sse = sum(Rfast::dista(t(overall_centroid), training_data)), - orig_labels = NULL, - cluster_assignments = clusts - ) -} + +#' S3 method to get fitted model summary info depending on engine +#' +#' @param object a fitted cluster_spec object +#' @param ... other arguments passed to methods +#' +#' @return A list with various summary elements +#' +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' kmeans_fit %>% +#' extract_fit_summary() +#' @export +extract_fit_summary <- function(object, ...) { + UseMethod("extract_fit_summary") +} + +#' @export +extract_fit_summary.cluster_fit <- function(object, ...) { + extract_fit_summary(object$fit) +} + +#' @export +extract_fit_summary.workflow <- function(object, ...) { + extract_fit_summary(object$fit$fit$fit) +} + +#' @export +extract_fit_summary.kmeans <- function(object, ...) { + reorder_clusts <- order(unique(object$cluster)) + names <- paste0("Cluster_", seq_len(nrow(object$centers))) + + list( + cluster_names = names, + centroids = tibble::as_tibble(object$centers[reorder_clusts, , drop = FALSE]), + n_members = object$size[reorder_clusts], + within_sse = object$withinss[reorder_clusts], + tot_sse = object$totss, + orig_labels = unname(object$cluster), + cluster_assignments = names[reorder_clusts][object$cluster] + ) +} + +#' @export +extract_fit_summary.KMeansCluster <- function(object, ...) { + reorder_clusts <- order(unique(object$cluster)) + names <- paste0("Cluster_", seq_len(nrow(object$centroids))) + + list( + cluster_names = names, + centroids = tibble::as_tibble(object$centroids[reorder_clusts, , drop = FALSE]), + n_members = object$obs_per_cluster[reorder_clusts], + within_sse = object$WCSS_per_cluster[reorder_clusts], + tot_sse = object$total_SSE, + orig_labels = object$clusters, + cluster_assignments = names[reorder_clusts][object$clusters] + ) +} + +#' @export +extract_fit_summary.hclust <- function(object, ...) { + + clusts <- extract_cluster_assignment(object, ...)$.cluster + n_clust <- dplyr::n_distinct(clusts) + + training_data <- attr(object, "training_data") + + overall_centroid <- colMeans(training_data) + + by_clust <- training_data %>% + tibble::as_tibble() %>% + dplyr::mutate( + .cluster = clusts + ) %>% + dplyr::group_by(.cluster) %>% + tidyr::nest() + + centroids <- by_clust$data %>% + purrr::map_dfr(~ .x %>% dplyr::summarize_all(mean)) + + within_sse <- by_clust$data %>% + purrr::map2_dbl(1:n_clust, + ~ sum(Rfast::dista(centroids[.y,], .x))) + + list( + cluster_names = unique(clusts), + centroids = centroids, + n_members = unname(table(clusts)), + within_sse = within_sse, + tot_sse = sum(Rfast::dista(t(overall_centroid), training_data)), + orig_labels = NULL, + cluster_assignments = clusts + ) +} diff --git a/R/finalize.R b/R/finalize.R index 8da34a4e..12d279f7 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -1,62 +1,62 @@ -#' Splice final parameters into objects -#' -#' The `finalize_*` functions take a list or tibble of tuning parameter values and -#' update objects with those values. -#' -#' @param x A recipe, `parsnip` model specification, or workflow. -#' @param parameters A list or 1-row tibble of parameter values. Note that the -#' column names of the tibble should be the `id` fields attached to `tune()`. -#' For example, in the `Examples` section below, the model has `tune("K")`. In -#' this case, the parameter tibble should be "K" and not "neighbors". -#' @return An updated version of `x`. -#' @export -#' @examples -#' kmeans_spec <- k_means(num_clusters = tune()) -#' -#' best_params <- data.frame(num_clusters = 5) -#' best_params -#' -#' kmeans_spec -#' finalize_model_tidyclust(kmeans_spec, best_params) -finalize_model_tidyclust <- function(x, parameters) { - if (!inherits(x, "cluster_spec")) { - rlang::abort("`x` should be a tidyclust model specification.") - } - parsnip::check_final_param(parameters) - pset <- hardhat::extract_parameter_set_dials(x) - if (tibble::is_tibble(parameters)) { - parameters <- as.list(parameters) - } - - parameters <- parameters[names(parameters) %in% pset$id] - - discordant <- dplyr::filter(pset, id != name & id %in% names(parameters)) - if (nrow(discordant) > 0) { - for (i in 1:nrow(discordant)) { - names(parameters)[names(parameters) == discordant$id[i]] <- - discordant$name[i] - } - } - rlang::exec(stats::update, object = x, !!!parameters) -} - -#' @rdname finalize_model_tidyclust -#' @export -finalize_workflow_tidyclust <- function(x, parameters) { - if (!inherits(x, "workflow")) { - rlang::abort("`x` should be a workflow") - } - parsnip::check_final_param(parameters) - - mod <- extract_spec_parsnip(x) - mod <- finalize_model_tidyclust(mod, parameters) - x <- set_workflow_spec(x, mod) - - if (has_preprocessor_recipe(x)) { - rec <- extract_preprocessor(x) - rec <- tune::finalize_recipe(rec, parameters) - x <- set_workflow_recipe(x, rec) - } - - x -} +#' Splice final parameters into objects +#' +#' The `finalize_*` functions take a list or tibble of tuning parameter values and +#' update objects with those values. +#' +#' @param x A recipe, `parsnip` model specification, or workflow. +#' @param parameters A list or 1-row tibble of parameter values. Note that the +#' column names of the tibble should be the `id` fields attached to `tune()`. +#' For example, in the `Examples` section below, the model has `tune("K")`. In +#' this case, the parameter tibble should be "K" and not "neighbors". +#' @return An updated version of `x`. +#' @export +#' @examples +#' kmeans_spec <- k_means(num_clusters = tune()) +#' +#' best_params <- data.frame(num_clusters = 5) +#' best_params +#' +#' kmeans_spec +#' finalize_model_tidyclust(kmeans_spec, best_params) +finalize_model_tidyclust <- function(x, parameters) { + if (!inherits(x, "cluster_spec")) { + rlang::abort("`x` should be a tidyclust model specification.") + } + parsnip::check_final_param(parameters) + pset <- hardhat::extract_parameter_set_dials(x) + if (tibble::is_tibble(parameters)) { + parameters <- as.list(parameters) + } + + parameters <- parameters[names(parameters) %in% pset$id] + + discordant <- dplyr::filter(pset, id != name & id %in% names(parameters)) + if (nrow(discordant) > 0) { + for (i in 1:nrow(discordant)) { + names(parameters)[names(parameters) == discordant$id[i]] <- + discordant$name[i] + } + } + rlang::exec(stats::update, object = x, !!!parameters) +} + +#' @rdname finalize_model_tidyclust +#' @export +finalize_workflow_tidyclust <- function(x, parameters) { + if (!inherits(x, "workflow")) { + rlang::abort("`x` should be a workflow") + } + parsnip::check_final_param(parameters) + + mod <- extract_spec_parsnip(x) + mod <- finalize_model_tidyclust(mod, parameters) + x <- set_workflow_spec(x, mod) + + if (has_preprocessor_recipe(x)) { + rec <- extract_preprocessor(x) + rec <- tune::finalize_recipe(rec, parameters) + x <- set_workflow_recipe(x, rec) + } + + x +} diff --git a/R/fit.R b/R/fit.R index f31f69f0..c7a71492 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1,334 +1,334 @@ -#' Fit a Model Specification to a Data Set -#' -#' `fit()` and `fit_xy()` take a model specification, translate_tidyclust the -#' required code by substituting arguments, and execute the model fit routine. -#' -#' @param object An object of class `cluster_spec` that has a chosen engine (via -#' [set_engine()]). -#' @param formula An object of class `formula` (or one that can be coerced to -#' that class): a symbolic description of the model to be fitted. -#' @param data Optional, depending on the interface (see Details below). A data -#' frame containing all relevant variables (e.g. predictors, case weights, -#' etc). Note: when needed, a \emph{named argument} should be used. -#' @param control A named list with elements `verbosity` and `catch`. See -#' [control_cluster()]. -#' @param ... Not currently used; values passed here will be ignored. Other -#' options required to fit the model should be passed using -#' `set_engine()`. -#' @details `fit()` and `fit_xy()` substitute the current arguments in the -#' model specification into the computational engine's code, check them for -#' validity, then fit the model using the data and the engine-specific code. -#' Different model functions have different interfaces (e.g. formula or -#' `x`/`y`) and these functions translate_tidyclust between the interface used -#' when `fit()` or `fit_xy()` was invoked and the one required by the -#' underlying model. -#' -#' When possible, these functions attempt to avoid making copies of the data. -#' For example, if the underlying model uses a formula and `fit()` is invoked, -#' the original data are references when the model is fit. However, if the -#' underlying model uses something else, such as `x`/`y`, the formula is -#' evaluated and the data are converted to the required format. In this case, -#' any calls in the resulting model objects reference the temporary objects -#' used to fit the model. -#' -#' If the model engine has not been set, the model's default engine will be -#' used (as discussed on each model page). If the `verbosity` option of -#' [control_cluster()] is greater than zero, a warning will be produced. -#' -#' If you would like to use an alternative method for generating contrasts -#' when supplying a formula to `fit()`, set the global option `contrasts` to -#' your preferred method. For example, you might set it to: `options(contrasts -#' = c(unordered = "contr.helmert", ordered = "contr.poly"))`. See the help -#' page for [stats::contr.treatment()] for more possible contrast types. -#' @examples -#' library(dplyr) -#' -#' kmeans_mod <- k_means(num_clusters = 5) -#' -#' using_formula <- -#' kmeans_mod %>% -#' set_engine("stats") %>% -#' fit(~., data = mtcars) -#' -#' using_x <- -#' kmeans_mod %>% -#' set_engine("stats") %>% -#' fit_xy(x = mtcars) -#' -#' using_formula -#' using_x -#' @return A `cluster_fit` object that contains several elements: -#' \itemize{ -#' \item \code{spec}: The model specification object (\code{object} in the -#' call to \code{fit}) -#' \item \code{fit}: when the model is executed without error, this is the -#' model object. Otherwise, it is a \code{try-error} -#' object with the error message. -#' \item \code{preproc}: any objects needed to convert between a formula and -#' non-formula interface -#' (such as the \code{terms} object) -#' } -#' The return value will also have a class related to the fitted model (e.g. -#' `"_kmeans"`) before the base class of `"cluster_fit"`. -#' -#' @seealso [set_engine()], [control_cluster()], `cluster_spec`, -#' `cluster_fit` -#' @param x A matrix, sparse matrix, or data frame of predictors. Only some -#' models have support for sparse matrix input. See -#' `tidyclust::get_encoding_tidyclust()` for details. `x` should have column names. -#' @param case_weights An optional classed vector of numeric case weights. This -#' must return `TRUE` when [hardhat::is_case_weights()] is run on it. See -#' [hardhat::frequency_weights()] and [hardhat::importance_weights()] for -#' examples. -#' @rdname fit -#' @export -#' @export fit.cluster_spec -fit.cluster_spec <- function(object, - formula, - data, - control = control_cluster(), - ...) { - if (object$mode == "unknown") { - rlang::abort("Please set the mode in the model specification.") - } - # if (!inherits(control, "control_cluster")) { - # rlang::abort("The 'control' argument should have class 'control_cluster'.") - # } - dots <- quos(...) - if (is.null(object$engine)) { - eng_vals <- possible_engines(object) - object$engine <- eng_vals[1] - if (control$verbosity > 0) { - rlang::warn(glue::glue("Engine set to `{object$engine}`.")) - } - } - - if (all(c("x", "y") %in% names(dots))) { - rlang::abort("`fit.cluster_spec()` is for the formula methods. Use `fit_xy()` instead.") - } - cl <- match.call(expand.dots = TRUE) - # Create an environment with the evaluated argument objects. This will be - # used when a model call is made later. - eval_env <- rlang::env() - - eval_env$data <- data - eval_env$formula <- formula - fit_interface <- - check_interface(eval_env$formula, eval_env$data, cl, object) - - # populate `method` with the details for this model type - object <- add_methods(object, engine = object$engine) - - check_installs(object) - - interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") - - # Now call the wrappers that transition between the interface - # called here ("fit" interface) that will direct traffic to - # what the underlying model uses. For example, if a formula is - # used here, `fit_interface_formula` will determine if a - # translation has to be made if the model interface is x/y/ - res <- - switch(interfaces, - # homogeneous combinations: - formula_formula = - form_form( - object = object, - control = control, - env = eval_env - ), - - # heterogenous combinations - formula_matrix = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - formula_data.frame = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - rlang::abort(glue::glue("{interfaces} is unknown.")) - ) - model_classes <- class(res$fit) - class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") - res -} - -check_interface <- function(formula, data, cl, model) { - inher(formula, "formula", cl) - - # Determine the `fit()` interface - form_interface <- !is.null(formula) & !is.null(data) - - if (form_interface) { - return("formula") - } - rlang::abort("Error when checking the interface.") -} - -inher <- function(x, cls, cl) { - if (!is.null(x) && !inherits(x, cls)) { - call <- match.call() - obj <- deparse(call[["x"]]) - if (length(cls) > 1) { - rlang::abort( - glue::glue( - "`{obj}` should be one of the following classes: ", - glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ") - ) - ) - } else { - rlang::abort( - glue::glue("`{obj}` should be a {cls} object") - ) - } - } - invisible(x) -} - -add_methods <- function(x, engine) { - x$engine <- engine - check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) - x$method <- get_cluster_spec(specific_model(x), x$mode, x$engine) - x -} - -# ------------------------------------------------------------------------------ - -eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) { - if (capture) { - if (catch) { - junk <- capture.output(res <- try(rlang::eval_tidy(e, ...), silent = TRUE)) - } else { - junk <- capture.output(res <- rlang::eval_tidy(e, ...)) - } - } else { - if (catch) { - res <- try(rlang::eval_tidy(e, ...), silent = TRUE) - } else { - res <- rlang::eval_tidy(e, ...) - } - } - res -} - -# ------------------------------------------------------------------------------ - -#' @rdname fit -#' @export -#' @export fit_xy.cluster_spec -fit_xy.cluster_spec <- - function(object, x, case_weights = NULL, control = control_cluster(), ...) { - # if (!inherits(control, "control_cluster")) { - # rlang::abort("The 'control' argument should have class 'control_cluster'.") - # } - if (is.null(colnames(x))) { - rlang::abort("'x' should have column names.") - } - - dots <- quos(...) - if (is.null(object$engine)) { - eng_vals <- possible_engines(object) - object$engine <- eng_vals[1] - if (control$verbosity > 0) { - rlang::warn(glue::glue("Engine set to `{object$engine}`.")) - } - } - - cl <- match.call(expand.dots = TRUE) - eval_env <- rlang::env() - eval_env$x <- x - fit_interface <- check_x_interface(eval_env$x, cl, object) - - # populate `method` with the details for this model type - object <- add_methods(object, engine = object$engine) - - check_installs(object) - - interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") - - # Now call the wrappers that transition between the interface - # called here ("fit" interface) that will direct traffic to - # what the underlying model uses. For example, if a formula is - # used here, `fit_interface_formula` will determine if a - # translation has to be made if the model interface is x/y/ - res <- - switch(interfaces, - # homogeneous combinations: - matrix_matrix = , - data.frame_matrix = - x_x( - object = object, - env = eval_env, - control = control, - target = "matrix", - ... - ), - data.frame_data.frame = , - matrix_data.frame = - x_x( - object = object, - env = eval_env, - control = control, - target = "data.frame", - ... - ), - - # heterogenous combinations - matrix_formula = , - data.frame_formula = - x_form( - object = object, - env = eval_env, - control = control, - ... - ), - rlang::abort(glue::glue("{interfaces} is unknown.")) - ) - model_classes <- class(res$fit) - class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") - res - } - -check_x_interface <- function(x, cl, model) { - sparse_ok <- allow_sparse(model) - sparse_x <- inherits(x, "dgCMatrix") - if (!sparse_ok & sparse_x) { - rlang::abort("Sparse matrices not supported by this model/engine combination.") - } - - if (sparse_ok) { - inher(x, c("data.frame", "matrix", "dgCMatrix"), cl) - } else { - inher(x, c("data.frame", "matrix"), cl) - } - - if (sparse_ok) { - matrix_interface <- !is.null(x) && (is.matrix(x) | sparse_x) - } else { - matrix_interface <- !is.null(x) && is.matrix(x) - } - - df_interface <- !is.null(x) && is.data.frame(x) - - if (matrix_interface) { - return("matrix") - } - if (df_interface) { - return("data.frame") - } - rlang::abort("Error when checking the interface") -} - -allow_sparse <- function(x) { - res <- get_from_env_tidyclust(paste0(class(x)[1], "_encoding")) - all(res$allow_sparse_x[res$engine == x$engine]) -} +#' Fit a Model Specification to a Data Set +#' +#' `fit()` and `fit_xy()` take a model specification, translate_tidyclust the +#' required code by substituting arguments, and execute the model fit routine. +#' +#' @param object An object of class `cluster_spec` that has a chosen engine (via +#' [set_engine()]). +#' @param formula An object of class `formula` (or one that can be coerced to +#' that class): a symbolic description of the model to be fitted. +#' @param data Optional, depending on the interface (see Details below). A data +#' frame containing all relevant variables (e.g. predictors, case weights, +#' etc). Note: when needed, a \emph{named argument} should be used. +#' @param control A named list with elements `verbosity` and `catch`. See +#' [control_cluster()]. +#' @param ... Not currently used; values passed here will be ignored. Other +#' options required to fit the model should be passed using +#' `set_engine()`. +#' @details `fit()` and `fit_xy()` substitute the current arguments in the +#' model specification into the computational engine's code, check them for +#' validity, then fit the model using the data and the engine-specific code. +#' Different model functions have different interfaces (e.g. formula or +#' `x`/`y`) and these functions translate_tidyclust between the interface used +#' when `fit()` or `fit_xy()` was invoked and the one required by the +#' underlying model. +#' +#' When possible, these functions attempt to avoid making copies of the data. +#' For example, if the underlying model uses a formula and `fit()` is invoked, +#' the original data are references when the model is fit. However, if the +#' underlying model uses something else, such as `x`/`y`, the formula is +#' evaluated and the data are converted to the required format. In this case, +#' any calls in the resulting model objects reference the temporary objects +#' used to fit the model. +#' +#' If the model engine has not been set, the model's default engine will be +#' used (as discussed on each model page). If the `verbosity` option of +#' [control_cluster()] is greater than zero, a warning will be produced. +#' +#' If you would like to use an alternative method for generating contrasts +#' when supplying a formula to `fit()`, set the global option `contrasts` to +#' your preferred method. For example, you might set it to: `options(contrasts +#' = c(unordered = "contr.helmert", ordered = "contr.poly"))`. See the help +#' page for [stats::contr.treatment()] for more possible contrast types. +#' @examples +#' library(dplyr) +#' +#' kmeans_mod <- k_means(num_clusters = 5) +#' +#' using_formula <- +#' kmeans_mod %>% +#' set_engine("stats") %>% +#' fit(~., data = mtcars) +#' +#' using_x <- +#' kmeans_mod %>% +#' set_engine("stats") %>% +#' fit_xy(x = mtcars) +#' +#' using_formula +#' using_x +#' @return A `cluster_fit` object that contains several elements: +#' \itemize{ +#' \item \code{spec}: The model specification object (\code{object} in the +#' call to \code{fit}) +#' \item \code{fit}: when the model is executed without error, this is the +#' model object. Otherwise, it is a \code{try-error} +#' object with the error message. +#' \item \code{preproc}: any objects needed to convert between a formula and +#' non-formula interface +#' (such as the \code{terms} object) +#' } +#' The return value will also have a class related to the fitted model (e.g. +#' `"_kmeans"`) before the base class of `"cluster_fit"`. +#' +#' @seealso [set_engine()], [control_cluster()], `cluster_spec`, +#' `cluster_fit` +#' @param x A matrix, sparse matrix, or data frame of predictors. Only some +#' models have support for sparse matrix input. See +#' `tidyclust::get_encoding_tidyclust()` for details. `x` should have column names. +#' @param case_weights An optional classed vector of numeric case weights. This +#' must return `TRUE` when [hardhat::is_case_weights()] is run on it. See +#' [hardhat::frequency_weights()] and [hardhat::importance_weights()] for +#' examples. +#' @rdname fit +#' @export +#' @export fit.cluster_spec +fit.cluster_spec <- function(object, + formula, + data, + control = control_cluster(), + ...) { + if (object$mode == "unknown") { + rlang::abort("Please set the mode in the model specification.") + } + # if (!inherits(control, "control_cluster")) { + # rlang::abort("The 'control' argument should have class 'control_cluster'.") + # } + dots <- quos(...) + if (is.null(object$engine)) { + eng_vals <- possible_engines(object) + object$engine <- eng_vals[1] + if (control$verbosity > 0) { + rlang::warn(glue::glue("Engine set to `{object$engine}`.")) + } + } + + if (all(c("x", "y") %in% names(dots))) { + rlang::abort("`fit.cluster_spec()` is for the formula methods. Use `fit_xy()` instead.") + } + cl <- match.call(expand.dots = TRUE) + # Create an environment with the evaluated argument objects. This will be + # used when a model call is made later. + eval_env <- rlang::env() + + eval_env$data <- data + eval_env$formula <- formula + fit_interface <- + check_interface(eval_env$formula, eval_env$data, cl, object) + + # populate `method` with the details for this model type + object <- add_methods(object, engine = object$engine) + + check_installs(object) + + interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") + + # Now call the wrappers that transition between the interface + # called here ("fit" interface) that will direct traffic to + # what the underlying model uses. For example, if a formula is + # used here, `fit_interface_formula` will determine if a + # translation has to be made if the model interface is x/y/ + res <- + switch(interfaces, + # homogeneous combinations: + formula_formula = + form_form( + object = object, + control = control, + env = eval_env + ), + + # heterogenous combinations + formula_matrix = + form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + formula_data.frame = + form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + rlang::abort(glue::glue("{interfaces} is unknown.")) + ) + model_classes <- class(res$fit) + class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") + res +} + +check_interface <- function(formula, data, cl, model) { + inher(formula, "formula", cl) + + # Determine the `fit()` interface + form_interface <- !is.null(formula) & !is.null(data) + + if (form_interface) { + return("formula") + } + rlang::abort("Error when checking the interface.") +} + +inher <- function(x, cls, cl) { + if (!is.null(x) && !inherits(x, cls)) { + call <- match.call() + obj <- deparse(call[["x"]]) + if (length(cls) > 1) { + rlang::abort( + glue::glue( + "`{obj}` should be one of the following classes: ", + glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ") + ) + ) + } else { + rlang::abort( + glue::glue("`{obj}` should be a {cls} object") + ) + } + } + invisible(x) +} + +add_methods <- function(x, engine) { + x$engine <- engine + check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) + x$method <- get_cluster_spec(specific_model(x), x$mode, x$engine) + x +} + +# ------------------------------------------------------------------------------ + +eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) { + if (capture) { + if (catch) { + junk <- capture.output(res <- try(rlang::eval_tidy(e, ...), silent = TRUE)) + } else { + junk <- capture.output(res <- rlang::eval_tidy(e, ...)) + } + } else { + if (catch) { + res <- try(rlang::eval_tidy(e, ...), silent = TRUE) + } else { + res <- rlang::eval_tidy(e, ...) + } + } + res +} + +# ------------------------------------------------------------------------------ + +#' @rdname fit +#' @export +#' @export fit_xy.cluster_spec +fit_xy.cluster_spec <- + function(object, x, case_weights = NULL, control = control_cluster(), ...) { + # if (!inherits(control, "control_cluster")) { + # rlang::abort("The 'control' argument should have class 'control_cluster'.") + # } + if (is.null(colnames(x))) { + rlang::abort("'x' should have column names.") + } + + dots <- quos(...) + if (is.null(object$engine)) { + eng_vals <- possible_engines(object) + object$engine <- eng_vals[1] + if (control$verbosity > 0) { + rlang::warn(glue::glue("Engine set to `{object$engine}`.")) + } + } + + cl <- match.call(expand.dots = TRUE) + eval_env <- rlang::env() + eval_env$x <- x + fit_interface <- check_x_interface(eval_env$x, cl, object) + + # populate `method` with the details for this model type + object <- add_methods(object, engine = object$engine) + + check_installs(object) + + interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") + + # Now call the wrappers that transition between the interface + # called here ("fit" interface) that will direct traffic to + # what the underlying model uses. For example, if a formula is + # used here, `fit_interface_formula` will determine if a + # translation has to be made if the model interface is x/y/ + res <- + switch(interfaces, + # homogeneous combinations: + matrix_matrix = , + data.frame_matrix = + x_x( + object = object, + env = eval_env, + control = control, + target = "matrix", + ... + ), + data.frame_data.frame = , + matrix_data.frame = + x_x( + object = object, + env = eval_env, + control = control, + target = "data.frame", + ... + ), + + # heterogenous combinations + matrix_formula = , + data.frame_formula = + x_form( + object = object, + env = eval_env, + control = control, + ... + ), + rlang::abort(glue::glue("{interfaces} is unknown.")) + ) + model_classes <- class(res$fit) + class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") + res + } + +check_x_interface <- function(x, cl, model) { + sparse_ok <- allow_sparse(model) + sparse_x <- inherits(x, "dgCMatrix") + if (!sparse_ok & sparse_x) { + rlang::abort("Sparse matrices not supported by this model/engine combination.") + } + + if (sparse_ok) { + inher(x, c("data.frame", "matrix", "dgCMatrix"), cl) + } else { + inher(x, c("data.frame", "matrix"), cl) + } + + if (sparse_ok) { + matrix_interface <- !is.null(x) && (is.matrix(x) | sparse_x) + } else { + matrix_interface <- !is.null(x) && is.matrix(x) + } + + df_interface <- !is.null(x) && is.data.frame(x) + + if (matrix_interface) { + return("matrix") + } + if (df_interface) { + return("data.frame") + } + rlang::abort("Error when checking the interface") +} + +allow_sparse <- function(x) { + res <- get_from_env_tidyclust(paste0(class(x)[1], "_encoding")) + all(res$allow_sparse_x[res$engine == x$engine]) +} diff --git a/R/hier_clust.R b/R/hier_clust.R index 36554e97..d50b7450 100644 --- a/R/hier_clust.R +++ b/R/hier_clust.R @@ -1,96 +1,96 @@ -#' Hierarchical (Agglomerative) Clustering -#' -#' @description -#' -#' `hier_clust()` defines a model that fits clusters based on a distance-based -#' dendrogram -#' -#' @param mode A single character string for the type of model. -#' The only possible value for this model is "partition". -#' @param engine A single character string specifying what computational engine -#' to use for fitting. Possible engines are listed below. The default for this -#' model is `"stats"`. -#' @param k Positive integer, number of clusters in model (optional). -#' @param h Positive double, height at which to cut dendrogram to obtain cluster -#' assignments (only used if `k` is `NULL`) -#' @param linkage_method the agglomeration method to be used. This should be (an -#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`, -#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"` -#' (= WPGMC) or `"centroid"` (= UPGMC). -#' @param dist_fun A distance function to use -#' -#' @examples -#' # show_engines("hier_clust") -#' -#' hier_clust() -#' @export -hier_clust <- - function(mode = "partition", - engine = "stats", - k = NULL, - h = NULL, - linkage_method = "complete") { - args <- list( - k = enquo(k), - h = enquo(h), - linkage_method = enquo(linkage_method) - ) - - new_cluster_spec( - "hier_clust", - args = args, - eng_args = NULL, - mode = mode, - method = NULL, - engine = engine - ) - } - -#' @export -print.hier_clust <- function(x, ...) { - cat("Hierarchical Clustering Specification (", x$mode, ")\n\n", sep = "") - model_printer(x, ...) - - if (!is.null(x$method$fit$args)) { - cat("Model fit template:\n") - print(show_call(x)) - } - - invisible(x) -} - -#' @export -translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) { - x <- translate_tidyclust.default(x, engine, ...) - x -} - -# ------------------------------------------------------------------------------ - -#' Simple Wrapper around hclust function -#' -#' This wrapper prepares the data into a distance matrix to send to -#' `stats::hclust` and retains the parameters `k` or `h` as an attribute. -#' -#' @param x matrix or data frame -#' @param k the number of clusters -#' @param h the height to cut the dendrogram -#' @param linkage_method the agglomeration method to be used. This should be (an -#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`, -#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"` -#' (= WPGMC) or `"centroid"` (= UPGMC). -#' @param dist_fun A distance function to use -#' -#' @return A dendrogram -#' @keywords internal -#' @export -hclust_fit <- function(x, k = NULL, cut_height = NULL, - linkage_method = NULL, - dist_fun = Rfast::Dist) { - dmat <- dist_fun(x) - res <- hclust(as.dist(dmat), method = linkage_method) - attr(res, "k") <- k - attr(res, "cut_height") <- cut_height - attr(res, "training_data") <- x - return(res) -} +#' Hierarchical (Agglomerative) Clustering +#' +#' @description +#' +#' `hier_clust()` defines a model that fits clusters based on a distance-based +#' dendrogram +#' +#' @param mode A single character string for the type of model. +#' The only possible value for this model is "partition". +#' @param engine A single character string specifying what computational engine +#' to use for fitting. Possible engines are listed below. The default for this +#' model is `"stats"`. +#' @param k Positive integer, number of clusters in model (optional). +#' @param h Positive double, height at which to cut dendrogram to obtain cluster +#' assignments (only used if `k` is `NULL`) +#' @param linkage_method the agglomeration method to be used. This should be (an +#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`, +#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"` +#' (= WPGMC) or `"centroid"` (= UPGMC). +#' @param dist_fun A distance function to use +#' +#' @examples +#' # show_engines("hier_clust") +#' +#' hier_clust() +#' @export +hier_clust <- + function(mode = "partition", + engine = "stats", + k = NULL, + h = NULL, + linkage_method = "complete") { + args <- list( + k = enquo(k), + h = enquo(h), + linkage_method = enquo(linkage_method) + ) + + new_cluster_spec( + "hier_clust", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = engine + ) + } + +#' @export +print.hier_clust <- function(x, ...) { + cat("Hierarchical Clustering Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if (!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +#' @export +translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) { + x <- translate_tidyclust.default(x, engine, ...) + x +} + +# ------------------------------------------------------------------------------ + +#' Simple Wrapper around hclust function +#' +#' This wrapper prepares the data into a distance matrix to send to +#' `stats::hclust` and retains the parameters `k` or `h` as an attribute. +#' +#' @param x matrix or data frame +#' @param k the number of clusters +#' @param h the height to cut the dendrogram +#' @param linkage_method the agglomeration method to be used. This should be (an +#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`, +#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"` +#' (= WPGMC) or `"centroid"` (= UPGMC). +#' @param dist_fun A distance function to use +#' +#' @return A dendrogram +#' @keywords internal +#' @export +hclust_fit <- function(x, k = NULL, cut_height = NULL, + linkage_method = NULL, + dist_fun = Rfast::Dist) { + dmat <- dist_fun(x) + res <- hclust(as.dist(dmat), method = linkage_method) + attr(res, "k") <- k + attr(res, "cut_height") <- cut_height + attr(res, "training_data") <- x + return(res) +} diff --git a/R/hier_clust_data.R b/R/hier_clust_data.R index edc93559..b4cf14e1 100644 --- a/R/hier_clust_data.R +++ b/R/hier_clust_data.R @@ -1,132 +1,132 @@ -set_new_model_tidyclust("hier_clust") - -set_model_mode_tidyclust("hier_clust", "partition") - -# ------------------------------------------------------------------------------ - -set_model_engine_tidyclust("hier_clust", "partition", "stats") -set_dependency_tidyclust("hier_clust", "stats", "stats") - -set_fit_tidyclust( - model = "hier_clust", - eng = "stats", - mode = "partition", - value = list( - interface = "matrix", - protect = c("data"), - func = c(pkg = "tidyclust", fun = "hclust_fit"), - defaults = list() - ) -) - -set_encoding_tidyclust( - model = "hier_clust", - eng = "stats", - mode = "partition", - options = list( - predictor_indicators = "traditional", - compute_intercept = TRUE, - remove_intercept = TRUE, - allow_sparse_x = FALSE - ) -) - -set_model_arg_tidyclust( - model = "hier_clust", - eng = "stats", - tidyclust = "k", - original = "k", - func = list(pkg = "tidyclust", fun = "k"), - has_submodel = TRUE -) - -set_model_arg_tidyclust( - model = "hier_clust", - eng = "stats", - tidyclust = "linkage_method", - original = "linkage_method", - func = list(pkg = "tidyclust", fun = "linkage_method"), - has_submodel = TRUE -) - -set_model_arg_tidyclust( - model = "hier_clust", - eng = "stats", - tidyclust = "cut_height", - original = "cut_height", - func = list(pkg = "tidyclust", fun = "cut_height"), - has_submodel = TRUE -) - -set_pred_tidyclust( - model = "hier_clust", - eng = "stats", - mode = "partition", - type = "cluster", - value = list( - pre = NULL, - post = NULL, - func = c(fun = "stats_hier_clust_predict"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) - ) -) - -# ------------------------------------------------------------------------------ -# -# set_model_engine_tidyclust("k_means", "partition", "ClusterR") -# set_dependency_tidyclust("k_means", "ClusterR", "ClusterR") -# -# set_fit_tidyclust( -# model = "k_means", -# eng = "ClusterR", -# mode = "partition", -# value = list( -# interface = "matrix", -# data = c(x = "data"), -# protect = c("data", "clusters"), -# func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"), -# defaults = list() -# ) -# ) -# -# set_encoding_tidyclust( -# model = "k_means", -# eng = "ClusterR", -# mode = "partition", -# options = list( -# predictor_indicators = "traditional", -# compute_intercept = TRUE, -# remove_intercept = TRUE, -# allow_sparse_x = FALSE -# ) -# ) -# -# set_model_arg_tidyclust( -# model = "k_means", -# eng = "ClusterR", -# tidyclust = "k", -# original = "clusters", -# func = list(pkg = "dials", fun = "k"), -# has_submodel = TRUE -# ) -# -# set_pred_tidyclust( -# model = "k_means", -# eng = "ClusterR", -# mode = "partition", -# type = "cluster", -# value = list( -# pre = NULL, -# post = NULL, -# func = c(fun = "clusterR_kmeans_predict"), -# args = -# list( -# object = rlang::expr(object$fit), -# new_data = rlang::expr(new_data) -# ) -# ) -# ) +set_new_model_tidyclust("hier_clust") + +set_model_mode_tidyclust("hier_clust", "partition") + +# ------------------------------------------------------------------------------ + +set_model_engine_tidyclust("hier_clust", "partition", "stats") +set_dependency_tidyclust("hier_clust", "stats", "stats") + +set_fit_tidyclust( + model = "hier_clust", + eng = "stats", + mode = "partition", + value = list( + interface = "matrix", + protect = c("data"), + func = c(pkg = "tidyclust", fun = "hclust_fit"), + defaults = list() + ) +) + +set_encoding_tidyclust( + model = "hier_clust", + eng = "stats", + mode = "partition", + options = list( + predictor_indicators = "traditional", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_model_arg_tidyclust( + model = "hier_clust", + eng = "stats", + tidyclust = "k", + original = "k", + func = list(pkg = "tidyclust", fun = "k"), + has_submodel = TRUE +) + +set_model_arg_tidyclust( + model = "hier_clust", + eng = "stats", + tidyclust = "linkage_method", + original = "linkage_method", + func = list(pkg = "tidyclust", fun = "linkage_method"), + has_submodel = TRUE +) + +set_model_arg_tidyclust( + model = "hier_clust", + eng = "stats", + tidyclust = "cut_height", + original = "cut_height", + func = list(pkg = "tidyclust", fun = "cut_height"), + has_submodel = TRUE +) + +set_pred_tidyclust( + model = "hier_clust", + eng = "stats", + mode = "partition", + type = "cluster", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "stats_hier_clust_predict"), + args = + list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ +# +# set_model_engine_tidyclust("k_means", "partition", "ClusterR") +# set_dependency_tidyclust("k_means", "ClusterR", "ClusterR") +# +# set_fit_tidyclust( +# model = "k_means", +# eng = "ClusterR", +# mode = "partition", +# value = list( +# interface = "matrix", +# data = c(x = "data"), +# protect = c("data", "clusters"), +# func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"), +# defaults = list() +# ) +# ) +# +# set_encoding_tidyclust( +# model = "k_means", +# eng = "ClusterR", +# mode = "partition", +# options = list( +# predictor_indicators = "traditional", +# compute_intercept = TRUE, +# remove_intercept = TRUE, +# allow_sparse_x = FALSE +# ) +# ) +# +# set_model_arg_tidyclust( +# model = "k_means", +# eng = "ClusterR", +# tidyclust = "k", +# original = "clusters", +# func = list(pkg = "dials", fun = "k"), +# has_submodel = TRUE +# ) +# +# set_pred_tidyclust( +# model = "k_means", +# eng = "ClusterR", +# mode = "partition", +# type = "cluster", +# value = list( +# pre = NULL, +# post = NULL, +# func = c(fun = "clusterR_kmeans_predict"), +# args = +# list( +# object = rlang::expr(object$fit), +# new_data = rlang::expr(new_data) +# ) +# ) +# ) diff --git a/R/k_means.R b/R/k_means.R index abdc2261..5b37ea22 100644 --- a/R/k_means.R +++ b/R/k_means.R @@ -1,153 +1,153 @@ -#' K-Means -#' -#' @description -#' -#' `k_means()` defines a model that fits clusters based on distances to a number -#' of centers. -#' -#' @param mode A single character string for the type of model. -#' The only possible value for this model is "partition". -#' @param engine A single character string specifying what computational engine -#' to use for fitting. Possible engines are listed below. The default for this -#' model is `"stats"`. -#' @param num_clusters Positive integer, number of clusters in model. -#' -#' @examples -#' # show_engines("k_means") -#' -#' k_means() -#' @export -k_means <- - function(mode = "partition", - engine = "stats", - num_clusters = NULL) { - args <- list( - num_clusters = enquo(num_clusters) - ) - - new_cluster_spec( - "k_means", - args = args, - eng_args = NULL, - mode = mode, - method = NULL, - engine = engine - ) - } - -#' @export -print.k_means <- function(x, ...) { - cat("K Means Cluster Specification (", x$mode, ")\n\n", sep = "") - model_printer(x, ...) - - if (!is.null(x$method$fit$args)) { - cat("Model fit template:\n") - print(show_call(x)) - } - - invisible(x) -} - -#' @export -translate_tidyclust.k_means <- function(x, engine = x$engine, ...) { - x <- translate_tidyclust.default(x, engine, ...) - x -} - -# ------------------------------------------------------------------------------ - -#' @method update k_means -#' @rdname tidyclust_update -#' @export -update.k_means <- function(object, - parameters = NULL, - num_clusters = NULL, - fresh = FALSE, ...) { - - eng_args <- parsnip::update_engine_parameters(object$eng_args, ...) - - if (!is.null(parameters)) { - parameters <- parsnip::check_final_param(parameters) - } - args <- list( - num_clusters = enquo(num_clusters) - ) - - args <- parsnip::update_main_parameters(args, parameters) - - if (fresh) { - object$args <- args - object$eng_args <- eng_args - } else { - null_args <- map_lgl(args, null_value) - if (any(null_args)) - args <- args[!null_args] - if (length(args) > 0) - object$args[names(args)] <- args - if (length(eng_args) > 0) - object$eng_args[names(eng_args)] <- eng_args - } - - new_cluster_spec( - "k_means", - args = object$args, - eng_args = object$eng_args, - mode = object$mode, - method = NULL, - engine = object$engine - ) -} - -# # ------------------------------------------------------------------------------ - -check_args.k_means <- function(object) { - - args <- lapply(object$args, rlang::eval_tidy) - - if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0)) - rlang::abort("The number of centers should be >= 0.") - - invisible(object) -} - -# ------------------------------------------------------------------------------ - -#' Simple Wrapper around ClusterR kmeans -#' -#' This wrapper runs `ClusterR::KMeans_rcpp` and adds column names to the -#' `centroids` field. -#' -#' @param data matrix or data frame -#' @param clusters the number of clusters -#' @param num_init number of times the algorithm will be run with different -#' centroid seeds -#' @param max_iters the maximum number of clustering iterations -#' @param initializer the method of initialization. One of, optimal_init, -#' quantile_init, kmeans++ and random. See details for more information -#' @param fuzzy either TRUE or FALSE. If TRUE, then prediction probabilities -#' will be calculated using the distance between observations and centroids -#' @param verbose either TRUE or FALSE, indicating whether progress is printed -#' during clustering. -#' @param CENTROIDS a matrix of initial cluster centroids. The rows of the -#' CENTROIDS matrix should be equal to the number of clusters and the columns -#' should be equal to the columns of the data. -#' @param tol a float number. If, in case of an iteration (iteration > 1 and -#' iteration < max_iters) 'tol' is greater than the squared norm of the -#' centroids, then kmeans has converged -#' @param tol_optimal_init tolerance value for the 'optimal_init' initializer. -#' The higher this value is, the far appart from each other the centroids are. -#' @param seed integer value for random number generator (RNG) -#' -#' @return a list with the following attributes: clusters, fuzzy_clusters (if -#' fuzzy = TRUE), centroids, total_SSE, best_initialization, WCSS_per_cluster, -#' obs_per_cluster, between.SS_DIV_total.SS -#' @keywords internal -#' @export -ClusterR_kmeans_fit <- function(data, clusters, num_init = 1, max_iters = 100, - initializer = "kmeans++", fuzzy = FALSE, - verbose = FALSE, CENTROIDS = NULL, tol = 1e-04, - tol_optimal_init = 0.3, seed = 1) { - res <- ClusterR::KMeans_rcpp(data, clusters) - colnames(res$centroids) <- colnames(data) - res -} +#' K-Means +#' +#' @description +#' +#' `k_means()` defines a model that fits clusters based on distances to a number +#' of centers. +#' +#' @param mode A single character string for the type of model. +#' The only possible value for this model is "partition". +#' @param engine A single character string specifying what computational engine +#' to use for fitting. Possible engines are listed below. The default for this +#' model is `"stats"`. +#' @param num_clusters Positive integer, number of clusters in model. +#' +#' @examples +#' # show_engines("k_means") +#' +#' k_means() +#' @export +k_means <- + function(mode = "partition", + engine = "stats", + num_clusters = NULL) { + args <- list( + num_clusters = enquo(num_clusters) + ) + + new_cluster_spec( + "k_means", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = engine + ) + } + +#' @export +print.k_means <- function(x, ...) { + cat("K Means Cluster Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if (!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +#' @export +translate_tidyclust.k_means <- function(x, engine = x$engine, ...) { + x <- translate_tidyclust.default(x, engine, ...) + x +} + +# ------------------------------------------------------------------------------ + +#' @method update k_means +#' @rdname tidyclust_update +#' @export +update.k_means <- function(object, + parameters = NULL, + num_clusters = NULL, + fresh = FALSE, ...) { + + eng_args <- parsnip::update_engine_parameters(object$eng_args, ...) + + if (!is.null(parameters)) { + parameters <- parsnip::check_final_param(parameters) + } + args <- list( + num_clusters = enquo(num_clusters) + ) + + args <- parsnip::update_main_parameters(args, parameters) + + if (fresh) { + object$args <- args + object$eng_args <- eng_args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args + } + + new_cluster_spec( + "k_means", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + +# # ------------------------------------------------------------------------------ + +check_args.k_means <- function(object) { + + args <- lapply(object$args, rlang::eval_tidy) + + if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0)) + rlang::abort("The number of centers should be >= 0.") + + invisible(object) +} + +# ------------------------------------------------------------------------------ + +#' Simple Wrapper around ClusterR kmeans +#' +#' This wrapper runs `ClusterR::KMeans_rcpp` and adds column names to the +#' `centroids` field. +#' +#' @param data matrix or data frame +#' @param clusters the number of clusters +#' @param num_init number of times the algorithm will be run with different +#' centroid seeds +#' @param max_iters the maximum number of clustering iterations +#' @param initializer the method of initialization. One of, optimal_init, +#' quantile_init, kmeans++ and random. See details for more information +#' @param fuzzy either TRUE or FALSE. If TRUE, then prediction probabilities +#' will be calculated using the distance between observations and centroids +#' @param verbose either TRUE or FALSE, indicating whether progress is printed +#' during clustering. +#' @param CENTROIDS a matrix of initial cluster centroids. The rows of the +#' CENTROIDS matrix should be equal to the number of clusters and the columns +#' should be equal to the columns of the data. +#' @param tol a float number. If, in case of an iteration (iteration > 1 and +#' iteration < max_iters) 'tol' is greater than the squared norm of the +#' centroids, then kmeans has converged +#' @param tol_optimal_init tolerance value for the 'optimal_init' initializer. +#' The higher this value is, the far appart from each other the centroids are. +#' @param seed integer value for random number generator (RNG) +#' +#' @return a list with the following attributes: clusters, fuzzy_clusters (if +#' fuzzy = TRUE), centroids, total_SSE, best_initialization, WCSS_per_cluster, +#' obs_per_cluster, between.SS_DIV_total.SS +#' @keywords internal +#' @export +ClusterR_kmeans_fit <- function(data, clusters, num_init = 1, max_iters = 100, + initializer = "kmeans++", fuzzy = FALSE, + verbose = FALSE, CENTROIDS = NULL, tol = 1e-04, + tol_optimal_init = 0.3, seed = 1) { + res <- ClusterR::KMeans_rcpp(data, clusters) + colnames(res$centroids) <- colnames(data) + res +} diff --git a/R/k_means_data.R b/R/k_means_data.R index 4b43d1c5..35b2238c 100644 --- a/R/k_means_data.R +++ b/R/k_means_data.R @@ -1,114 +1,114 @@ -set_new_model_tidyclust("k_means") - -set_model_mode_tidyclust("k_means", "partition") - -# ------------------------------------------------------------------------------ - -set_model_engine_tidyclust("k_means", "partition", "stats") -set_dependency_tidyclust("k_means", "stats", "stats") - -set_fit_tidyclust( - model = "k_means", - eng = "stats", - mode = "partition", - value = list( - interface = "matrix", - protect = c("x", "centers"), - func = c(pkg = "stats", fun = "kmeans"), - defaults = list() - ) -) - -set_encoding_tidyclust( - model = "k_means", - eng = "stats", - mode = "partition", - options = list( - predictor_indicators = "traditional", - compute_intercept = TRUE, - remove_intercept = TRUE, - allow_sparse_x = FALSE - ) -) - -set_model_arg_tidyclust( - model = "k_means", - eng = "stats", - tidyclust = "num_clusters", - original = "centers", - func = list(pkg = "tidyclust", fun = "num_clusters"), - has_submodel = TRUE -) - -set_pred_tidyclust( - model = "k_means", - eng = "stats", - mode = "partition", - type = "cluster", - value = list( - pre = NULL, - post = NULL, - func = c(fun = "stats_kmeans_predict"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) - ) -) - -# ------------------------------------------------------------------------------ - -set_model_engine_tidyclust("k_means", "partition", "ClusterR") -set_dependency_tidyclust("k_means", "ClusterR", "ClusterR") - -set_fit_tidyclust( - model = "k_means", - eng = "ClusterR", - mode = "partition", - value = list( - interface = "matrix", - data = c(x = "data"), - protect = c("data", "clusters"), - func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"), - defaults = list() - ) -) - -set_encoding_tidyclust( - model = "k_means", - eng = "ClusterR", - mode = "partition", - options = list( - predictor_indicators = "traditional", - compute_intercept = TRUE, - remove_intercept = TRUE, - allow_sparse_x = FALSE - ) -) - -set_model_arg_tidyclust( - model = "k_means", - eng = "ClusterR", - tidyclust = "num_clusters", - original = "clusters", - func = list(pkg = "tidyclust", fun = "num_clusters"), - has_submodel = TRUE -) - -set_pred_tidyclust( - model = "k_means", - eng = "ClusterR", - mode = "partition", - type = "cluster", - value = list( - pre = NULL, - post = NULL, - func = c(fun = "clusterR_kmeans_predict"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) - ) -) +set_new_model_tidyclust("k_means") + +set_model_mode_tidyclust("k_means", "partition") + +# ------------------------------------------------------------------------------ + +set_model_engine_tidyclust("k_means", "partition", "stats") +set_dependency_tidyclust("k_means", "stats", "stats") + +set_fit_tidyclust( + model = "k_means", + eng = "stats", + mode = "partition", + value = list( + interface = "matrix", + protect = c("x", "centers"), + func = c(pkg = "stats", fun = "kmeans"), + defaults = list() + ) +) + +set_encoding_tidyclust( + model = "k_means", + eng = "stats", + mode = "partition", + options = list( + predictor_indicators = "traditional", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_model_arg_tidyclust( + model = "k_means", + eng = "stats", + tidyclust = "num_clusters", + original = "centers", + func = list(pkg = "tidyclust", fun = "num_clusters"), + has_submodel = TRUE +) + +set_pred_tidyclust( + model = "k_means", + eng = "stats", + mode = "partition", + type = "cluster", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "stats_kmeans_predict"), + args = + list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine_tidyclust("k_means", "partition", "ClusterR") +set_dependency_tidyclust("k_means", "ClusterR", "ClusterR") + +set_fit_tidyclust( + model = "k_means", + eng = "ClusterR", + mode = "partition", + value = list( + interface = "matrix", + data = c(x = "data"), + protect = c("data", "clusters"), + func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"), + defaults = list() + ) +) + +set_encoding_tidyclust( + model = "k_means", + eng = "ClusterR", + mode = "partition", + options = list( + predictor_indicators = "traditional", + compute_intercept = TRUE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_model_arg_tidyclust( + model = "k_means", + eng = "ClusterR", + tidyclust = "num_clusters", + original = "clusters", + func = list(pkg = "tidyclust", fun = "num_clusters"), + has_submodel = TRUE +) + +set_pred_tidyclust( + model = "k_means", + eng = "ClusterR", + mode = "partition", + type = "cluster", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "clusterR_kmeans_predict"), + args = + list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) + ) +) diff --git a/R/metric-silhouette.R b/R/metric-silhouette.R index 6a78c4dd..ef138210 100644 --- a/R/metric-silhouette.R +++ b/R/metric-silhouette.R @@ -1,107 +1,107 @@ -#' Measures silhouettes between clusters -#' -#' @param object A fitted tidyclust model -#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering. -#' @param dists A distance matrix. Used if `new_data` is `NULL`. -#' @param dist_fun A function for calculating distances between observations. -#' Defaults to Euclidean distance on processed data. -#' -#' @return A tibble giving the silhouettes for each observation. -#' -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' dists <- mtcars %>% -#' as.matrix() %>% -#' dist() -#' -#' silhouettes(kmeans_fit, dists = dists) -#' @export -silhouettes <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist) { - preproc <- prep_data_dist(object, new_data, dists, dist_fun) - - clust_int <- as.integer(gsub("Cluster_", "", preproc$clusters)) - - sil <- cluster::silhouette(clust_int, preproc$dists) - - sil %>% - unclass() %>% - tibble::as_tibble() %>% - dplyr::mutate( - cluster = factor(paste0("Cluster_", cluster)), - neighbor = factor(paste0("Cluster_", neighbor)), - sil_width = as.numeric(sil_width) - ) -} - -#' Measures average silhouette across all observations -#' -#' @param object A fitted kmeans tidyclust model -#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering. -#' @param dists A distance matrix. Used if `new_data` is `NULL`. -#' @param dist_fun A function for calculating distances between observations. -#' Defaults to Euclidean distance on processed data. -#' @param ... Other arguments passed to methods. -#' -#' @return A double; the average silhouette. -#' -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' dists <- mtcars %>% -#' as.matrix() %>% -#' dist() -#' -#' avg_silhouette(kmeans_fit, dists = dists) -#' -#' avg_silhouette_vec(kmeans_fit, dists = dists) -#' @export -avg_silhouette <- function(object, ...) { - UseMethod("avg_silhouette") -} - -avg_silhouette <- new_cluster_metric( - avg_silhouette, - direction = "zero" -) - -#' @export -#' @rdname avg_silhouette -avg_silhouette.cluster_fit <- function(object, new_data = NULL, dists = NULL, - dist_fun = NULL, ...) { - if (is.null(dist_fun)) { - dist_fun <- Rfast::Dist - } - - res <- avg_silhouette_impl(object, new_data, dists, dist_fun, ...) - - tibble::tibble( - .metric = "avg_silhouette", - .estimator = "standard", - .estimate = res - ) -} - -#' @export -#' @rdname avg_silhouette -avg_silhouette.workflow <- avg_silhouette.cluster_fit - -#' @export -#' @rdname avg_silhouette -avg_silhouette_vec <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { - avg_silhouette_impl(object, new_data, dists, dist_fun, ...) - -} - -avg_silhouette_impl <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { - mean(silhouettes(object, new_data, dists, dist_fun, ...)$sil_width) -} +#' Measures silhouettes between clusters +#' +#' @param object A fitted tidyclust model +#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering. +#' @param dists A distance matrix. Used if `new_data` is `NULL`. +#' @param dist_fun A function for calculating distances between observations. +#' Defaults to Euclidean distance on processed data. +#' +#' @return A tibble giving the silhouettes for each observation. +#' +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' dists <- mtcars %>% +#' as.matrix() %>% +#' dist() +#' +#' silhouettes(kmeans_fit, dists = dists) +#' @export +silhouettes <- function(object, new_data = NULL, dists = NULL, + dist_fun = Rfast::Dist) { + preproc <- prep_data_dist(object, new_data, dists, dist_fun) + + clust_int <- as.integer(gsub("Cluster_", "", preproc$clusters)) + + sil <- cluster::silhouette(clust_int, preproc$dists) + + sil %>% + unclass() %>% + tibble::as_tibble() %>% + dplyr::mutate( + cluster = factor(paste0("Cluster_", cluster)), + neighbor = factor(paste0("Cluster_", neighbor)), + sil_width = as.numeric(sil_width) + ) +} + +#' Measures average silhouette across all observations +#' +#' @param object A fitted kmeans tidyclust model +#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering. +#' @param dists A distance matrix. Used if `new_data` is `NULL`. +#' @param dist_fun A function for calculating distances between observations. +#' Defaults to Euclidean distance on processed data. +#' @param ... Other arguments passed to methods. +#' +#' @return A double; the average silhouette. +#' +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' dists <- mtcars %>% +#' as.matrix() %>% +#' dist() +#' +#' avg_silhouette(kmeans_fit, dists = dists) +#' +#' avg_silhouette_vec(kmeans_fit, dists = dists) +#' @export +avg_silhouette <- function(object, ...) { + UseMethod("avg_silhouette") +} + +avg_silhouette <- new_cluster_metric( + avg_silhouette, + direction = "zero" +) + +#' @export +#' @rdname avg_silhouette +avg_silhouette.cluster_fit <- function(object, new_data = NULL, dists = NULL, + dist_fun = NULL, ...) { + if (is.null(dist_fun)) { + dist_fun <- Rfast::Dist + } + + res <- avg_silhouette_impl(object, new_data, dists, dist_fun, ...) + + tibble::tibble( + .metric = "avg_silhouette", + .estimator = "standard", + .estimate = res + ) +} + +#' @export +#' @rdname avg_silhouette +avg_silhouette.workflow <- avg_silhouette.cluster_fit + +#' @export +#' @rdname avg_silhouette +avg_silhouette_vec <- function(object, new_data = NULL, dists = NULL, + dist_fun = Rfast::Dist, ...) { + avg_silhouette_impl(object, new_data, dists, dist_fun, ...) + +} + +avg_silhouette_impl <- function(object, new_data = NULL, dists = NULL, + dist_fun = Rfast::Dist, ...) { + mean(silhouettes(object, new_data, dists, dist_fun, ...)$sil_width) +} diff --git a/R/predict.R b/R/predict.R index a8cf691d..c6c898ce 100644 --- a/R/predict.R +++ b/R/predict.R @@ -1,144 +1,144 @@ -#' Model predictions -#' -#' Apply a model to create different types of predictions. -#' `predict()` can be used for all types of models and uses the -#' "type" argument for more specificity. -#' -#' @param object An object of class `cluster_fit` -#' @param new_data A rectangular data object, such as a data frame. -#' @param type A single character value or `NULL`. Possible values -#' are "cluster", or "raw". When `NULL`, `predict()` will choose an -#' appropriate value based on the model's mode. -#' @param opts A list of optional arguments to the underlying -#' predict function that will be used when `type = "raw"`. The -#' list should not include options for the model object or the -#' new data being predicted. -#' @param ... Arguments to the underlying model's prediction -#' function cannot be passed here (see `opts`). -#' @details If "type" is not supplied to `predict()`, then a choice -#' is made: -#' -#' * `type = "cluster"` for clustering models -#' -#' `predict()` is designed to provide a tidy result (see "Value" -#' section below) in a tibble output format. -#' -#' @return With the exception of `type = "raw"`, the results of -#' `predict.cluster_fit()` will be a tibble as many rows in the output -#' as there are rows in `new_data` and the column names will be -#' predictable. -#' -#' For clustering results the tibble will have a `.pred_cluster` column. -#' -#' Using `type = "raw"` with `predict.cluster_fit()` will return -#' the unadulterated results of the prediction function. -#' -#' When the model fit failed and the error was captured, the -#' `predict()` function will return the same structure as above but -#' filled with missing values. This does not currently work for -#' multivariate models. -#' -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) %>% -#' set_engine("stats") -#' -#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) -#' -#' kmeans_fit %>% -#' predict(new_data = mtcars) -#' @method predict cluster_fit -#' @export predict.cluster_fit -#' @export -predict.cluster_fit <- function(object, new_data, type = NULL, opts = list(), ...) { - if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") - return(NULL) - } - - check_installs(object$spec) - load_libs(object$spec, quiet = TRUE) - - type <- check_pred_type(object, type) - - res <- switch(type, - cluster = predict_cluster(object = object, new_data = new_data, ...), - raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), - - rlang::abort(glue::glue("I don't know about type = '{type}'")) - ) - - res <- switch(type, - cluster = format_cluster(res), - res - ) - res -} - - -check_pred_type <- function(object, type, ...) { - if (is.null(type)) { - type <- - switch(object$spec$mode, - partition = "cluster", - rlang::abort("`type` should be 'cluster'.") - ) - } - if (!(type %in% pred_types)) { - rlang::abort( - glue::glue( - "`type` should be one of: ", - glue::glue_collapse(pred_types, sep = ", ", last = " and ") - ) - ) - } - type -} - -format_cluster <- function(x) { - tibble::tibble(.pred_cluster = unname(x)) -} - -#' Prepare data based on parsnip encoding information -#' @param object A parsnip model object -#' @param new_data A data frame -#' @return A data frame or matrix -#' @keywords internal -#' @export -prepare_data <- function(object, new_data) { - fit_interface <- object$spec$method$fit$interface - - pp_names <- names(object$preproc) - if (any(pp_names == "terms") | any(pp_names == "x_var")) { - # Translation code - if (fit_interface == "formula") { - new_data <- .convert_x_to_form_new(object$preproc, new_data) - } else { - new_data <- .convert_form_to_x_new(object$preproc, new_data)$x - } - } - - remove_intercept <- - get_encoding_tidyclust(class(object$spec)[1]) %>% - dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% - dplyr::pull(remove_intercept) - if (remove_intercept & any(grepl("Intercept", names(new_data)))) { - new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) - } - - switch(fit_interface, - none = new_data, - data.frame = as.data.frame(new_data), - matrix = as.matrix(new_data), - new_data - ) -} - -make_pred_call <- function(x) { - if ("pkg" %in% names(x$func)) { - cl <- rlang::call2(x$func["fun"], !!!x$args, .ns = x$func["pkg"]) - } else { - cl <- rlang::call2(x$func["fun"], !!!x$args) - } - - cl -} +#' Model predictions +#' +#' Apply a model to create different types of predictions. +#' `predict()` can be used for all types of models and uses the +#' "type" argument for more specificity. +#' +#' @param object An object of class `cluster_fit` +#' @param new_data A rectangular data object, such as a data frame. +#' @param type A single character value or `NULL`. Possible values +#' are "cluster", or "raw". When `NULL`, `predict()` will choose an +#' appropriate value based on the model's mode. +#' @param opts A list of optional arguments to the underlying +#' predict function that will be used when `type = "raw"`. The +#' list should not include options for the model object or the +#' new data being predicted. +#' @param ... Arguments to the underlying model's prediction +#' function cannot be passed here (see `opts`). +#' @details If "type" is not supplied to `predict()`, then a choice +#' is made: +#' +#' * `type = "cluster"` for clustering models +#' +#' `predict()` is designed to provide a tidy result (see "Value" +#' section below) in a tibble output format. +#' +#' @return With the exception of `type = "raw"`, the results of +#' `predict.cluster_fit()` will be a tibble as many rows in the output +#' as there are rows in `new_data` and the column names will be +#' predictable. +#' +#' For clustering results the tibble will have a `.pred_cluster` column. +#' +#' Using `type = "raw"` with `predict.cluster_fit()` will return +#' the unadulterated results of the prediction function. +#' +#' When the model fit failed and the error was captured, the +#' `predict()` function will return the same structure as above but +#' filled with missing values. This does not currently work for +#' multivariate models. +#' +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) %>% +#' set_engine("stats") +#' +#' kmeans_fit <- fit(kmeans_spec, ~., mtcars) +#' +#' kmeans_fit %>% +#' predict(new_data = mtcars) +#' @method predict cluster_fit +#' @export predict.cluster_fit +#' @export +predict.cluster_fit <- function(object, new_data, type = NULL, opts = list(), ...) { + if (inherits(object$fit, "try-error")) { + rlang::warn("Model fit failed; cannot make predictions.") + return(NULL) + } + + check_installs(object$spec) + load_libs(object$spec, quiet = TRUE) + + type <- check_pred_type(object, type) + + res <- switch(type, + cluster = predict_cluster(object = object, new_data = new_data, ...), + raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), + + rlang::abort(glue::glue("I don't know about type = '{type}'")) + ) + + res <- switch(type, + cluster = format_cluster(res), + res + ) + res +} + + +check_pred_type <- function(object, type, ...) { + if (is.null(type)) { + type <- + switch(object$spec$mode, + partition = "cluster", + rlang::abort("`type` should be 'cluster'.") + ) + } + if (!(type %in% pred_types)) { + rlang::abort( + glue::glue( + "`type` should be one of: ", + glue::glue_collapse(pred_types, sep = ", ", last = " and ") + ) + ) + } + type +} + +format_cluster <- function(x) { + tibble::tibble(.pred_cluster = unname(x)) +} + +#' Prepare data based on parsnip encoding information +#' @param object A parsnip model object +#' @param new_data A data frame +#' @return A data frame or matrix +#' @keywords internal +#' @export +prepare_data <- function(object, new_data) { + fit_interface <- object$spec$method$fit$interface + + pp_names <- names(object$preproc) + if (any(pp_names == "terms") | any(pp_names == "x_var")) { + # Translation code + if (fit_interface == "formula") { + new_data <- .convert_x_to_form_new(object$preproc, new_data) + } else { + new_data <- .convert_form_to_x_new(object$preproc, new_data)$x + } + } + + remove_intercept <- + get_encoding_tidyclust(class(object$spec)[1]) %>% + dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% + dplyr::pull(remove_intercept) + if (remove_intercept & any(grepl("Intercept", names(new_data)))) { + new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) + } + + switch(fit_interface, + none = new_data, + data.frame = as.data.frame(new_data), + matrix = as.matrix(new_data), + new_data + ) +} + +make_pred_call <- function(x) { + if ("pkg" %in% names(x$func)) { + cl <- rlang::call2(x$func["fun"], !!!x$args, .ns = x$func["pkg"]) + } else { + cl <- rlang::call2(x$func["fun"], !!!x$args) + } + + cl +} diff --git a/R/predict_helpers.R b/R/predict_helpers.R index 55a00f8b..3e667bb7 100644 --- a/R/predict_helpers.R +++ b/R/predict_helpers.R @@ -1,101 +1,101 @@ -stats_kmeans_predict <- function(object, new_data) { - reorder_clusts <- unique(object$cluster) - res <- apply(flexclust::dist2(object$centers[reorder_clusts, , drop = FALSE], new_data), 2, which.min) - res <- paste0("Cluster_", res) - factor(res) -} - -clusterR_kmeans_predict <- function(object, new_data) { - reorder_clusts <- unique(object$clusters) - res <- apply(flexclust::dist2(object$centroids[reorder_clusts, , drop = FALSE], new_data), 2, which.min) - res <- paste0("Cluster_", res) - factor(res) -} - -stats_hier_clust_predict <- function(object, new_data) { - - linkage_method <- object$method - - new_data <- as.matrix(new_data) - - training_data <- as.matrix(attr(object, "training_data")) - clusters <- extract_cluster_assignment(object) - - if (linkage_method %in% c("single", "complete", "average", "median")) { - - ## complete, single, average, and median linkage_methods are basically the same idea, - ## just different summary distance to cluster - - cluster_dist_fun <- switch(linkage_method, - "single" = min, - "complete" = max, - "average" = mean, - "median" = median - ) - - # need this to be obs on rows, dist to new data on cols - dists_new <- Rfast::dista(new_data, training_data, trans = TRUE) - - cluster_dists <- bind_cols(data.frame(dists_new), clusters) %>% - group_by(.cluster) %>% - summarize_all(cluster_dist_fun) - - pred_clusts_num <- cluster_dists %>% - select(-.cluster) %>% - purrr::map_dbl(which.min) - - } else if (linkage_method == "centroid") { - - ## Centroid linkage_method, dist to center - - cluster_centers <- extract_centroids(object) %>% select(-.cluster) - dists_means <- Rfast::dista(new_data, cluster_centers) - - pred_clusts_num <- apply(dists_means, 1, which.min) - - } else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) { - - ## Ward linkage_method: lowest change in ESS - ## dendrograms created from already-squared distances - ## use Ward.D2 on these plain distances for Ward.D - - cluster_centers <- extract_centroids(object) - n_clust <- nrow(cluster_centers) - cluster_names <- cluster_centers[[1]] - cluster_centers <- as.matrix(cluster_centers[, -1]) - - d_means <- purrr::map(1:n_clust, - ~t(t(training_data[clusters$.cluster == cluster_names[.x],]) - cluster_centers[.x, ])) - - n <- nrow(training_data) - - d_new_list <- purrr::map(1:nrow(new_data), - function(new_obs) { - purrr::map(1:n_clust, - ~ t(t(training_data[clusters$.cluster == cluster_names[.x],]) - - new_data[new_obs,]) - ) - } - ) - - change_in_ess <- purrr::map(d_new_list, - function(v) { - purrr::map2_dbl(d_means, v, - ~ sum((n*.x + .y)^2/(n+1)^2 - .x^2) - )} - ) - - pred_clusts_num <- purrr::map_dbl(change_in_ess, which.min) - - } else { - - stop(glue::glue("linkage_method {linkage_method} is not supported for prediction.")) - - } - - - pred_clusts <- unique(clusters$.cluster)[pred_clusts_num] - - return(factor(pred_clusts)) - -} +stats_kmeans_predict <- function(object, new_data) { + reorder_clusts <- unique(object$cluster) + res <- apply(flexclust::dist2(object$centers[reorder_clusts, , drop = FALSE], new_data), 2, which.min) + res <- paste0("Cluster_", res) + factor(res) +} + +clusterR_kmeans_predict <- function(object, new_data) { + reorder_clusts <- unique(object$clusters) + res <- apply(flexclust::dist2(object$centroids[reorder_clusts, , drop = FALSE], new_data), 2, which.min) + res <- paste0("Cluster_", res) + factor(res) +} + +stats_hier_clust_predict <- function(object, new_data) { + + linkage_method <- object$method + + new_data <- as.matrix(new_data) + + training_data <- as.matrix(attr(object, "training_data")) + clusters <- extract_cluster_assignment(object) + + if (linkage_method %in% c("single", "complete", "average", "median")) { + + ## complete, single, average, and median linkage_methods are basically the same idea, + ## just different summary distance to cluster + + cluster_dist_fun <- switch(linkage_method, + "single" = min, + "complete" = max, + "average" = mean, + "median" = median + ) + + # need this to be obs on rows, dist to new data on cols + dists_new <- Rfast::dista(new_data, training_data, trans = TRUE) + + cluster_dists <- bind_cols(data.frame(dists_new), clusters) %>% + group_by(.cluster) %>% + summarize_all(cluster_dist_fun) + + pred_clusts_num <- cluster_dists %>% + select(-.cluster) %>% + purrr::map_dbl(which.min) + + } else if (linkage_method == "centroid") { + + ## Centroid linkage_method, dist to center + + cluster_centers <- extract_centroids(object) %>% select(-.cluster) + dists_means <- Rfast::dista(new_data, cluster_centers) + + pred_clusts_num <- apply(dists_means, 1, which.min) + + } else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) { + + ## Ward linkage_method: lowest change in ESS + ## dendrograms created from already-squared distances + ## use Ward.D2 on these plain distances for Ward.D + + cluster_centers <- extract_centroids(object) + n_clust <- nrow(cluster_centers) + cluster_names <- cluster_centers[[1]] + cluster_centers <- as.matrix(cluster_centers[, -1]) + + d_means <- purrr::map(1:n_clust, + ~t(t(training_data[clusters$.cluster == cluster_names[.x],]) - cluster_centers[.x, ])) + + n <- nrow(training_data) + + d_new_list <- purrr::map(1:nrow(new_data), + function(new_obs) { + purrr::map(1:n_clust, + ~ t(t(training_data[clusters$.cluster == cluster_names[.x],]) + - new_data[new_obs,]) + ) + } + ) + + change_in_ess <- purrr::map(d_new_list, + function(v) { + purrr::map2_dbl(d_means, v, + ~ sum((n*.x + .y)^2/(n+1)^2 - .x^2) + )} + ) + + pred_clusts_num <- purrr::map_dbl(change_in_ess, which.min) + + } else { + + stop(glue::glue("linkage_method {linkage_method} is not supported for prediction.")) + + } + + + pred_clusts <- unique(clusters$.cluster)[pred_clusts_num] + + return(factor(pred_clusts)) + +} diff --git a/R/reexports.R b/R/reexports.R index 6360ac59..356c98ff 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -1,72 +1,72 @@ -#' @importFrom magrittr %>% -#' @export -magrittr::`%>%` - -#' @importFrom generics fit -#' @export -generics::fit - -#' @importFrom generics tidy -#' @export -generics::tidy - -#' @importFrom generics glance -#' @export -generics::glance - -#' @importFrom generics augment -#' @export -generics::augment - -#' @importFrom generics fit_xy -#' @export -generics::fit_xy - -#' @importFrom hardhat extract_parameter_set_dials -#' @export -hardhat::extract_parameter_set_dials - -#' @importFrom hardhat tune -#' @export -hardhat::tune - -#' @importFrom hardhat extract_spec_parsnip -#' @export -hardhat::extract_spec_parsnip - -#' @importFrom generics min_grid -#' @export -generics::min_grid - -#' @importFrom hardhat extract_preprocessor -#' @export -hardhat::extract_preprocessor - -#' @importFrom hardhat extract_fit_parsnip -#' @export -hardhat::extract_fit_parsnip - -#' @importFrom tune load_pkgs -#' @export -tune::load_pkgs - -#' @importFrom generics required_pkgs -#' @export -generics::required_pkgs - -#' @importFrom parsnip predict_raw -#' @export -parsnip::predict_raw - -#' @importFrom parsnip set_args -#' @export -parsnip::set_args - -#' @importFrom parsnip set_engine -#' @export -parsnip::set_engine - -#' @importFrom parsnip set_mode -#' @export -parsnip::set_mode - +#' @importFrom magrittr %>% +#' @export +magrittr::`%>%` + +#' @importFrom generics fit +#' @export +generics::fit + +#' @importFrom generics tidy +#' @export +generics::tidy + +#' @importFrom generics glance +#' @export +generics::glance + +#' @importFrom generics augment +#' @export +generics::augment + +#' @importFrom generics fit_xy +#' @export +generics::fit_xy + +#' @importFrom hardhat extract_parameter_set_dials +#' @export +hardhat::extract_parameter_set_dials + +#' @importFrom hardhat tune +#' @export +hardhat::tune + +#' @importFrom hardhat extract_spec_parsnip +#' @export +hardhat::extract_spec_parsnip + +#' @importFrom generics min_grid +#' @export +generics::min_grid + +#' @importFrom hardhat extract_preprocessor +#' @export +hardhat::extract_preprocessor + +#' @importFrom hardhat extract_fit_parsnip +#' @export +hardhat::extract_fit_parsnip + +#' @importFrom tune load_pkgs +#' @export +tune::load_pkgs + +#' @importFrom generics required_pkgs +#' @export +generics::required_pkgs + +#' @importFrom parsnip predict_raw +#' @export +parsnip::predict_raw + +#' @importFrom parsnip set_args +#' @export +parsnip::set_args + +#' @importFrom parsnip set_engine +#' @export +parsnip::set_engine + +#' @importFrom parsnip set_mode +#' @export +parsnip::set_mode + diff --git a/R/translate.R b/R/translate.R index 6a534a64..2c4d6598 100644 --- a/R/translate.R +++ b/R/translate.R @@ -1,151 +1,151 @@ -#' Resolve a Model Specification for a Computational Engine -#' -#' `translate_tidyclust()` will translate_tidyclust a model specification into a code -#' object that is specific to a particular engine (e.g. R package). -#' It translate_tidyclusts generic parameters to their counterparts. -#' -#' @param x A model specification. -#' @param engine The computational engine for the model (see `?set_engine`). -#' @param ... Not currently used. -#' @details -#' `translate_tidyclust()` produces a _template_ call that lacks the specific -#' argument values (such as `data`, etc). These are filled in once -#' `fit()` is called with the specifics of the data for the model. -#' The call may also include `tune()` arguments if these are in -#' the specification. To handle the `tune()` arguments, you need to use the -#' [tune package](https://tune.tidymodels.org/). For more information -#' see -#' -#' It does contain the resolved argument names that are specific to -#' the model fitting function/engine. -#' -#' This function can be useful when you need to understand how -#' `tidyclust` goes from a generic model specific to a model fitting -#' function. -#' -#' **Note**: this function is used internally and users should only use it -#' to understand what the underlying syntax would be. It should not be used -#' to modify the cluster specification. -#' -#' @export -translate_tidyclust <- function(x, ...) { - UseMethod("translate_tidyclust") -} - -#' @rdname translate_tidyclust -#' @export -#' @export translate_tidyclust.default -translate_tidyclust.default <- function(x, engine = x$engine, ...) { - check_empty_ellipse_tidyclust(...) - if (is.null(engine)) { - rlang::abort("Please set an engine.") - } - - mod_name <- specific_model(x) - - x$engine <- engine - if (x$mode == "unknown") { - rlang::abort("Model code depends on the mode; please specify one.") - } - - check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) - - if (is.null(x$method)) { - x$method <- get_cluster_spec(mod_name, x$mode, engine) - } - - arg_key <- get_args(mod_name, engine) - - # deharmonize primary arguments - actual_args <- deharmonize(x$args, arg_key) - - # check secondary arguments to see if they are in the final - # expression unless there are dots, warn if protected args are - # being altered - x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) - - # keep only modified args - modifed_args <- !map_lgl(actual_args, null_value) - actual_args <- actual_args[modifed_args] - - # look for defaults if not modified in other - if (length(x$method$fit$defaults) > 0) { - in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) - x$defaults <- x$method$fit$defaults[!in_other] - } - - # combine primary, eng_args, and defaults - protected <- lapply(x$method$fit$protect, function(x) rlang::expr(missing_arg())) - names(protected) <- x$method$fit$protect - - x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults) - - x -} - -# ------------------------------------------------------------------------------ -# new code for revised model data structures - -get_cluster_spec <- function(model, mode, engine) { - m_env <- get_model_env_tidyclust() - env_obj <- rlang::env_names(m_env) - env_obj <- grep(model, env_obj, value = TRUE) - - res <- list() - res$libs <- - rlang::env_get(m_env, paste0(model, "_pkgs")) %>% - dplyr::filter(engine == !!engine) %>% - .[["pkg"]] %>% - .[[1]] - - res$fit <- - rlang::env_get(m_env, paste0(model, "_fit")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::pull(value) %>% - .[[1]] - - pred_code <- - rlang::env_get(m_env, paste0(model, "_predict")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::select(-engine, -mode) - - res$pred <- pred_code[["value"]] - names(res$pred) <- pred_code$type - - res -} - -get_args <- function(model, engine) { - m_env <- get_model_env_tidyclust() - rlang::env_get(m_env, paste0(model, "_args")) %>% - dplyr::filter(engine == !!engine) %>% - dplyr::select(-engine) -} - -# to replace harmonize -deharmonize <- function(args, key) { - if (length(args) == 0) { - return(args) - } - parsn <- tibble::tibble(tidyclust = names(args), order = seq_along(args)) - merged <- - dplyr::left_join(parsn, key, by = "tidyclust") %>% - dplyr::arrange(order) - # TODO correct for bad merge? - - names(args) <- merged$original - args[!is.na(merged$original)] -} - -#' Check to ensure that ellipses are empty -#' @param ... Extra arguments. -#' @return If an error is not thrown (from non-empty ellipses), a NULL list. -#' @keywords internal -#' @export -check_empty_ellipse_tidyclust <- function(...) { - terms <- quos(...) - if (!rlang::is_empty(terms)) { - rlang::abort("Please pass other arguments to the model function via `set_engine()`.") - } - terms -} +#' Resolve a Model Specification for a Computational Engine +#' +#' `translate_tidyclust()` will translate_tidyclust a model specification into a code +#' object that is specific to a particular engine (e.g. R package). +#' It translate_tidyclusts generic parameters to their counterparts. +#' +#' @param x A model specification. +#' @param engine The computational engine for the model (see `?set_engine`). +#' @param ... Not currently used. +#' @details +#' `translate_tidyclust()` produces a _template_ call that lacks the specific +#' argument values (such as `data`, etc). These are filled in once +#' `fit()` is called with the specifics of the data for the model. +#' The call may also include `tune()` arguments if these are in +#' the specification. To handle the `tune()` arguments, you need to use the +#' [tune package](https://tune.tidymodels.org/). For more information +#' see +#' +#' It does contain the resolved argument names that are specific to +#' the model fitting function/engine. +#' +#' This function can be useful when you need to understand how +#' `tidyclust` goes from a generic model specific to a model fitting +#' function. +#' +#' **Note**: this function is used internally and users should only use it +#' to understand what the underlying syntax would be. It should not be used +#' to modify the cluster specification. +#' +#' @export +translate_tidyclust <- function(x, ...) { + UseMethod("translate_tidyclust") +} + +#' @rdname translate_tidyclust +#' @export +#' @export translate_tidyclust.default +translate_tidyclust.default <- function(x, engine = x$engine, ...) { + check_empty_ellipse_tidyclust(...) + if (is.null(engine)) { + rlang::abort("Please set an engine.") + } + + mod_name <- specific_model(x) + + x$engine <- engine + if (x$mode == "unknown") { + rlang::abort("Model code depends on the mode; please specify one.") + } + + check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) + + if (is.null(x$method)) { + x$method <- get_cluster_spec(mod_name, x$mode, engine) + } + + arg_key <- get_args(mod_name, engine) + + # deharmonize primary arguments + actual_args <- deharmonize(x$args, arg_key) + + # check secondary arguments to see if they are in the final + # expression unless there are dots, warn if protected args are + # being altered + x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) + + # keep only modified args + modifed_args <- !map_lgl(actual_args, null_value) + actual_args <- actual_args[modifed_args] + + # look for defaults if not modified in other + if (length(x$method$fit$defaults) > 0) { + in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) + x$defaults <- x$method$fit$defaults[!in_other] + } + + # combine primary, eng_args, and defaults + protected <- lapply(x$method$fit$protect, function(x) rlang::expr(missing_arg())) + names(protected) <- x$method$fit$protect + + x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults) + + x +} + +# ------------------------------------------------------------------------------ +# new code for revised model data structures + +get_cluster_spec <- function(model, mode, engine) { + m_env <- get_model_env_tidyclust() + env_obj <- rlang::env_names(m_env) + env_obj <- grep(model, env_obj, value = TRUE) + + res <- list() + res$libs <- + rlang::env_get(m_env, paste0(model, "_pkgs")) %>% + dplyr::filter(engine == !!engine) %>% + .[["pkg"]] %>% + .[[1]] + + res$fit <- + rlang::env_get(m_env, paste0(model, "_fit")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + .[[1]] + + pred_code <- + rlang::env_get(m_env, paste0(model, "_predict")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) + + res$pred <- pred_code[["value"]] + names(res$pred) <- pred_code$type + + res +} + +get_args <- function(model, engine) { + m_env <- get_model_env_tidyclust() + rlang::env_get(m_env, paste0(model, "_args")) %>% + dplyr::filter(engine == !!engine) %>% + dplyr::select(-engine) +} + +# to replace harmonize +deharmonize <- function(args, key) { + if (length(args) == 0) { + return(args) + } + parsn <- tibble::tibble(tidyclust = names(args), order = seq_along(args)) + merged <- + dplyr::left_join(parsn, key, by = "tidyclust") %>% + dplyr::arrange(order) + # TODO correct for bad merge? + + names(args) <- merged$original + args[!is.na(merged$original)] +} + +#' Check to ensure that ellipses are empty +#' @param ... Extra arguments. +#' @return If an error is not thrown (from non-empty ellipses), a NULL list. +#' @keywords internal +#' @export +check_empty_ellipse_tidyclust <- function(...) { + terms <- quos(...) + if (!rlang::is_empty(terms)) { + rlang::abort("Please pass other arguments to the model function via `set_engine()`.") + } + terms +} diff --git a/R/tunable.R b/R/tunable.R index 6b7a9335..41a4bc1f 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -1,81 +1,81 @@ -# Lazily registered in .onLoad() -# Unit tests are in extratests -# nocov start -tunable_cluster_spec <- function(x, ...) { - mod_env <- rlang::ns_env("tidyclust")$tidyclust - - if (is.null(x$engine)) { - abort("Please declare an engine first using `set_engine()`.", call. = FALSE) - } - - arg_name <- paste0(mod_type(x), "_args") - if (!(any(arg_name == names(mod_env)))) { - abort( - paste( - "The `tidyclust` model database doesn't know about the arguments for ", - "model `", mod_type(x), "`. Was it registered?", - sep = "" - ), - call. = FALSE - ) - } - - arg_vals <- - mod_env[[arg_name]] %>% - dplyr::filter(engine == x$engine) %>% - dplyr::select(name = tidyclust, call_info = func) %>% - dplyr::full_join( - tibble::tibble(name = c(names(x$args), names(x$eng_args))), - by = "name" - ) %>% - dplyr::mutate( - source = "cluster_spec", - component = mod_type(x), - component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") - ) - - if (nrow(arg_vals) > 0) { - has_info <- map_lgl(arg_vals$call_info, is.null) - rm_list <- !(has_info & (arg_vals$component_id == "main")) - - arg_vals <- arg_vals[rm_list, ] - } - arg_vals %>% dplyr::select(name, call_info, source, component, component_id) -} - -mod_type <- function(.mod) class(.mod)[class(.mod) != "cluster_spec"][1] - -add_engine_parameters <- function(pset, engines) { - is_engine_param <- pset$name %in% engines$name - if (any(is_engine_param)) { - engine_names <- pset$name[is_engine_param] - pset <- pset[!is_engine_param, ] - pset <- - dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name)) - } - pset -} - -# Lazily registered in .onLoad() -tunable_k_means <- function(x, ...) { - res <- NextMethod() - if (x$engine == "stats") { - res <- add_engine_parameters(res, stats_k_means_engine_args) - } - res -} - -stats_k_means_engine_args <- - tibble::tibble( - name = c( - "centers" - ), - call_info = list( - list(pkg = "tidyclust", fun = "num_clusters") - ), - source = "cluster_spec", - component = "k_means", - component_id = "engine" - ) - -# nocov end +# Lazily registered in .onLoad() +# Unit tests are in extratests +# nocov start +tunable_cluster_spec <- function(x, ...) { + mod_env <- rlang::ns_env("tidyclust")$tidyclust + + if (is.null(x$engine)) { + abort("Please declare an engine first using `set_engine()`.", call. = FALSE) + } + + arg_name <- paste0(mod_type(x), "_args") + if (!(any(arg_name == names(mod_env)))) { + abort( + paste( + "The `tidyclust` model database doesn't know about the arguments for ", + "model `", mod_type(x), "`. Was it registered?", + sep = "" + ), + call. = FALSE + ) + } + + arg_vals <- + mod_env[[arg_name]] %>% + dplyr::filter(engine == x$engine) %>% + dplyr::select(name = tidyclust, call_info = func) %>% + dplyr::full_join( + tibble::tibble(name = c(names(x$args), names(x$eng_args))), + by = "name" + ) %>% + dplyr::mutate( + source = "cluster_spec", + component = mod_type(x), + component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") + ) + + if (nrow(arg_vals) > 0) { + has_info <- map_lgl(arg_vals$call_info, is.null) + rm_list <- !(has_info & (arg_vals$component_id == "main")) + + arg_vals <- arg_vals[rm_list, ] + } + arg_vals %>% dplyr::select(name, call_info, source, component, component_id) +} + +mod_type <- function(.mod) class(.mod)[class(.mod) != "cluster_spec"][1] + +add_engine_parameters <- function(pset, engines) { + is_engine_param <- pset$name %in% engines$name + if (any(is_engine_param)) { + engine_names <- pset$name[is_engine_param] + pset <- pset[!is_engine_param, ] + pset <- + dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name)) + } + pset +} + +# Lazily registered in .onLoad() +tunable_k_means <- function(x, ...) { + res <- NextMethod() + if (x$engine == "stats") { + res <- add_engine_parameters(res, stats_k_means_engine_args) + } + res +} + +stats_k_means_engine_args <- + tibble::tibble( + name = c( + "centers" + ), + call_info = list( + list(pkg = "tidyclust", fun = "num_clusters") + ), + source = "cluster_spec", + component = "k_means", + component_id = "engine" + ) + +# nocov end diff --git a/R/update.R b/R/update.R index f343035e..0e4d3d5f 100644 --- a/R/update.R +++ b/R/update.R @@ -1,28 +1,28 @@ -#' Update a cluster specification -#' -#' @description -#' If parameters of a cluster specification need to be modified, `update()` can -#' be used in lieu of recreating the object from scratch. -#' -#' @inheritParams k_means -#' @param object A cluster specification. -#' @param parameters A 1-row tibble or named list with _main_ -#' parameters to update. Use **either** `parameters` **or** the main arguments -#' directly when updating. If the main arguments are used, -#' these will supersede the values in `parameters`. Also, using -#' engine arguments in this object will result in an error. -#' @param ... Not used for `update()`. -#' @param fresh A logical for whether the arguments should be -#' modified in-place or replaced wholesale. -#' @return An updated cluster specification. -#' @name tidyclust_update -#' @examples -#' kmeans_spec <- k_means(num_clusters = 5) -#' kmeans_spec -#' update(kmeans_spec, num_clusters = 1) -#' update(kmeans_spec, num_clusters = 1, fresh = TRUE) -#' -#' param_values <- tibble::tibble(num_clusters = 10) -#' -#' kmeans_spec %>% update(param_values) -NULL +#' Update a cluster specification +#' +#' @description +#' If parameters of a cluster specification need to be modified, `update()` can +#' be used in lieu of recreating the object from scratch. +#' +#' @inheritParams k_means +#' @param object A cluster specification. +#' @param parameters A 1-row tibble or named list with _main_ +#' parameters to update. Use **either** `parameters` **or** the main arguments +#' directly when updating. If the main arguments are used, +#' these will supersede the values in `parameters`. Also, using +#' engine arguments in this object will result in an error. +#' @param ... Not used for `update()`. +#' @param fresh A logical for whether the arguments should be +#' modified in-place or replaced wholesale. +#' @return An updated cluster specification. +#' @name tidyclust_update +#' @examples +#' kmeans_spec <- k_means(num_clusters = 5) +#' kmeans_spec +#' update(kmeans_spec, num_clusters = 1) +#' update(kmeans_spec, num_clusters = 1, fresh = TRUE) +#' +#' param_values <- tibble::tibble(num_clusters = 10) +#' +#' kmeans_spec %>% update(param_values) +NULL diff --git a/README.Rmd b/README.Rmd index 5ec8f3f7..8761424c 100644 --- a/README.Rmd +++ b/README.Rmd @@ -1,74 +1,74 @@ ---- -output: github_document ---- - - - -```{r, include = FALSE} -knitr::opts_chunk$set( - collapse = TRUE, - comment = "#>", - fig.path = "man/figures/README-", - out.width = "100%" -) -``` - -# tidyclust - - -[![Codecov test coverage](https://codecov.io/gh/EmilHvitfeldt/tidyclust/branch/main/graph/badge.svg)](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main) -[![R-CMD-check](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml) - - -The goal of tidyclust is to provide a tidy, unified interface to clustering models. The packages is closely modeled after the [parsnip](https://parsnip.tidymodels.org/) package. - -## Installation - -You can install the development version of tidyclust from [GitHub](https://github.com/) with: - -``` r -# install.packages("devtools") -devtools::install_github("EmilHvitfeldt/tidyclust") -``` - -Please note that this package currently requires a [branch of the workflows](https://github.com/tidymodels/workflows/tree/tidyclust) package to work. Use with caution. - -## Example - -The first thing you do is to create a `cluster specification`. For this example we are creating a K-means model, using the `stats` engine. - -```{r} -library(tidyclust) - -kmeans_spec <- k_means(num_clusters = 3) %>% - set_engine("stats") - -kmeans_spec -``` - -This specification can then be fit using data. - -```{r} -kmeans_spec_fit <- kmeans_spec %>% - fit(~., data = mtcars) -kmeans_spec_fit -``` - -Once you have a fitted tidyclust object, you can do a number of things. `predict()` returns the cluster a new observation belongs to - -```{r} -predict(kmeans_spec_fit, mtcars[1:4, ]) -``` - -`extract_cluster_assignment()` returns the cluster assignments of the training observations - -```{r} -extract_cluster_assignment(kmeans_spec_fit) -``` - -and `extract_clusters()` returns the locations of the clusters - -```{r} -extract_centroids(kmeans_spec_fit) -``` - +--- +output: github_document +--- + + + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>", + fig.path = "man/figures/README-", + out.width = "100%" +) +``` + +# tidyclust + + +[![Codecov test coverage](https://codecov.io/gh/EmilHvitfeldt/tidyclust/branch/main/graph/badge.svg)](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main) +[![R-CMD-check](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml) + + +The goal of tidyclust is to provide a tidy, unified interface to clustering models. The packages is closely modeled after the [parsnip](https://parsnip.tidymodels.org/) package. + +## Installation + +You can install the development version of tidyclust from [GitHub](https://github.com/) with: + +``` r +# install.packages("devtools") +devtools::install_github("EmilHvitfeldt/tidyclust") +``` + +Please note that this package currently requires a [branch of the workflows](https://github.com/tidymodels/workflows/tree/tidyclust) package to work. Use with caution. + +## Example + +The first thing you do is to create a `cluster specification`. For this example we are creating a K-means model, using the `stats` engine. + +```{r} +library(tidyclust) + +kmeans_spec <- k_means(num_clusters = 3) %>% + set_engine("stats") + +kmeans_spec +``` + +This specification can then be fit using data. + +```{r} +kmeans_spec_fit <- kmeans_spec %>% + fit(~., data = mtcars) +kmeans_spec_fit +``` + +Once you have a fitted tidyclust object, you can do a number of things. `predict()` returns the cluster a new observation belongs to + +```{r} +predict(kmeans_spec_fit, mtcars[1:4, ]) +``` + +`extract_cluster_assignment()` returns the cluster assignments of the training observations + +```{r} +extract_cluster_assignment(kmeans_spec_fit) +``` + +and `extract_clusters()` returns the locations of the clusters + +```{r} +extract_centroids(kmeans_spec_fit) +``` + diff --git a/README.md b/README.md index 50ce6989..98ce0a0d 100644 --- a/README.md +++ b/README.md @@ -1,145 +1,145 @@ - - - -# tidyclust - - - -[![Codecov test -coverage](https://codecov.io/gh/EmilHvitfeldt/tidyclust/branch/main/graph/badge.svg)](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main) -[![R-CMD-check](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml) - - -The goal of tidyclust is to provide a tidy, unified interface to -clustering models. The packages is closely modeled after the -[parsnip](https://parsnip.tidymodels.org/) package. - -## Installation - -You can install the development version of tidyclust from -[GitHub](https://github.com/) with: - -``` r -# install.packages("devtools") -devtools::install_github("EmilHvitfeldt/tidyclust") -``` - -Please note that this package currently requires a [branch of the -workflows](https://github.com/tidymodels/workflows/tree/tidyclust) -package to work. Use with caution. - -## Example - -The first thing you do is to create a `cluster specification`. For this -example we are creating a K-means model, using the `stats` engine. - -``` r -library(tidyclust) - -kmeans_spec <- k_means(num_clusters = 3) %>% - set_engine("stats") - -kmeans_spec -#> K Means Cluster Specification (partition) -#> -#> Main Arguments: -#> num_clusters = 3 -#> -#> Computational engine: stats -``` - -This specification can then be fit using data. - -``` r -kmeans_spec_fit <- kmeans_spec %>% - fit(~., data = mtcars) -kmeans_spec_fit -#> tidyclust cluster object -#> -#> K-means clustering with 3 clusters of sizes 9, 16, 7 -#> -#> Cluster means: -#> mpg cyl disp hp drat wt qsec vs -#> 1 14.64444 8.000000 388.2222 232.1111 3.343333 4.161556 16.40444 0.0000000 -#> 2 24.50000 4.625000 122.2937 96.8750 4.002500 2.518000 18.54312 0.7500000 -#> 3 17.01429 7.428571 276.0571 150.7143 2.994286 3.601429 18.11857 0.2857143 -#> am gear carb -#> 1 0.2222222 3.444444 4.000000 -#> 2 0.6875000 4.125000 2.437500 -#> 3 0.0000000 3.000000 2.142857 -#> -#> Clustering vector: -#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive -#> 2 2 2 3 -#> Hornet Sportabout Valiant Duster 360 Merc 240D -#> 1 3 1 2 -#> Merc 230 Merc 280 Merc 280C Merc 450SE -#> 2 2 2 3 -#> Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental -#> 3 3 1 1 -#> Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla -#> 1 2 2 2 -#> Toyota Corona Dodge Challenger AMC Javelin Camaro Z28 -#> 2 3 3 1 -#> Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa -#> 1 2 2 2 -#> Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E -#> 1 2 1 2 -#> -#> Within cluster sum of squares by cluster: -#> [1] 46659.32 32838.00 11846.09 -#> (between_SS / total_SS = 85.3 %) -#> -#> Available components: -#> -#> [1] "cluster" "centers" "totss" "withinss" "tot.withinss" -#> [6] "betweenss" "size" "iter" "ifault" -``` - -Once you have a fitted tidyclust object, you can do a number of things. -`predict()` returns the cluster a new observation belongs to - -``` r -predict(kmeans_spec_fit, mtcars[1:4, ]) -#> # A tibble: 4 × 1 -#> .pred_cluster -#> -#> 1 Cluster_1 -#> 2 Cluster_1 -#> 3 Cluster_1 -#> 4 Cluster_2 -``` - -`extract_cluster_assignment()` returns the cluster assignments of the -training observations - -``` r -extract_cluster_assignment(kmeans_spec_fit) -#> # A tibble: 32 × 1 -#> .cluster -#> -#> 1 Cluster_1 -#> 2 Cluster_1 -#> 3 Cluster_1 -#> 4 Cluster_2 -#> 5 Cluster_3 -#> 6 Cluster_2 -#> 7 Cluster_3 -#> 8 Cluster_1 -#> 9 Cluster_1 -#> 10 Cluster_1 -#> # … with 22 more rows -#> # ℹ Use `print(n = ...)` to see more rows -``` - -and `extract_clusters()` returns the locations of the clusters - -``` r -extract_centroids(kmeans_spec_fit) -#> # A tibble: 3 × 12 -#> .cluster mpg cyl disp hp drat wt qsec vs am gear carb -#> -#> 1 Cluster_1 17.0 7.43 276. 151. 2.99 3.60 18.1 0.286 0 3 2.14 -#> 2 Cluster_2 14.6 8 388. 232. 3.34 4.16 16.4 0 0.222 3.44 4 -#> 3 Cluster_3 24.5 4.62 122. 96.9 4.00 2.52 18.5 0.75 0.688 4.12 2.44 -``` + + + +# tidyclust + + + +[![Codecov test +coverage](https://codecov.io/gh/EmilHvitfeldt/tidyclust/branch/main/graph/badge.svg)](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main) +[![R-CMD-check](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml) + + +The goal of tidyclust is to provide a tidy, unified interface to +clustering models. The packages is closely modeled after the +[parsnip](https://parsnip.tidymodels.org/) package. + +## Installation + +You can install the development version of tidyclust from +[GitHub](https://github.com/) with: + +``` r +# install.packages("devtools") +devtools::install_github("EmilHvitfeldt/tidyclust") +``` + +Please note that this package currently requires a [branch of the +workflows](https://github.com/tidymodels/workflows/tree/tidyclust) +package to work. Use with caution. + +## Example + +The first thing you do is to create a `cluster specification`. For this +example we are creating a K-means model, using the `stats` engine. + +``` r +library(tidyclust) + +kmeans_spec <- k_means(num_clusters = 3) %>% + set_engine("stats") + +kmeans_spec +#> K Means Cluster Specification (partition) +#> +#> Main Arguments: +#> num_clusters = 3 +#> +#> Computational engine: stats +``` + +This specification can then be fit using data. + +``` r +kmeans_spec_fit <- kmeans_spec %>% + fit(~., data = mtcars) +kmeans_spec_fit +#> tidyclust cluster object +#> +#> K-means clustering with 3 clusters of sizes 9, 16, 7 +#> +#> Cluster means: +#> mpg cyl disp hp drat wt qsec vs +#> 1 14.64444 8.000000 388.2222 232.1111 3.343333 4.161556 16.40444 0.0000000 +#> 2 24.50000 4.625000 122.2937 96.8750 4.002500 2.518000 18.54312 0.7500000 +#> 3 17.01429 7.428571 276.0571 150.7143 2.994286 3.601429 18.11857 0.2857143 +#> am gear carb +#> 1 0.2222222 3.444444 4.000000 +#> 2 0.6875000 4.125000 2.437500 +#> 3 0.0000000 3.000000 2.142857 +#> +#> Clustering vector: +#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive +#> 2 2 2 3 +#> Hornet Sportabout Valiant Duster 360 Merc 240D +#> 1 3 1 2 +#> Merc 230 Merc 280 Merc 280C Merc 450SE +#> 2 2 2 3 +#> Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental +#> 3 3 1 1 +#> Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla +#> 1 2 2 2 +#> Toyota Corona Dodge Challenger AMC Javelin Camaro Z28 +#> 2 3 3 1 +#> Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa +#> 1 2 2 2 +#> Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E +#> 1 2 1 2 +#> +#> Within cluster sum of squares by cluster: +#> [1] 46659.32 32838.00 11846.09 +#> (between_SS / total_SS = 85.3 %) +#> +#> Available components: +#> +#> [1] "cluster" "centers" "totss" "withinss" "tot.withinss" +#> [6] "betweenss" "size" "iter" "ifault" +``` + +Once you have a fitted tidyclust object, you can do a number of things. +`predict()` returns the cluster a new observation belongs to + +``` r +predict(kmeans_spec_fit, mtcars[1:4, ]) +#> # A tibble: 4 × 1 +#> .pred_cluster +#> +#> 1 Cluster_1 +#> 2 Cluster_1 +#> 3 Cluster_1 +#> 4 Cluster_2 +``` + +`extract_cluster_assignment()` returns the cluster assignments of the +training observations + +``` r +extract_cluster_assignment(kmeans_spec_fit) +#> # A tibble: 32 × 1 +#> .cluster +#> +#> 1 Cluster_1 +#> 2 Cluster_1 +#> 3 Cluster_1 +#> 4 Cluster_2 +#> 5 Cluster_3 +#> 6 Cluster_2 +#> 7 Cluster_3 +#> 8 Cluster_1 +#> 9 Cluster_1 +#> 10 Cluster_1 +#> # … with 22 more rows +#> # ℹ Use `print(n = ...)` to see more rows +``` + +and `extract_clusters()` returns the locations of the clusters + +``` r +extract_centroids(kmeans_spec_fit) +#> # A tibble: 3 × 12 +#> .cluster mpg cyl disp hp drat wt qsec vs am gear carb +#> +#> 1 Cluster_1 17.0 7.43 276. 151. 2.99 3.60 18.1 0.286 0 3 2.14 +#> 2 Cluster_2 14.6 8 388. 232. 3.34 4.16 16.4 0 0.222 3.44 4 +#> 3 Cluster_3 24.5 4.62 122. 96.9 4.00 2.52 18.5 0.75 0.688 4.12 2.44 +``` diff --git a/_pkgdown.yml b/_pkgdown.yml index 09616d33..2d3474c3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -1,67 +1,67 @@ -url: ~ -template: - bootstrap: 5 - -reference: -- title: Specifications - desc: > - These cluster specification fucntion are used to specify the type of model - you want to do. These functions work in a similar fashion to the [model - specification function from parsnip](https://parsnip.tidymodels.org/reference/index.html#models). - contents: - - k_means -- title: Fit and Inspect - desc: > - These Functions are the generics are are supported for specifications - created with tidyclust. - contents: - - fit.cluster_spec - - augment.cluster_fit - - glance.cluster_fit - - tidy.cluster_fit -- title: Prediction - desc: > - Once the cluster specification have been fit, you are likely to want to look - at where the clusters are and which observations are associated with which - cluster. - contents: - - predict.cluster_fit - - extract_cluster_assignment - - extract_centroids -- title: Parameter Objects - desc: > - Parameter objects for tuning. Similar to - [parameter objects from dials package](https://dials.tidymodels.org/reference/index.html#parameter-objects) - contents: - - num_clusters -- title: Model based performance metrics - desc: > - These metrics use the fitted clustering model to extract values denoting how - well the model works. - contents: - - cluster_metric_set - - avg_silhouette - - tot_sse - - sse_ratio - - tot_wss -- title: Tuning - desc: > - Functions to allow multiple cluster specifications to be fit at once. - contents: - - control_cluster - - tidyclust_update - - finalize_model_tidyclust - - tune_cluster -- title: Developer tools - contents: - - extract_fit_summary - - get_centroid_dists - - new_cluster_metric - - prep_data_dist - - reconcile_clusterings - - translate_tidyclust - - within_cluster_sse -- title: Other - contents: - - enrichment - - silhouettes +url: ~ +template: + bootstrap: 5 + +reference: +- title: Specifications + desc: > + These cluster specification fucntion are used to specify the type of model + you want to do. These functions work in a similar fashion to the [model + specification function from parsnip](https://parsnip.tidymodels.org/reference/index.html#models). + contents: + - k_means +- title: Fit and Inspect + desc: > + These Functions are the generics are are supported for specifications + created with tidyclust. + contents: + - fit.cluster_spec + - augment.cluster_fit + - glance.cluster_fit + - tidy.cluster_fit +- title: Prediction + desc: > + Once the cluster specification have been fit, you are likely to want to look + at where the clusters are and which observations are associated with which + cluster. + contents: + - predict.cluster_fit + - extract_cluster_assignment + - extract_centroids +- title: Parameter Objects + desc: > + Parameter objects for tuning. Similar to + [parameter objects from dials package](https://dials.tidymodels.org/reference/index.html#parameter-objects) + contents: + - num_clusters +- title: Model based performance metrics + desc: > + These metrics use the fitted clustering model to extract values denoting how + well the model works. + contents: + - cluster_metric_set + - avg_silhouette + - tot_sse + - sse_ratio + - tot_wss +- title: Tuning + desc: > + Functions to allow multiple cluster specifications to be fit at once. + contents: + - control_cluster + - tidyclust_update + - finalize_model_tidyclust + - tune_cluster +- title: Developer tools + contents: + - extract_fit_summary + - get_centroid_dists + - new_cluster_metric + - prep_data_dist + - reconcile_clusterings + - translate_tidyclust + - within_cluster_sse +- title: Other + contents: + - enrichment + - silhouettes diff --git a/dev/cross_val_kmeans.R b/dev/cross_val_kmeans.R index 30880662..538985db 100644 --- a/dev/cross_val_kmeans.R +++ b/dev/cross_val_kmeans.R @@ -1,108 +1,108 @@ -library(tidymodels) -library(tidyverse) -library(tidyclust) - -## "Cross-validation" for kmeans - -ir <- iris %>% select(-Species) - -cvs <- vfold_cv(ir, v = 5) - -res <- data.frame( - k = NA, - i = NA, - wss = NA, - sil = NA, - wss_2 = NA -) - -for (k in 2:10) { - - km <- k_means(k = k) %>% - set_engine("stats") - - - for (i in 1:5) { - - tmp_train <- training(cvs$splits[[i]]) - tmp_test <- testing(cvs$splits[[i]]) - - km_fit <- km %>% fit(~., data = tmp_train) - - wss <- km_fit %>% - tot_wss(tmp_test) - - wss_2 <- km_fit$fit$tot.withinss - - sil <- km_fit %>% - avg_silhouette(tmp_test) - - res <- rbind(res, - c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) - - } - -} - -res %>% - drop_na() %>% - ggplot(aes(x = factor(k), y = sil)) + - geom_point() - - -### Second idea -## What if we cluster the whole data, then see how well subsamples are reclassified? -## This needs "predict" -## Doesn't really make sense yet - -cvs <- vfold_cv(ir, v = 10) - -res <- data.frame( - k = NA, - i = NA, - acc = NA, - f1 = NA -) - -for (k in 2:10) { - - km <- k_means(k = k) %>% - set_engine("stats") - - full_fit <- km %>% fit(~., data = ir) - - - for (i in 1:10) { - - tmp_train <- training(cvs$splits[[i]]) - tmp_test <- testing(cvs$splits[[i]]) - - km_fit <- km %>% fit(~., data = tmp_train) - - dat <- tmp_test %>% - mutate( - truth = predict(full_fit, tmp_test)$.pred_cluster, - estimate = predict(km_fit, tmp_test)$.pred_cluster - ) - - thing <- reconcile_clusterings(dat$truth, dat$estimate) - - acc <- accuracy(thing, clusters_1, clusters_2) - f1 <- f_meas(thing, clusters_1, clusters_2) - - res <- rbind(res, - c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate)) - - } - -} - - -res %>% - ggplot(aes(x = factor(k), y = f1)) + - geom_point() - - -### use orders from reconciling to order centers and check center similarity? -### or to get "raw probabilities" - what does that mean though? -### to do predict = raw +library(tidymodels) +library(tidyverse) +library(tidyclust) + +## "Cross-validation" for kmeans + +ir <- iris %>% select(-Species) + +cvs <- vfold_cv(ir, v = 5) + +res <- data.frame( + k = NA, + i = NA, + wss = NA, + sil = NA, + wss_2 = NA +) + +for (k in 2:10) { + + km <- k_means(k = k) %>% + set_engine("stats") + + + for (i in 1:5) { + + tmp_train <- training(cvs$splits[[i]]) + tmp_test <- testing(cvs$splits[[i]]) + + km_fit <- km %>% fit(~., data = tmp_train) + + wss <- km_fit %>% + tot_wss(tmp_test) + + wss_2 <- km_fit$fit$tot.withinss + + sil <- km_fit %>% + avg_silhouette(tmp_test) + + res <- rbind(res, + c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) + + } + +} + +res %>% + drop_na() %>% + ggplot(aes(x = factor(k), y = sil)) + + geom_point() + + +### Second idea +## What if we cluster the whole data, then see how well subsamples are reclassified? +## This needs "predict" +## Doesn't really make sense yet + +cvs <- vfold_cv(ir, v = 10) + +res <- data.frame( + k = NA, + i = NA, + acc = NA, + f1 = NA +) + +for (k in 2:10) { + + km <- k_means(k = k) %>% + set_engine("stats") + + full_fit <- km %>% fit(~., data = ir) + + + for (i in 1:10) { + + tmp_train <- training(cvs$splits[[i]]) + tmp_test <- testing(cvs$splits[[i]]) + + km_fit <- km %>% fit(~., data = tmp_train) + + dat <- tmp_test %>% + mutate( + truth = predict(full_fit, tmp_test)$.pred_cluster, + estimate = predict(km_fit, tmp_test)$.pred_cluster + ) + + thing <- reconcile_clusterings(dat$truth, dat$estimate) + + acc <- accuracy(thing, clusters_1, clusters_2) + f1 <- f_meas(thing, clusters_1, clusters_2) + + res <- rbind(res, + c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate)) + + } + +} + + +res %>% + ggplot(aes(x = factor(k), y = f1)) + + geom_point() + + +### use orders from reconciling to order centers and check center similarity? +### or to get "raw probabilities" - what does that mean though? +### to do predict = raw diff --git a/dev/kmeans.Rmd b/dev/kmeans.Rmd index 0fa5245c..ef08b912 100644 --- a/dev/kmeans.Rmd +++ b/dev/kmeans.Rmd @@ -1,132 +1,132 @@ ---- -title: "k-means Clustering" -output: rmarkdown::html_vignette -vignette: > - %\VignetteIndexEntry{k-means Clustering} - %\VignetteEngine{knitr::rmarkdown} - %\VignetteEncoding{UTF-8} ---- - -```{r, include = FALSE} -knitr::opts_chunk$set( - collapse = TRUE, - comment = "#>" -) -``` - -```{r setup} -library(tidyclust) -library(palmerpenguins) -library(tidymodels) -``` - -## Fit - -```{r} -kmeans_spec <- k_means(k = 5) %>% - set_engine("stats") - -penguins_rec_1 <- recipe(~ ., data = penguins) %>% - update_role(species, island, new_role = "demographic") %>% - step_dummy(sex) - - -penguins_rec_2 <- recipe(species ~ ., data = penguins) %>% - step_dummy(sex, island) - -wflow_1 <- workflow() %>% - add_model(kmeans_spec) %>% - add_recipe(penguins_rec_1) - - -wflow_2 <- workflow() %>% - add_model(kmeans_spec) %>% - add_recipe(penguins_rec_2) -``` - -We need workflows! - -```{r} -# dropping NA first so rows match up later, this is clunky -pen_sub <- penguins %>% - drop_na() %>% - select(-species, -island, -sex) - -kmeans_fit <- kmeans_spec %>% fit( ~., pen_sub) - -kmeans_fit %>% - predict(new_data = pen_sub) - -### try my new version -kmeans_fit %>% - extract_cluster_assignment() -``` - -* Needed flexclust install; probably not necessary, we could implement for kmeans with just dists - -* Missing values should probably return an NA prediction. Or for k-means, imputation isn't crazy... - -* We want a consistent return, and k-means is randomized. We should agree to a consistent default ordering throughout tidyclust - maybe by size (number of members) or something? - - -```{r} -penguins %>% - drop_na() %>% - mutate( - preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster - ) %>% - count(preds, sex, species) -``` - -## Diagnostics - -* Measure cluster enrichment with demographic variables (an official recipe designation?). Include Chi-Square or ANOVA tests??? - -* Exploit confusion matrix from `yardstick` - -* Characterization of clusters, presumably by centers? - -* Automatic plot?! - - -## talk to Emil - -* Ordering of clusters is really bothering me -something with indices - -* Cluster density followup: is it kind of a model? - -```{r} -recipe( ~ demo1 + predictor1) %>% - step_tidyclust(kmeans_fit) # doesn't quite make sense - -``` - -* Some of these followups feel like they need a recipe. Which vars are we using for within SS? Which vars for enrichments? Do we PCA first? etc. - -```{r} -get_SS(Cluster ~ v1 + v2) -recipe(Cluster ~ v1 + v2) %>% - step_pca() %>% - get_ss() -``` - -... or maybe fit `enrichment()` on a recipe and it automatically uses the variables with the "enrich" role? -how would PCA fit in on this? - -How does `fit()` access the right variables? - -... is this even worth it given we can attach cluster assignments? I say yes, because what if we are trying to "cross-validate". - - -```{r} -penguins_2 <- penguins %>% - drop_na() %>% - mutate( - preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster - ) - -debugonce(enrichment) -penguins_2 %>% - enrichment(preds, species) -``` +--- +title: "k-means Clustering" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{k-means Clustering} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +```{r setup} +library(tidyclust) +library(palmerpenguins) +library(tidymodels) +``` + +## Fit + +```{r} +kmeans_spec <- k_means(k = 5) %>% + set_engine("stats") + +penguins_rec_1 <- recipe(~ ., data = penguins) %>% + update_role(species, island, new_role = "demographic") %>% + step_dummy(sex) + + +penguins_rec_2 <- recipe(species ~ ., data = penguins) %>% + step_dummy(sex, island) + +wflow_1 <- workflow() %>% + add_model(kmeans_spec) %>% + add_recipe(penguins_rec_1) + + +wflow_2 <- workflow() %>% + add_model(kmeans_spec) %>% + add_recipe(penguins_rec_2) +``` + +We need workflows! + +```{r} +# dropping NA first so rows match up later, this is clunky +pen_sub <- penguins %>% + drop_na() %>% + select(-species, -island, -sex) + +kmeans_fit <- kmeans_spec %>% fit( ~., pen_sub) + +kmeans_fit %>% + predict(new_data = pen_sub) + +### try my new version +kmeans_fit %>% + extract_cluster_assignment() +``` + +* Needed flexclust install; probably not necessary, we could implement for kmeans with just dists + +* Missing values should probably return an NA prediction. Or for k-means, imputation isn't crazy... + +* We want a consistent return, and k-means is randomized. We should agree to a consistent default ordering throughout tidyclust - maybe by size (number of members) or something? + + +```{r} +penguins %>% + drop_na() %>% + mutate( + preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster + ) %>% + count(preds, sex, species) +``` + +## Diagnostics + +* Measure cluster enrichment with demographic variables (an official recipe designation?). Include Chi-Square or ANOVA tests??? + +* Exploit confusion matrix from `yardstick` + +* Characterization of clusters, presumably by centers? + +* Automatic plot?! + + +## talk to Emil + +* Ordering of clusters is really bothering me +something with indices + +* Cluster density followup: is it kind of a model? + +```{r} +recipe( ~ demo1 + predictor1) %>% + step_tidyclust(kmeans_fit) # doesn't quite make sense + +``` + +* Some of these followups feel like they need a recipe. Which vars are we using for within SS? Which vars for enrichments? Do we PCA first? etc. + +```{r} +get_SS(Cluster ~ v1 + v2) +recipe(Cluster ~ v1 + v2) %>% + step_pca() %>% + get_ss() +``` + +... or maybe fit `enrichment()` on a recipe and it automatically uses the variables with the "enrich" role? +how would PCA fit in on this? + +How does `fit()` access the right variables? + +... is this even worth it given we can attach cluster assignments? I say yes, because what if we are trying to "cross-validate". + + +```{r} +penguins_2 <- penguins %>% + drop_na() %>% + mutate( + preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster + ) + +debugonce(enrichment) +penguins_2 %>% + enrichment(preds, species) +``` diff --git a/dev/test_hc.R b/dev/test_hc.R index d3ebe427..b6939152 100644 --- a/dev/test_hc.R +++ b/dev/test_hc.R @@ -1,33 +1,33 @@ -library(tidyverse) -library(celery) - -ir <- iris[,-5] - -hclust(dist(ir)) - -bob <- hclust_fit(ir) - -hc <- hier_clust(k = 3) %>% - fit(~ ., data = ir) - - -km <- k_means(k = 3) %>% - fit(~., data = ir) - - -thing <- tibble( - km_c = extract_cluster_assignment(km)$.cluster, - hc_c = extract_cluster_assignment(hc)$.cluster, - truth = iris$Species -) - -thing %>% - count(hc_c,truth) - -cutree(hc$fit, k = 3) - -# hc %>% -# extract_fit_engine() %>% -# cutree(k = 3) - -## reconcile? +library(tidyverse) +library(celery) + +ir <- iris[,-5] + +hclust(dist(ir)) + +bob <- hclust_fit(ir) + +hc <- hier_clust(k = 3) %>% + fit(~ ., data = ir) + + +km <- k_means(k = 3) %>% + fit(~., data = ir) + + +thing <- tibble( + km_c = extract_cluster_assignment(km)$.cluster, + hc_c = extract_cluster_assignment(hc)$.cluster, + truth = iris$Species +) + +thing %>% + count(hc_c,truth) + +cutree(hc$fit, k = 3) + +# hc %>% +# extract_fit_engine() %>% +# cutree(k = 3) + +## reconcile? diff --git a/dev/test_hclust_predict.R b/dev/test_hclust_predict.R index a6caee32..700d1b0a 100644 --- a/dev/test_hclust_predict.R +++ b/dev/test_hclust_predict.R @@ -1,39 +1,39 @@ -library(tidyverse) -library(tidyclust) -# -# my_mod <- hier_clust(k = 3) %>% fit(~., mtcars) -# -# #debugonce(tidyclust:::stats_hier_clust_predict) -# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) -# -# my_mod <- hier_clust(k = 3, linkage_method = "single") %>% fit(~., mtcars) -# my_mod$fit$method -# translate_tidyclust(hier_clust(k = 3, linkage_method = "single")) -# -# #debugonce(tidyclust:::stats_hier_clust_predict) -# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) -# -# my_mod <- hier_clust(k = 3, linkage_method = "average") %>% fit(~., mtcars) -# -# #debugonce(tidyclust:::stats_hier_clust_predict) -# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) -# -# my_mod <- hier_clust(k = 3, linkage_method = "median") %>% fit(~., mtcars) -# -# #debugonce(tidyclust:::stats_hier_clust_predict) -# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) - -my_mod <- hier_clust(k = 3, linkage_method = "centroid") %>% fit(~., mtcars) - -#debugonce(tidyclust:::stats_hier_clust_predict) -tidyclust:::stats_hier_clust_predict(my_mod, mtcars) - -my_mod <- hier_clust(k = 3, linkage_method = "ward.D") %>% fit(~., mtcars) -# debugonce(extract_fit_summary.hclust) -# extract_fit_summary.hclust(my_mod) - -#debugonce(tidyclust:::stats_hier_clust_predict) -tidyclust:::stats_hier_clust_predict(my_mod$fit, mtcars) -predict(my_mod, mtcars) - -avg_silhouette(my_mod) +library(tidyverse) +library(tidyclust) +# +# my_mod <- hier_clust(k = 3) %>% fit(~., mtcars) +# +# #debugonce(tidyclust:::stats_hier_clust_predict) +# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) +# +# my_mod <- hier_clust(k = 3, linkage_method = "single") %>% fit(~., mtcars) +# my_mod$fit$method +# translate_tidyclust(hier_clust(k = 3, linkage_method = "single")) +# +# #debugonce(tidyclust:::stats_hier_clust_predict) +# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) +# +# my_mod <- hier_clust(k = 3, linkage_method = "average") %>% fit(~., mtcars) +# +# #debugonce(tidyclust:::stats_hier_clust_predict) +# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) +# +# my_mod <- hier_clust(k = 3, linkage_method = "median") %>% fit(~., mtcars) +# +# #debugonce(tidyclust:::stats_hier_clust_predict) +# tidyclust:::stats_hier_clust_predict(my_mod, mtcars) + +my_mod <- hier_clust(k = 3, linkage_method = "centroid") %>% fit(~., mtcars) + +#debugonce(tidyclust:::stats_hier_clust_predict) +tidyclust:::stats_hier_clust_predict(my_mod, mtcars) + +my_mod <- hier_clust(k = 3, linkage_method = "ward.D") %>% fit(~., mtcars) +# debugonce(extract_fit_summary.hclust) +# extract_fit_summary.hclust(my_mod) + +#debugonce(tidyclust:::stats_hier_clust_predict) +tidyclust:::stats_hier_clust_predict(my_mod$fit, mtcars) +predict(my_mod, mtcars) + +avg_silhouette(my_mod) diff --git a/dev/to do b/dev/to do index 2fe0da67..91e9e24c 100644 --- a/dev/to do +++ b/dev/to do @@ -1,21 +1,21 @@ -* similarity metrics -- for comparison to external variable -- for stability between runs -- for reconciliation - - -metrics for hclust -parameter object ofr hclust - -+ sticker - -kmeans tutorial - autoplot! -hclust tutorial - -tuning - -predicting and reconciling - - -later: -* augment with response variable, then compute metrics +* similarity metrics +- for comparison to external variable +- for stability between runs +- for reconciliation + + +metrics for hclust +parameter object ofr hclust + ++ sticker + +kmeans tutorial - autoplot! +hclust tutorial + +tuning + +predicting and reconciling + + +later: +* augment with response variable, then compute metrics diff --git a/man/augment.Rd b/man/augment.Rd index 4c27a1d6..8d197956 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -1,31 +1,31 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/augment.R -\name{augment.cluster_fit} -\alias{augment.cluster_fit} -\title{Augment data with predictions} -\usage{ -\method{augment}{cluster_fit}(x, new_data, ...) -} -\arguments{ -\item{x}{A \code{cluster_fit} object produced by \code{\link[=fit.cluster_spec]{fit.cluster_spec()}} or -\code{\link[=fit_xy.cluster_spec]{fit_xy.cluster_spec()}} .} - -\item{new_data}{A data frame or matrix.} - -\item{...}{Not currently used.} -} -\description{ -\code{augment()} will add column(s) for predictions to the given data. -} -\details{ -For partition models, a \code{.pred_cluster} column is added. -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - augment(new_data = mtcars) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augment.R +\name{augment.cluster_fit} +\alias{augment.cluster_fit} +\title{Augment data with predictions} +\usage{ +\method{augment}{cluster_fit}(x, new_data, ...) +} +\arguments{ +\item{x}{A \code{cluster_fit} object produced by \code{\link[=fit.cluster_spec]{fit.cluster_spec()}} or +\code{\link[=fit_xy.cluster_spec]{fit_xy.cluster_spec()}} .} + +\item{new_data}{A data frame or matrix.} + +\item{...}{Not currently used.} +} +\description{ +\code{augment()} will add column(s) for predictions to the given data. +} +\details{ +For partition models, a \code{.pred_cluster} column is added. +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + augment(new_data = mtcars) +} diff --git a/man/avg_silhouette.Rd b/man/avg_silhouette.Rd index f044428e..604a2d33 100644 --- a/man/avg_silhouette.Rd +++ b/man/avg_silhouette.Rd @@ -1,55 +1,55 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-silhouette.R -\name{avg_silhouette} -\alias{avg_silhouette} -\alias{avg_silhouette.cluster_fit} -\alias{avg_silhouette.workflow} -\alias{avg_silhouette_vec} -\title{Measures average silhouette across all observations} -\usage{ -avg_silhouette(object, ...) - -\method{avg_silhouette}{cluster_fit}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...) - -\method{avg_silhouette}{workflow}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...) - -avg_silhouette_vec( - object, - new_data = NULL, - dists = NULL, - dist_fun = Rfast::Dist, - ... -) -} -\arguments{ -\item{object}{A fitted kmeans tidyclust model} - -\item{...}{Other arguments passed to methods.} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.} - -\item{dist_fun}{A function for calculating distances between observations. -Defaults to Euclidean distance on processed data.} -} -\value{ -A double; the average silhouette. -} -\description{ -Measures average silhouette across all observations -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -dists <- mtcars \%>\% - as.matrix() \%>\% - dist() - -avg_silhouette(kmeans_fit, dists = dists) - -avg_silhouette_vec(kmeans_fit, dists = dists) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-silhouette.R +\name{avg_silhouette} +\alias{avg_silhouette} +\alias{avg_silhouette.cluster_fit} +\alias{avg_silhouette.workflow} +\alias{avg_silhouette_vec} +\title{Measures average silhouette across all observations} +\usage{ +avg_silhouette(object, ...) + +\method{avg_silhouette}{cluster_fit}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...) + +\method{avg_silhouette}{workflow}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...) + +avg_silhouette_vec( + object, + new_data = NULL, + dists = NULL, + dist_fun = Rfast::Dist, + ... +) +} +\arguments{ +\item{object}{A fitted kmeans tidyclust model} + +\item{...}{Other arguments passed to methods.} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.} + +\item{dist_fun}{A function for calculating distances between observations. +Defaults to Euclidean distance on processed data.} +} +\value{ +A double; the average silhouette. +} +\description{ +Measures average silhouette across all observations +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +dists <- mtcars \%>\% + as.matrix() \%>\% + dist() + +avg_silhouette(kmeans_fit, dists = dists) + +avg_silhouette_vec(kmeans_fit, dists = dists) +} diff --git a/man/extract_centroids.Rd b/man/extract_centroids.Rd index 330f0909..8a1d03c7 100644 --- a/man/extract_centroids.Rd +++ b/man/extract_centroids.Rd @@ -1,26 +1,26 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract_characterization.R -\name{extract_centroids} -\alias{extract_centroids} -\title{Extract clusters from model} -\usage{ -extract_centroids(object, ...) -} -\arguments{ -\item{object}{An cluster_spec object.} - -\item{...}{Other arguments passed to methods.} -} -\description{ -Extract clusters from model -} -\examples{ -set.seed(1234) -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - extract_centroids() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_characterization.R +\name{extract_centroids} +\alias{extract_centroids} +\title{Extract clusters from model} +\usage{ +extract_centroids(object, ...) +} +\arguments{ +\item{object}{An cluster_spec object.} + +\item{...}{Other arguments passed to methods.} +} +\description{ +Extract clusters from model +} +\examples{ +set.seed(1234) +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + extract_centroids() +} diff --git a/man/extract_cluster_assignment.Rd b/man/extract_cluster_assignment.Rd index 1d8eccbd..29fae8b8 100644 --- a/man/extract_cluster_assignment.Rd +++ b/man/extract_cluster_assignment.Rd @@ -1,25 +1,25 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract_assignment.R -\name{extract_cluster_assignment} -\alias{extract_cluster_assignment} -\title{Extract cluster assignments from model} -\usage{ -extract_cluster_assignment(object, ...) -} -\arguments{ -\item{object}{An cluster_spec object.} - -\item{...}{Other arguments passed to methods.} -} -\description{ -Extract cluster assignments from model -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - extract_cluster_assignment() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_assignment.R +\name{extract_cluster_assignment} +\alias{extract_cluster_assignment} +\title{Extract cluster assignments from model} +\usage{ +extract_cluster_assignment(object, ...) +} +\arguments{ +\item{object}{An cluster_spec object.} + +\item{...}{Other arguments passed to methods.} +} +\description{ +Extract cluster assignments from model +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + extract_cluster_assignment() +} diff --git a/man/extract_fit_summary.Rd b/man/extract_fit_summary.Rd index 56788562..9df27b2e 100644 --- a/man/extract_fit_summary.Rd +++ b/man/extract_fit_summary.Rd @@ -1,28 +1,28 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/extract_summary.R -\name{extract_fit_summary} -\alias{extract_fit_summary} -\title{S3 method to get fitted model summary info depending on engine} -\usage{ -extract_fit_summary(object, ...) -} -\arguments{ -\item{object}{a fitted cluster_spec object} - -\item{...}{other arguments passed to methods} -} -\value{ -A list with various summary elements -} -\description{ -S3 method to get fitted model summary info depending on engine -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - extract_fit_summary() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_summary.R +\name{extract_fit_summary} +\alias{extract_fit_summary} +\title{S3 method to get fitted model summary info depending on engine} +\usage{ +extract_fit_summary(object, ...) +} +\arguments{ +\item{object}{a fitted cluster_spec object} + +\item{...}{other arguments passed to methods} +} +\value{ +A list with various summary elements +} +\description{ +S3 method to get fitted model summary info depending on engine +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + extract_fit_summary() +} diff --git a/man/figures/logo.svg b/man/figures/logo.svg index 045b23b0..6129b854 100644 --- a/man/figures/logo.svg +++ b/man/figures/logo.svg @@ -1,646 +1,646 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/man/finalize_model_tidyclust.Rd b/man/finalize_model_tidyclust.Rd index 443f9e6f..c5281b3c 100644 --- a/man/finalize_model_tidyclust.Rd +++ b/man/finalize_model_tidyclust.Rd @@ -1,35 +1,35 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/finalize.R -\name{finalize_model_tidyclust} -\alias{finalize_model_tidyclust} -\alias{finalize_workflow_tidyclust} -\title{Splice final parameters into objects} -\usage{ -finalize_model_tidyclust(x, parameters) - -finalize_workflow_tidyclust(x, parameters) -} -\arguments{ -\item{x}{A recipe, \code{parsnip} model specification, or workflow.} - -\item{parameters}{A list or 1-row tibble of parameter values. Note that the -column names of the tibble should be the \code{id} fields attached to \code{tune()}. -For example, in the \code{Examples} section below, the model has \code{tune("K")}. In -this case, the parameter tibble should be "K" and not "neighbors".} -} -\value{ -An updated version of \code{x}. -} -\description{ -The \verb{finalize_*} functions take a list or tibble of tuning parameter values and -update objects with those values. -} -\examples{ -kmeans_spec <- k_means(num_clusters = tune()) - -best_params <- data.frame(num_clusters = 5) -best_params - -kmeans_spec -finalize_model_tidyclust(kmeans_spec, best_params) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/finalize.R +\name{finalize_model_tidyclust} +\alias{finalize_model_tidyclust} +\alias{finalize_workflow_tidyclust} +\title{Splice final parameters into objects} +\usage{ +finalize_model_tidyclust(x, parameters) + +finalize_workflow_tidyclust(x, parameters) +} +\arguments{ +\item{x}{A recipe, \code{parsnip} model specification, or workflow.} + +\item{parameters}{A list or 1-row tibble of parameter values. Note that the +column names of the tibble should be the \code{id} fields attached to \code{tune()}. +For example, in the \code{Examples} section below, the model has \code{tune("K")}. In +this case, the parameter tibble should be "K" and not "neighbors".} +} +\value{ +An updated version of \code{x}. +} +\description{ +The \verb{finalize_*} functions take a list or tibble of tuning parameter values and +update objects with those values. +} +\examples{ +kmeans_spec <- k_means(num_clusters = tune()) + +best_params <- data.frame(num_clusters = 5) +best_params + +kmeans_spec +finalize_model_tidyclust(kmeans_spec, best_params) +} diff --git a/man/fit.Rd b/man/fit.Rd index 3437b3b2..c9ecc3da 100644 --- a/man/fit.Rd +++ b/man/fit.Rd @@ -1,105 +1,105 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fit.R -\name{fit.cluster_spec} -\alias{fit.cluster_spec} -\alias{fit_xy.cluster_spec} -\title{Fit a Model Specification to a Data Set} -\usage{ -\method{fit}{cluster_spec}(object, formula, data, control = control_cluster(), ...) - -\method{fit_xy}{cluster_spec}(object, x, case_weights = NULL, control = control_cluster(), ...) -} -\arguments{ -\item{object}{An object of class \code{cluster_spec} that has a chosen engine (via -\code{\link[=set_engine]{set_engine()}}).} - -\item{formula}{An object of class \code{formula} (or one that can be coerced to -that class): a symbolic description of the model to be fitted.} - -\item{data}{Optional, depending on the interface (see Details below). A data -frame containing all relevant variables (e.g. predictors, case weights, -etc). Note: when needed, a \emph{named argument} should be used.} - -\item{control}{A named list with elements \code{verbosity} and \code{catch}. See -\code{\link[=control_cluster]{control_cluster()}}.} - -\item{...}{Not currently used; values passed here will be ignored. Other -options required to fit the model should be passed using -\code{set_engine()}.} - -\item{x}{A matrix, sparse matrix, or data frame of predictors. Only some -models have support for sparse matrix input. See -\code{tidyclust::get_encoding_tidyclust()} for details. \code{x} should have column names.} - -\item{case_weights}{An optional classed vector of numeric case weights. This -must return \code{TRUE} when \code{\link[hardhat:is_case_weights]{hardhat::is_case_weights()}} is run on it. See -\code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}} and \code{\link[hardhat:importance_weights]{hardhat::importance_weights()}} for -examples.} -} -\value{ -A \code{cluster_fit} object that contains several elements: -\itemize{ -\item \code{spec}: The model specification object (\code{object} in the -call to \code{fit}) -\item \code{fit}: when the model is executed without error, this is the -model object. Otherwise, it is a \code{try-error} -object with the error message. -\item \code{preproc}: any objects needed to convert between a formula and -non-formula interface -(such as the \code{terms} object) -} -The return value will also have a class related to the fitted model (e.g. -\code{"_kmeans"}) before the base class of \code{"cluster_fit"}. -} -\description{ -\code{fit()} and \code{fit_xy()} take a model specification, translate_tidyclust the -required code by substituting arguments, and execute the model fit routine. -} -\details{ -\code{fit()} and \code{fit_xy()} substitute the current arguments in the -model specification into the computational engine's code, check them for -validity, then fit the model using the data and the engine-specific code. -Different model functions have different interfaces (e.g. formula or -\code{x}/\code{y}) and these functions translate_tidyclust between the interface used -when \code{fit()} or \code{fit_xy()} was invoked and the one required by the -underlying model. - -When possible, these functions attempt to avoid making copies of the data. -For example, if the underlying model uses a formula and \code{fit()} is invoked, -the original data are references when the model is fit. However, if the -underlying model uses something else, such as \code{x}/\code{y}, the formula is -evaluated and the data are converted to the required format. In this case, -any calls in the resulting model objects reference the temporary objects -used to fit the model. - -If the model engine has not been set, the model's default engine will be -used (as discussed on each model page). If the \code{verbosity} option of -\code{\link[=control_cluster]{control_cluster()}} is greater than zero, a warning will be produced. - -If you would like to use an alternative method for generating contrasts -when supplying a formula to \code{fit()}, set the global option \code{contrasts} to -your preferred method. For example, you might set it to: \code{options(contrasts = c(unordered = "contr.helmert", ordered = "contr.poly"))}. See the help -page for \code{\link[stats:contrast]{stats::contr.treatment()}} for more possible contrast types. -} -\examples{ -library(dplyr) - -kmeans_mod <- k_means(num_clusters = 5) - -using_formula <- - kmeans_mod \%>\% - set_engine("stats") \%>\% - fit(~., data = mtcars) - -using_x <- - kmeans_mod \%>\% - set_engine("stats") \%>\% - fit_xy(x = mtcars) - -using_formula -using_x -} -\seealso{ -\code{\link[=set_engine]{set_engine()}}, \code{\link[=control_cluster]{control_cluster()}}, \code{cluster_spec}, -\code{cluster_fit} -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit.R +\name{fit.cluster_spec} +\alias{fit.cluster_spec} +\alias{fit_xy.cluster_spec} +\title{Fit a Model Specification to a Data Set} +\usage{ +\method{fit}{cluster_spec}(object, formula, data, control = control_cluster(), ...) + +\method{fit_xy}{cluster_spec}(object, x, case_weights = NULL, control = control_cluster(), ...) +} +\arguments{ +\item{object}{An object of class \code{cluster_spec} that has a chosen engine (via +\code{\link[=set_engine]{set_engine()}}).} + +\item{formula}{An object of class \code{formula} (or one that can be coerced to +that class): a symbolic description of the model to be fitted.} + +\item{data}{Optional, depending on the interface (see Details below). A data +frame containing all relevant variables (e.g. predictors, case weights, +etc). Note: when needed, a \emph{named argument} should be used.} + +\item{control}{A named list with elements \code{verbosity} and \code{catch}. See +\code{\link[=control_cluster]{control_cluster()}}.} + +\item{...}{Not currently used; values passed here will be ignored. Other +options required to fit the model should be passed using +\code{set_engine()}.} + +\item{x}{A matrix, sparse matrix, or data frame of predictors. Only some +models have support for sparse matrix input. See +\code{tidyclust::get_encoding_tidyclust()} for details. \code{x} should have column names.} + +\item{case_weights}{An optional classed vector of numeric case weights. This +must return \code{TRUE} when \code{\link[hardhat:is_case_weights]{hardhat::is_case_weights()}} is run on it. See +\code{\link[hardhat:frequency_weights]{hardhat::frequency_weights()}} and \code{\link[hardhat:importance_weights]{hardhat::importance_weights()}} for +examples.} +} +\value{ +A \code{cluster_fit} object that contains several elements: +\itemize{ +\item \code{spec}: The model specification object (\code{object} in the +call to \code{fit}) +\item \code{fit}: when the model is executed without error, this is the +model object. Otherwise, it is a \code{try-error} +object with the error message. +\item \code{preproc}: any objects needed to convert between a formula and +non-formula interface +(such as the \code{terms} object) +} +The return value will also have a class related to the fitted model (e.g. +\code{"_kmeans"}) before the base class of \code{"cluster_fit"}. +} +\description{ +\code{fit()} and \code{fit_xy()} take a model specification, translate_tidyclust the +required code by substituting arguments, and execute the model fit routine. +} +\details{ +\code{fit()} and \code{fit_xy()} substitute the current arguments in the +model specification into the computational engine's code, check them for +validity, then fit the model using the data and the engine-specific code. +Different model functions have different interfaces (e.g. formula or +\code{x}/\code{y}) and these functions translate_tidyclust between the interface used +when \code{fit()} or \code{fit_xy()} was invoked and the one required by the +underlying model. + +When possible, these functions attempt to avoid making copies of the data. +For example, if the underlying model uses a formula and \code{fit()} is invoked, +the original data are references when the model is fit. However, if the +underlying model uses something else, such as \code{x}/\code{y}, the formula is +evaluated and the data are converted to the required format. In this case, +any calls in the resulting model objects reference the temporary objects +used to fit the model. + +If the model engine has not been set, the model's default engine will be +used (as discussed on each model page). If the \code{verbosity} option of +\code{\link[=control_cluster]{control_cluster()}} is greater than zero, a warning will be produced. + +If you would like to use an alternative method for generating contrasts +when supplying a formula to \code{fit()}, set the global option \code{contrasts} to +your preferred method. For example, you might set it to: \code{options(contrasts = c(unordered = "contr.helmert", ordered = "contr.poly"))}. See the help +page for \code{\link[stats:contrast]{stats::contr.treatment()}} for more possible contrast types. +} +\examples{ +library(dplyr) + +kmeans_mod <- k_means(num_clusters = 5) + +using_formula <- + kmeans_mod \%>\% + set_engine("stats") \%>\% + fit(~., data = mtcars) + +using_x <- + kmeans_mod \%>\% + set_engine("stats") \%>\% + fit_xy(x = mtcars) + +using_formula +using_x +} +\seealso{ +\code{\link[=set_engine]{set_engine()}}, \code{\link[=control_cluster]{control_cluster()}}, \code{cluster_spec}, +\code{cluster_fit} +} diff --git a/man/hclust_fit.Rd b/man/hclust_fit.Rd index ec40ab2a..79361804 100644 --- a/man/hclust_fit.Rd +++ b/man/hclust_fit.Rd @@ -1,36 +1,36 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/hier_clust.R -\name{hclust_fit} -\alias{hclust_fit} -\title{Simple Wrapper around hclust function} -\usage{ -hclust_fit( - x, - k = NULL, - cut_height = NULL, - linkage_method = NULL, - dist_fun = Rfast::Dist -) -} -\arguments{ -\item{x}{matrix or data frame} - -\item{k}{the number of clusters} - -\item{linkage_method}{the agglomeration method to be used. This should be (an -unambiguous abbreviation of) one of \code{"ward.D"}, \code{"ward.D2"}, \code{"single"}, -\code{"complete"}, \code{"average"} (= UPGMA), \code{"mcquitty"} (= WPGMA), \code{"median"} -(= WPGMC) or \code{"centroid"} (= UPGMC).} - -\item{dist_fun}{A distance function to use} - -\item{h}{the height to cut the dendrogram} -} -\value{ -A dendrogram -} -\description{ -This wrapper prepares the data into a distance matrix to send to -\code{stats::hclust} and retains the parameters \code{k} or \code{h} as an attribute. -} -\keyword{internal} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/hier_clust.R +\name{hclust_fit} +\alias{hclust_fit} +\title{Simple Wrapper around hclust function} +\usage{ +hclust_fit( + x, + k = NULL, + cut_height = NULL, + linkage_method = NULL, + dist_fun = Rfast::Dist +) +} +\arguments{ +\item{x}{matrix or data frame} + +\item{k}{the number of clusters} + +\item{linkage_method}{the agglomeration method to be used. This should be (an +unambiguous abbreviation of) one of \code{"ward.D"}, \code{"ward.D2"}, \code{"single"}, +\code{"complete"}, \code{"average"} (= UPGMA), \code{"mcquitty"} (= WPGMA), \code{"median"} +(= WPGMC) or \code{"centroid"} (= UPGMC).} + +\item{dist_fun}{A distance function to use} + +\item{h}{the height to cut the dendrogram} +} +\value{ +A dendrogram +} +\description{ +This wrapper prepares the data into a distance matrix to send to +\code{stats::hclust} and retains the parameters \code{k} or \code{h} as an attribute. +} +\keyword{internal} diff --git a/man/hier_clust.Rd b/man/hier_clust.Rd index a6172aee..4c3b7d74 100644 --- a/man/hier_clust.Rd +++ b/man/hier_clust.Rd @@ -1,43 +1,43 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/hier_clust.R -\name{hier_clust} -\alias{hier_clust} -\title{Hierarchical (Agglomerative) Clustering} -\usage{ -hier_clust( - mode = "partition", - engine = "stats", - k = NULL, - h = NULL, - linkage_method = "complete" -) -} -\arguments{ -\item{mode}{A single character string for the type of model. -The only possible value for this model is "partition".} - -\item{engine}{A single character string specifying what computational engine -to use for fitting. Possible engines are listed below. The default for this -model is \code{"stats"}.} - -\item{k}{Positive integer, number of clusters in model (optional).} - -\item{h}{Positive double, height at which to cut dendrogram to obtain cluster -assignments (only used if \code{k} is \code{NULL})} - -\item{linkage_method}{the agglomeration method to be used. This should be (an -unambiguous abbreviation of) one of \code{"ward.D"}, \code{"ward.D2"}, \code{"single"}, -\code{"complete"}, \code{"average"} (= UPGMA), \code{"mcquitty"} (= WPGMA), \code{"median"} -(= WPGMC) or \code{"centroid"} (= UPGMC).} - -\item{dist_fun}{A distance function to use} -} -\description{ -\code{hier_clust()} defines a model that fits clusters based on a distance-based -dendrogram -} -\examples{ -# show_engines("hier_clust") - -hier_clust() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/hier_clust.R +\name{hier_clust} +\alias{hier_clust} +\title{Hierarchical (Agglomerative) Clustering} +\usage{ +hier_clust( + mode = "partition", + engine = "stats", + k = NULL, + h = NULL, + linkage_method = "complete" +) +} +\arguments{ +\item{mode}{A single character string for the type of model. +The only possible value for this model is "partition".} + +\item{engine}{A single character string specifying what computational engine +to use for fitting. Possible engines are listed below. The default for this +model is \code{"stats"}.} + +\item{k}{Positive integer, number of clusters in model (optional).} + +\item{h}{Positive double, height at which to cut dendrogram to obtain cluster +assignments (only used if \code{k} is \code{NULL})} + +\item{linkage_method}{the agglomeration method to be used. This should be (an +unambiguous abbreviation of) one of \code{"ward.D"}, \code{"ward.D2"}, \code{"single"}, +\code{"complete"}, \code{"average"} (= UPGMA), \code{"mcquitty"} (= WPGMA), \code{"median"} +(= WPGMC) or \code{"centroid"} (= UPGMC).} + +\item{dist_fun}{A distance function to use} +} +\description{ +\code{hier_clust()} defines a model that fits clusters based on a distance-based +dendrogram +} +\examples{ +# show_engines("hier_clust") + +hier_clust() +} diff --git a/man/k_means.Rd b/man/k_means.Rd index 8ccfae5c..2600e65f 100644 --- a/man/k_means.Rd +++ b/man/k_means.Rd @@ -1,27 +1,27 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/k_means.R -\name{k_means} -\alias{k_means} -\title{K-Means} -\usage{ -k_means(mode = "partition", engine = "stats", num_clusters = NULL) -} -\arguments{ -\item{mode}{A single character string for the type of model. -The only possible value for this model is "partition".} - -\item{engine}{A single character string specifying what computational engine -to use for fitting. Possible engines are listed below. The default for this -model is \code{"stats"}.} - -\item{num_clusters}{Positive integer, number of clusters in model.} -} -\description{ -\code{k_means()} defines a model that fits clusters based on distances to a number -of centers. -} -\examples{ -# show_engines("k_means") - -k_means() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/k_means.R +\name{k_means} +\alias{k_means} +\title{K-Means} +\usage{ +k_means(mode = "partition", engine = "stats", num_clusters = NULL) +} +\arguments{ +\item{mode}{A single character string for the type of model. +The only possible value for this model is "partition".} + +\item{engine}{A single character string specifying what computational engine +to use for fitting. Possible engines are listed below. The default for this +model is \code{"stats"}.} + +\item{num_clusters}{Positive integer, number of clusters in model.} +} +\description{ +\code{k_means()} defines a model that fits clusters based on distances to a number +of centers. +} +\examples{ +# show_engines("k_means") + +k_means() +} diff --git a/man/num_clusters.Rd b/man/num_clusters.Rd index c8129e2c..47a5dba3 100644 --- a/man/num_clusters.Rd +++ b/man/num_clusters.Rd @@ -1,24 +1,24 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dials.R -\name{num_clusters} -\alias{num_clusters} -\title{Number of Clusters} -\usage{ -num_clusters(range = c(1L, 10L), trans = NULL) -} -\arguments{ -\item{range}{A two-element vector holding the \emph{defaults} for the smallest and -largest possible values, respectively. If a transformation is specified, -these values should be in the \emph{transformed units}.} - -\item{trans}{A \code{trans} object from the \code{scales} package, such as -\code{scales::log10_trans()} or \code{scales::reciprocal_trans()}. If not provided, -the default is used which matches the units used in \code{range}. If no -transformation, \code{NULL}.} -} -\description{ -Number of Clusters -} -\examples{ -num_clusters() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dials.R +\name{num_clusters} +\alias{num_clusters} +\title{Number of Clusters} +\usage{ +num_clusters(range = c(1L, 10L), trans = NULL) +} +\arguments{ +\item{range}{A two-element vector holding the \emph{defaults} for the smallest and +largest possible values, respectively. If a transformation is specified, +these values should be in the \emph{transformed units}.} + +\item{trans}{A \code{trans} object from the \code{scales} package, such as +\code{scales::log10_trans()} or \code{scales::reciprocal_trans()}. If not provided, +the default is used which matches the units used in \code{range}. If no +transformation, \code{NULL}.} +} +\description{ +Number of Clusters +} +\examples{ +num_clusters() +} diff --git a/man/predict.cluster_fit.Rd b/man/predict.cluster_fit.Rd index dd8ff763..a09c2a66 100644 --- a/man/predict.cluster_fit.Rd +++ b/man/predict.cluster_fit.Rd @@ -1,68 +1,68 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict.R, R/predict_raw.R -\name{predict.cluster_fit} -\alias{predict.cluster_fit} -\alias{predict_raw.cluster_fit} -\title{Model predictions} -\usage{ -\method{predict}{cluster_fit}(object, new_data, type = NULL, opts = list(), ...) - -\method{predict_raw}{cluster_fit}(object, new_data, opts = list(), ...) -} -\arguments{ -\item{object}{An object of class \code{cluster_fit}} - -\item{new_data}{A rectangular data object, such as a data frame.} - -\item{type}{A single character value or \code{NULL}. Possible values -are "cluster", or "raw". When \code{NULL}, \code{predict()} will choose an -appropriate value based on the model's mode.} - -\item{opts}{A list of optional arguments to the underlying -predict function that will be used when \code{type = "raw"}. The -list should not include options for the model object or the -new data being predicted.} - -\item{...}{Arguments to the underlying model's prediction -function cannot be passed here (see \code{opts}).} -} -\value{ -With the exception of \code{type = "raw"}, the results of -\code{predict.cluster_fit()} will be a tibble as many rows in the output -as there are rows in \code{new_data} and the column names will be -predictable. - -For clustering results the tibble will have a \code{.pred_cluster} column. - -Using \code{type = "raw"} with \code{predict.cluster_fit()} will return -the unadulterated results of the prediction function. - -When the model fit failed and the error was captured, the -\code{predict()} function will return the same structure as above but -filled with missing values. This does not currently work for -multivariate models. -} -\description{ -Apply a model to create different types of predictions. -\code{predict()} can be used for all types of models and uses the -"type" argument for more specificity. -} -\details{ -If "type" is not supplied to \code{predict()}, then a choice -is made: -\itemize{ -\item \code{type = "cluster"} for clustering models -} - -\code{predict()} is designed to provide a tidy result (see "Value" -section below) in a tibble output format. -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - predict(new_data = mtcars) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/predict.R, R/predict_raw.R +\name{predict.cluster_fit} +\alias{predict.cluster_fit} +\alias{predict_raw.cluster_fit} +\title{Model predictions} +\usage{ +\method{predict}{cluster_fit}(object, new_data, type = NULL, opts = list(), ...) + +\method{predict_raw}{cluster_fit}(object, new_data, opts = list(), ...) +} +\arguments{ +\item{object}{An object of class \code{cluster_fit}} + +\item{new_data}{A rectangular data object, such as a data frame.} + +\item{type}{A single character value or \code{NULL}. Possible values +are "cluster", or "raw". When \code{NULL}, \code{predict()} will choose an +appropriate value based on the model's mode.} + +\item{opts}{A list of optional arguments to the underlying +predict function that will be used when \code{type = "raw"}. The +list should not include options for the model object or the +new data being predicted.} + +\item{...}{Arguments to the underlying model's prediction +function cannot be passed here (see \code{opts}).} +} +\value{ +With the exception of \code{type = "raw"}, the results of +\code{predict.cluster_fit()} will be a tibble as many rows in the output +as there are rows in \code{new_data} and the column names will be +predictable. + +For clustering results the tibble will have a \code{.pred_cluster} column. + +Using \code{type = "raw"} with \code{predict.cluster_fit()} will return +the unadulterated results of the prediction function. + +When the model fit failed and the error was captured, the +\code{predict()} function will return the same structure as above but +filled with missing values. This does not currently work for +multivariate models. +} +\description{ +Apply a model to create different types of predictions. +\code{predict()} can be used for all types of models and uses the +"type" argument for more specificity. +} +\details{ +If "type" is not supplied to \code{predict()}, then a choice +is made: +\itemize{ +\item \code{type = "cluster"} for clustering models +} + +\code{predict()} is designed to provide a tidy result (see "Value" +section below) in a tibble output format. +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + predict(new_data = mtcars) +} diff --git a/man/reexports.Rd b/man/reexports.Rd index 57d77d70..f59baeb2 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,41 +1,41 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R -\docType{import} -\name{reexports} -\alias{reexports} -\alias{\%>\%} -\alias{fit} -\alias{tidy} -\alias{glance} -\alias{augment} -\alias{fit_xy} -\alias{extract_parameter_set_dials} -\alias{tune} -\alias{extract_spec_parsnip} -\alias{min_grid} -\alias{extract_preprocessor} -\alias{extract_fit_parsnip} -\alias{load_pkgs} -\alias{required_pkgs} -\alias{predict_raw} -\alias{set_args} -\alias{set_engine} -\alias{set_mode} -\title{Objects exported from other packages} -\keyword{internal} -\description{ -These objects are imported from other packages. Follow the links -below to see their documentation. - -\describe{ - \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{fit_xy}}, \code{\link[generics]{glance}}, \code{\link[generics]{min_grid}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}} - - \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_parsnip}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_preprocessor}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{tune}}} - - \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} - - \item{parsnip}{\code{\link[parsnip:predict.model_fit]{predict_raw}}, \code{\link[parsnip]{set_args}}, \code{\link[parsnip]{set_engine}}, \code{\link[parsnip:set_args]{set_mode}}} - - \item{tune}{\code{\link[tune]{load_pkgs}}} -}} - +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/reexports.R +\docType{import} +\name{reexports} +\alias{reexports} +\alias{\%>\%} +\alias{fit} +\alias{tidy} +\alias{glance} +\alias{augment} +\alias{fit_xy} +\alias{extract_parameter_set_dials} +\alias{tune} +\alias{extract_spec_parsnip} +\alias{min_grid} +\alias{extract_preprocessor} +\alias{extract_fit_parsnip} +\alias{load_pkgs} +\alias{required_pkgs} +\alias{predict_raw} +\alias{set_args} +\alias{set_engine} +\alias{set_mode} +\title{Objects exported from other packages} +\keyword{internal} +\description{ +These objects are imported from other packages. Follow the links +below to see their documentation. + +\describe{ + \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{fit_xy}}, \code{\link[generics]{glance}}, \code{\link[generics]{min_grid}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}} + + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_parsnip}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_preprocessor}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{tune}}} + + \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} + + \item{parsnip}{\code{\link[parsnip:predict.model_fit]{predict_raw}}, \code{\link[parsnip]{set_args}}, \code{\link[parsnip]{set_engine}}, \code{\link[parsnip:set_args]{set_mode}}} + + \item{tune}{\code{\link[tune]{load_pkgs}}} +}} + diff --git a/man/silhouettes.Rd b/man/silhouettes.Rd index c7368d20..5eea6d14 100644 --- a/man/silhouettes.Rd +++ b/man/silhouettes.Rd @@ -1,36 +1,36 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-silhouette.R -\name{silhouettes} -\alias{silhouettes} -\title{Measures silhouettes between clusters} -\usage{ -silhouettes(object, new_data = NULL, dists = NULL, dist_fun = Rfast::Dist) -} -\arguments{ -\item{object}{A fitted tidyclust model} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.} - -\item{dist_fun}{A function for calculating distances between observations. -Defaults to Euclidean distance on processed data.} -} -\value{ -A tibble giving the silhouettes for each observation. -} -\description{ -Measures silhouettes between clusters -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -dists <- mtcars \%>\% - as.matrix() \%>\% - dist() - -silhouettes(kmeans_fit, dists = dists) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-silhouette.R +\name{silhouettes} +\alias{silhouettes} +\title{Measures silhouettes between clusters} +\usage{ +silhouettes(object, new_data = NULL, dists = NULL, dist_fun = Rfast::Dist) +} +\arguments{ +\item{object}{A fitted tidyclust model} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.} + +\item{dist_fun}{A function for calculating distances between observations. +Defaults to Euclidean distance on processed data.} +} +\value{ +A tibble giving the silhouettes for each observation. +} +\description{ +Measures silhouettes between clusters +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +dists <- mtcars \%>\% + as.matrix() \%>\% + dist() + +silhouettes(kmeans_fit, dists = dists) +} diff --git a/man/sse_ratio.Rd b/man/sse_ratio.Rd index d3df4a89..0482553e 100644 --- a/man/sse_ratio.Rd +++ b/man/sse_ratio.Rd @@ -1,39 +1,39 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-sse.R -\name{sse_ratio} -\alias{sse_ratio} -\alias{sse_ratio.cluster_fit} -\alias{sse_ratio.workflow} -\alias{sse_ratio_vec} -\title{Compute the ratio of the WSS to the total SSE} -\usage{ -sse_ratio(object, ...) - -\method{sse_ratio}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) - -\method{sse_ratio}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) - -sse_ratio_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) -} -\arguments{ -\item{object}{A fitted kmeans tidyclust model} - -\item{...}{Other arguments passed to methods.} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dist_fun}{A function for calculating distances to centroids. Defaults -to Euclidean distance on processed data.} -} -\description{ -Compute the ratio of the WSS to the total SSE -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - sse_ratio() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-sse.R +\name{sse_ratio} +\alias{sse_ratio} +\alias{sse_ratio.cluster_fit} +\alias{sse_ratio.workflow} +\alias{sse_ratio_vec} +\title{Compute the ratio of the WSS to the total SSE} +\usage{ +sse_ratio(object, ...) + +\method{sse_ratio}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) + +\method{sse_ratio}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) + +sse_ratio_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +} +\arguments{ +\item{object}{A fitted kmeans tidyclust model} + +\item{...}{Other arguments passed to methods.} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dist_fun}{A function for calculating distances to centroids. Defaults +to Euclidean distance on processed data.} +} +\description{ +Compute the ratio of the WSS to the total SSE +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + sse_ratio() +} diff --git a/man/tidyclust_update.Rd b/man/tidyclust_update.Rd index 6b635ac7..3485f2bf 100644 --- a/man/tidyclust_update.Rd +++ b/man/tidyclust_update.Rd @@ -1,42 +1,42 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/k_means.R, R/update.R -\name{update.k_means} -\alias{update.k_means} -\alias{tidyclust_update} -\title{Update a cluster specification} -\usage{ -\method{update}{k_means}(object, parameters = NULL, num_clusters = NULL, fresh = FALSE, ...) -} -\arguments{ -\item{object}{A cluster specification.} - -\item{parameters}{A 1-row tibble or named list with \emph{main} -parameters to update. Use \strong{either} \code{parameters} \strong{or} the main arguments -directly when updating. If the main arguments are used, -these will supersede the values in \code{parameters}. Also, using -engine arguments in this object will result in an error.} - -\item{num_clusters}{Positive integer, number of clusters in model.} - -\item{fresh}{A logical for whether the arguments should be -modified in-place or replaced wholesale.} - -\item{...}{Not used for \code{update()}.} -} -\value{ -An updated cluster specification. -} -\description{ -If parameters of a cluster specification need to be modified, \code{update()} can -be used in lieu of recreating the object from scratch. -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) -kmeans_spec -update(kmeans_spec, num_clusters = 1) -update(kmeans_spec, num_clusters = 1, fresh = TRUE) - -param_values <- tibble::tibble(num_clusters = 10) - -kmeans_spec \%>\% update(param_values) -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/k_means.R, R/update.R +\name{update.k_means} +\alias{update.k_means} +\alias{tidyclust_update} +\title{Update a cluster specification} +\usage{ +\method{update}{k_means}(object, parameters = NULL, num_clusters = NULL, fresh = FALSE, ...) +} +\arguments{ +\item{object}{A cluster specification.} + +\item{parameters}{A 1-row tibble or named list with \emph{main} +parameters to update. Use \strong{either} \code{parameters} \strong{or} the main arguments +directly when updating. If the main arguments are used, +these will supersede the values in \code{parameters}. Also, using +engine arguments in this object will result in an error.} + +\item{num_clusters}{Positive integer, number of clusters in model.} + +\item{fresh}{A logical for whether the arguments should be +modified in-place or replaced wholesale.} + +\item{...}{Not used for \code{update()}.} +} +\value{ +An updated cluster specification. +} +\description{ +If parameters of a cluster specification need to be modified, \code{update()} can +be used in lieu of recreating the object from scratch. +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) +kmeans_spec +update(kmeans_spec, num_clusters = 1) +update(kmeans_spec, num_clusters = 1, fresh = TRUE) + +param_values <- tibble::tibble(num_clusters = 10) + +kmeans_spec \%>\% update(param_values) +} diff --git a/man/tot_sse.Rd b/man/tot_sse.Rd index 1adf5983..cf778c00 100644 --- a/man/tot_sse.Rd +++ b/man/tot_sse.Rd @@ -1,42 +1,42 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-sse.R -\name{tot_sse} -\alias{tot_sse} -\alias{tot_sse.cluster_fit} -\alias{tot_sse.workflow} -\alias{tot_sse_vec} -\title{Compute the total sum of squares} -\usage{ -tot_sse(object, ...) - -\method{tot_sse}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) - -\method{tot_sse}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) - -tot_sse_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) -} -\arguments{ -\item{object}{A fitted kmeans tidyclust model} - -\item{...}{Other arguments passed to methods.} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dist_fun}{A function for calculating distances to centroids. Defaults -to Euclidean distance on processed data.} -} -\description{ -Compute the total sum of squares -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - tot_sse() - -kmeans_fit \%>\% - tot_sse_vec() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-sse.R +\name{tot_sse} +\alias{tot_sse} +\alias{tot_sse.cluster_fit} +\alias{tot_sse.workflow} +\alias{tot_sse_vec} +\title{Compute the total sum of squares} +\usage{ +tot_sse(object, ...) + +\method{tot_sse}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) + +\method{tot_sse}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) + +tot_sse_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +} +\arguments{ +\item{object}{A fitted kmeans tidyclust model} + +\item{...}{Other arguments passed to methods.} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dist_fun}{A function for calculating distances to centroids. Defaults +to Euclidean distance on processed data.} +} +\description{ +Compute the total sum of squares +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + tot_sse() + +kmeans_fit \%>\% + tot_sse_vec() +} diff --git a/man/tot_wss.Rd b/man/tot_wss.Rd index 4dcb80de..71665d19 100644 --- a/man/tot_wss.Rd +++ b/man/tot_wss.Rd @@ -1,42 +1,42 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-sse.R -\name{tot_wss} -\alias{tot_wss} -\alias{tot_wss.cluster_fit} -\alias{tot_wss.workflow} -\alias{tot_wss_vec} -\title{Compute the sum of within-cluster SSE} -\usage{ -tot_wss(object, ...) - -\method{tot_wss}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) - -\method{tot_wss}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) - -tot_wss_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) -} -\arguments{ -\item{object}{A fitted kmeans tidyclust model} - -\item{...}{Other arguments passed to methods.} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dist_fun}{A function for calculating distances to centroids. Defaults -to Euclidean distance on processed data.} -} -\description{ -Compute the sum of within-cluster SSE -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - tot_wss() - -kmeans_fit \%>\% - tot_wss_vec() -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-sse.R +\name{tot_wss} +\alias{tot_wss} +\alias{tot_wss.cluster_fit} +\alias{tot_wss.workflow} +\alias{tot_wss_vec} +\title{Compute the sum of within-cluster SSE} +\usage{ +tot_wss(object, ...) + +\method{tot_wss}{cluster_fit}(object, new_data = NULL, dist_fun = NULL, ...) + +\method{tot_wss}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) + +tot_wss_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +} +\arguments{ +\item{object}{A fitted kmeans tidyclust model} + +\item{...}{Other arguments passed to methods.} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dist_fun}{A function for calculating distances to centroids. Defaults +to Euclidean distance on processed data.} +} +\description{ +Compute the sum of within-cluster SSE +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + tot_wss() + +kmeans_fit \%>\% + tot_wss_vec() +} diff --git a/man/translate_tidyclust.Rd b/man/translate_tidyclust.Rd index ed2f796c..28d9f0dc 100644 --- a/man/translate_tidyclust.Rd +++ b/man/translate_tidyclust.Rd @@ -1,43 +1,43 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/translate.R -\name{translate_tidyclust} -\alias{translate_tidyclust} -\alias{translate_tidyclust.default} -\title{Resolve a Model Specification for a Computational Engine} -\usage{ -translate_tidyclust(x, ...) - -\method{translate_tidyclust}{default}(x, engine = x$engine, ...) -} -\arguments{ -\item{x}{A model specification.} - -\item{...}{Not currently used.} - -\item{engine}{The computational engine for the model (see \code{?set_engine}).} -} -\description{ -\code{translate_tidyclust()} will translate_tidyclust a model specification into a code -object that is specific to a particular engine (e.g. R package). -It translate_tidyclusts generic parameters to their counterparts. -} -\details{ -\code{translate_tidyclust()} produces a \emph{template} call that lacks the specific -argument values (such as \code{data}, etc). These are filled in once -\code{fit()} is called with the specifics of the data for the model. -The call may also include \code{tune()} arguments if these are in -the specification. To handle the \code{tune()} arguments, you need to use the -\href{https://tune.tidymodels.org/}{tune package}. For more information -see \url{https://www.tidymodels.org/start/tuning/} - -It does contain the resolved argument names that are specific to -the model fitting function/engine. - -This function can be useful when you need to understand how -\code{tidyclust} goes from a generic model specific to a model fitting -function. - -\strong{Note}: this function is used internally and users should only use it -to understand what the underlying syntax would be. It should not be used -to modify the cluster specification. -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/translate.R +\name{translate_tidyclust} +\alias{translate_tidyclust} +\alias{translate_tidyclust.default} +\title{Resolve a Model Specification for a Computational Engine} +\usage{ +translate_tidyclust(x, ...) + +\method{translate_tidyclust}{default}(x, engine = x$engine, ...) +} +\arguments{ +\item{x}{A model specification.} + +\item{...}{Not currently used.} + +\item{engine}{The computational engine for the model (see \code{?set_engine}).} +} +\description{ +\code{translate_tidyclust()} will translate_tidyclust a model specification into a code +object that is specific to a particular engine (e.g. R package). +It translate_tidyclusts generic parameters to their counterparts. +} +\details{ +\code{translate_tidyclust()} produces a \emph{template} call that lacks the specific +argument values (such as \code{data}, etc). These are filled in once +\code{fit()} is called with the specifics of the data for the model. +The call may also include \code{tune()} arguments if these are in +the specification. To handle the \code{tune()} arguments, you need to use the +\href{https://tune.tidymodels.org/}{tune package}. For more information +see \url{https://www.tidymodels.org/start/tuning/} + +It does contain the resolved argument names that are specific to +the model fitting function/engine. + +This function can be useful when you need to understand how +\code{tidyclust} goes from a generic model specific to a model fitting +function. + +\strong{Note}: this function is used internally and users should only use it +to understand what the underlying syntax would be. It should not be used +to modify the cluster specification. +} diff --git a/man/within_cluster_sse.Rd b/man/within_cluster_sse.Rd index 1e425a4e..090b35ea 100644 --- a/man/within_cluster_sse.Rd +++ b/man/within_cluster_sse.Rd @@ -1,33 +1,33 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metric-sse.R -\name{within_cluster_sse} -\alias{within_cluster_sse} -\title{Calculates Sum of Squared Error in each cluster} -\usage{ -within_cluster_sse(object, new_data = NULL, dist_fun = Rfast::dista) -} -\arguments{ -\item{object}{A fitted kmeans tidyclust model} - -\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} - -\item{dist_fun}{A function for calculating distances to centroids. Defaults -to Euclidean distance on processed data.} -} -\value{ -A tibble with two columns, the cluster name and the SSE within that -cluster. -} -\description{ -Calculates Sum of Squared Error in each cluster -} -\examples{ -kmeans_spec <- k_means(num_clusters = 5) \%>\% - set_engine("stats") - -kmeans_fit <- fit(kmeans_spec, ~., mtcars) - -kmeans_fit \%>\% - within_cluster_sse() - -} +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/metric-sse.R +\name{within_cluster_sse} +\alias{within_cluster_sse} +\title{Calculates Sum of Squared Error in each cluster} +\usage{ +within_cluster_sse(object, new_data = NULL, dist_fun = Rfast::dista) +} +\arguments{ +\item{object}{A fitted kmeans tidyclust model} + +\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.} + +\item{dist_fun}{A function for calculating distances to centroids. Defaults +to Euclidean distance on processed data.} +} +\value{ +A tibble with two columns, the cluster name and the SSE within that +cluster. +} +\description{ +Calculates Sum of Squared Error in each cluster +} +\examples{ +kmeans_spec <- k_means(num_clusters = 5) \%>\% + set_engine("stats") + +kmeans_fit <- fit(kmeans_spec, ~., mtcars) + +kmeans_fit \%>\% + within_cluster_sse() + +} diff --git a/tests/testthat/_snaps/arguments.md b/tests/testthat/_snaps/arguments.md index 5f699f73..a3c59270 100644 --- a/tests/testthat/_snaps/arguments.md +++ b/tests/testthat/_snaps/arguments.md @@ -1,35 +1,35 @@ -# pipe arguments - - Code - k_means() %>% set_args() - Error - Please pass at least one named argument. - -# pipe engine - - Code - k_means() %>% set_mode() - Error - Available modes for model type k_means are: 'unknown', 'partition' - ---- - - Code - k_means() %>% set_mode(2) - Error - '2' is not a known mode for model `k_means()`. - ---- - - Code - k_means() %>% set_mode("haberdashery") - Error - 'haberdashery' is not a known mode for model `k_means()`. - -# can't set a mode that isn't allowed by the model spec - - Code - set_mode(k_means(), "classification") - Error - 'classification' is not a known mode for model `k_means()`. - +# pipe arguments + + Code + k_means() %>% set_args() + Error + Please pass at least one named argument. + +# pipe engine + + Code + k_means() %>% set_mode() + Error + Available modes for model type k_means are: 'unknown', 'partition' + +--- + + Code + k_means() %>% set_mode(2) + Error + '2' is not a known mode for model `k_means()`. + +--- + + Code + k_means() %>% set_mode("haberdashery") + Error + 'haberdashery' is not a known mode for model `k_means()`. + +# can't set a mode that isn't allowed by the model spec + + Code + set_mode(k_means(), "classification") + Error + 'classification' is not a known mode for model `k_means()`. + diff --git a/tests/testthat/_snaps/hier_clust.md b/tests/testthat/_snaps/hier_clust.md index f96176c8..60904caf 100644 --- a/tests/testthat/_snaps/hier_clust.md +++ b/tests/testthat/_snaps/hier_clust.md @@ -1,59 +1,59 @@ -# bad input - - Code - hier_clust(mode = "bogus") - Error - 'bogus' is not a known mode for model `hier_clust()`. - ---- - - Code - bt <- hier_clust(method = "bogus") %>% set_engine("stats") - Error - unused argument (method = "bogus") - Code - fit(bt, mpg ~ ., mtcars) - Error - object 'bt' not found - ---- - - Code - translate_tidyclust(hier_clust(), engine = NULL) - Error - Please set an engine. - ---- - - Code - translate_tidyclust(hier_clust(formula = ~x)) - Error - unused argument (formula = ~x) - -# printing - - Code - hier_clust() - Output - Hierarchical Clustering Specification (partition) - - Main Arguments: - linkage_method = complete - - Computational engine: stats - - ---- - - Code - hier_clust(k = 10) - Output - Hierarchical Clustering Specification (partition) - - Main Arguments: - k = 10 - linkage_method = complete - - Computational engine: stats - - +# bad input + + Code + hier_clust(mode = "bogus") + Error + 'bogus' is not a known mode for model `hier_clust()`. + +--- + + Code + bt <- hier_clust(method = "bogus") %>% set_engine("stats") + Error + unused argument (method = "bogus") + Code + fit(bt, mpg ~ ., mtcars) + Error + object 'bt' not found + +--- + + Code + translate_tidyclust(hier_clust(), engine = NULL) + Error + Please set an engine. + +--- + + Code + translate_tidyclust(hier_clust(formula = ~x)) + Error + unused argument (formula = ~x) + +# printing + + Code + hier_clust() + Output + Hierarchical Clustering Specification (partition) + + Main Arguments: + linkage_method = complete + + Computational engine: stats + + +--- + + Code + hier_clust(k = 10) + Output + Hierarchical Clustering Specification (partition) + + Main Arguments: + k = 10 + linkage_method = complete + + Computational engine: stats + + diff --git a/tests/testthat/_snaps/hier_clust.new.md b/tests/testthat/_snaps/hier_clust.new.md index 832f4dcd..fa7c44fe 100644 --- a/tests/testthat/_snaps/hier_clust.new.md +++ b/tests/testthat/_snaps/hier_clust.new.md @@ -1,59 +1,59 @@ -# bad input - - Code - hier_clust(mode = "bogus") - Error - 'bogus' is not a known mode for model `hier_clust()`. - ---- - - Code - bt <- hier_clust(linkage_method = "bogus") %>% set_engine("stats") - Error - `object` should have class 'model_spec'. - Code - fit(bt, mpg ~ ., mtcars) - Error - object 'bt' not found - ---- - - Code - translate_tidyclust(hier_clust(), engine = NULL) - Error - Please set an engine. - ---- - - Code - translate_tidyclust(hier_clust(formula = ~x)) - Error - unused argument (formula = ~x) - -# printing - - Code - hier_clust() - Output - Hierarchical Clustering Specification (partition) - - Main Arguments: - linkage_method = complete - - Computational engine: stats - - ---- - - Code - hier_clust(k = 10) - Output - Hierarchical Clustering Specification (partition) - - Main Arguments: - k = 10 - linkage_method = complete - - Computational engine: stats - - +# bad input + + Code + hier_clust(mode = "bogus") + Error + 'bogus' is not a known mode for model `hier_clust()`. + +--- + + Code + bt <- hier_clust(linkage_method = "bogus") %>% set_engine("stats") + Error + `object` should have class 'model_spec'. + Code + fit(bt, mpg ~ ., mtcars) + Error + object 'bt' not found + +--- + + Code + translate_tidyclust(hier_clust(), engine = NULL) + Error + Please set an engine. + +--- + + Code + translate_tidyclust(hier_clust(formula = ~x)) + Error + unused argument (formula = ~x) + +# printing + + Code + hier_clust() + Output + Hierarchical Clustering Specification (partition) + + Main Arguments: + linkage_method = complete + + Computational engine: stats + + +--- + + Code + hier_clust(k = 10) + Output + Hierarchical Clustering Specification (partition) + + Main Arguments: + k = 10 + linkage_method = complete + + Computational engine: stats + + diff --git a/tests/testthat/_snaps/k_means.md b/tests/testthat/_snaps/k_means.md index ed62ffcc..e720765a 100644 --- a/tests/testthat/_snaps/k_means.md +++ b/tests/testthat/_snaps/k_means.md @@ -1,65 +1,65 @@ -# bad input - - Code - k_means(mode = "bogus") - Error - 'bogus' is not a known mode for model `k_means()`. - ---- - - Code - bt <- k_means(num_clusters = -1) %>% set_engine("stats") - fit(bt, mpg ~ ., mtcars) - Error - The number of centers should be >= 0. - ---- - - Code - translate_tidyclust(k_means(), engine = NULL) - Error - Please set an engine. - ---- - - Code - translate_tidyclust(k_means(formula = ~x)) - Error - unused argument (formula = ~x) - -# printing - - Code - k_means() - Output - K Means Cluster Specification (partition) - - Computational engine: stats - - ---- - - Code - k_means(num_clusters = 10) - Output - K Means Cluster Specification (partition) - - Main Arguments: - num_clusters = 10 - - Computational engine: stats - - -# updating - - Code - k_means(num_clusters = 5) %>% update(num_clusters = tune()) - Output - K Means Cluster Specification (partition) - - Main Arguments: - num_clusters = tune() - - Computational engine: stats - - +# bad input + + Code + k_means(mode = "bogus") + Error + 'bogus' is not a known mode for model `k_means()`. + +--- + + Code + bt <- k_means(num_clusters = -1) %>% set_engine("stats") + fit(bt, mpg ~ ., mtcars) + Error + The number of centers should be >= 0. + +--- + + Code + translate_tidyclust(k_means(), engine = NULL) + Error + Please set an engine. + +--- + + Code + translate_tidyclust(k_means(formula = ~x)) + Error + unused argument (formula = ~x) + +# printing + + Code + k_means() + Output + K Means Cluster Specification (partition) + + Computational engine: stats + + +--- + + Code + k_means(num_clusters = 10) + Output + K Means Cluster Specification (partition) + + Main Arguments: + num_clusters = 10 + + Computational engine: stats + + +# updating + + Code + k_means(num_clusters = 5) %>% update(num_clusters = tune()) + Output + K Means Cluster Specification (partition) + + Main Arguments: + num_clusters = tune() + + Computational engine: stats + + diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md index 9b20dcec..0706fec9 100644 --- a/tests/testthat/_snaps/registration.md +++ b/tests/testthat/_snaps/registration.md @@ -1,339 +1,339 @@ -# adding a new model - - Code - set_new_model_tidyclust() - Error - Please supply a character string for a model name (e.g. `'k_means'`) - ---- - - Code - set_new_model_tidyclust(2) - Error - Please supply a character string for a model name (e.g. `'k_means'`) - ---- - - Code - set_new_model_tidyclust(letters[1:2]) - Error - Please supply a character string for a model name (e.g. `'k_means'`) - -# adding a new mode - - Code - set_model_mode_tidyclust("sponge") - Error - Please supply a character string for a mode (e.g. `'partition'`). - -# adding a new engine - - Code - set_model_engine_tidyclust("sponge", eng = "gum") - Error - Please supply a character string for a mode (e.g. `'partition'`). - ---- - - Code - set_model_engine_tidyclust("sponge", mode = "partition") - Error - Please supply a character string for an engine name (e.g. `'stats'`) - ---- - - Code - set_model_engine_tidyclust("sponge", mode = "regression", eng = "gum") - Error - 'regression' is not a known mode for model `sponge()`. - -# adding a new package - - Code - set_dependency_tidyclust("sponge", "gum", letters[1:2]) - Error - Please supply a single character value for the package name. - ---- - - Code - set_dependency_tidyclust("sponge", "gummies", "trident") - Error - The engine 'gummies' has not been registered for model 'sponge'. - ---- - - Code - set_dependency_tidyclust("sponge", "gum", "trident", mode = "regression") - Error - mode 'regression' is not a valid mode for 'sponge' - -# adding a new argument - - Code - set_model_arg_tidyclust(model = "lunchroom", eng = "gum", tidyclust = "modeling", - original = "modelling", func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) - Error - Model `lunchroom` has not been registered. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", - func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) - Error - Please supply a character string for the argument. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", original = "modelling", - func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) - Error - Please supply a character string for the argument. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", - original = "modelling", func = "foo::bar", has_submodel = FALSE) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", - original = "modelling", func = list(pkg = "foo", fun = "bar"), has_submodel = 2) - Error - The `submodels` argument should be a single logical. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", - original = "modelling", func = list(pkg = "foo", fun = "bar")) - Error - argument "has_submodel" is missing, with no default - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", - original = "yodelling", func = c(foo = "a", bar = "b"), has_submodel = FALSE) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", - original = "yodelling", func = c(foo = "a"), has_submodel = FALSE) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - ---- - - Code - set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", - original = "yodelling", func = c(fun = 2, pkg = 1), has_submodel = FALSE) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - -# adding a new fit - - Code - set_fit_tidyclust(model = "cactus", eng = "gum", mode = "partition", value = fit_vals) - Error - Model `cactus` has not been registered. - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "nose", mode = "partition", value = fit_vals) - Error - Engine 'nose' is not supported for `sponge()`. See `show_engines('sponge')`. - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "frog", value = fit_vals) - Error - 'frog' is not a known mode for model `sponge()`. - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ - -i]) - Error - The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ - -i]) - Error - The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ - -i]) - Error - The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ - -i]) - Error - The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_0) - Error - The `interface` element should have a single value of: `data.frame`, `formula`, `matrix` - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_1) - Error - The `defaults` element should be a list: - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_2) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - ---- - - Code - set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_3) - Error - The `interface` element should have a single value of: `data.frame`, `formula`, `matrix` - -# adding a new predict method - - Code - set_pred_tidyclust(model = "cactus", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals) - Error - Model `cactus` has not been registered. - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "nose", mode = "partition", type = "cluster", - value = cluster_vals) - Error - Engine 'nose' is not supported for `sponge()`. See `show_engines('sponge')`. - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "eggs", - value = cluster_vals) - Error - The prediction type should be one of: 'cluster', 'raw' - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "frog", type = "cluster", - value = cluster_vals) - Error - 'frog' is not a known mode for model `sponge()`. - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals[-i]) - Error - The `predict` module should have elements: `args`, `func`, `post`, `pre` - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals[-i]) - Error - The `predict` module should have elements: `args`, `func`, `post`, `pre` - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals[-i]) - Error - The `predict` module should have elements: `args`, `func`, `post`, `pre` - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals[-i]) - Error - The `predict` module should have elements: `args`, `func`, `post`, `pre` - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals_0) - Error - The `pre` module should be null or a function: - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals_1) - Error - The `post` module should be null or a function: - ---- - - Code - set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", - value = cluster_vals_2) - Error - `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. - -# showing model info - - Code - show_model_info_tidyclust("k_means") - Output - Information for `k_means` - modes: unknown, partition - - engines: - partition: ClusterR, stats - - arguments: - stats: - num_clusters --> centers - ClusterR: - num_clusters --> clusters - - fit modules: - engine mode - stats partition - ClusterR partition - - prediction modules: - mode engine methods - partition ClusterR cluster - partition stats cluster - - +# adding a new model + + Code + set_new_model_tidyclust() + Error + Please supply a character string for a model name (e.g. `'k_means'`) + +--- + + Code + set_new_model_tidyclust(2) + Error + Please supply a character string for a model name (e.g. `'k_means'`) + +--- + + Code + set_new_model_tidyclust(letters[1:2]) + Error + Please supply a character string for a model name (e.g. `'k_means'`) + +# adding a new mode + + Code + set_model_mode_tidyclust("sponge") + Error + Please supply a character string for a mode (e.g. `'partition'`). + +# adding a new engine + + Code + set_model_engine_tidyclust("sponge", eng = "gum") + Error + Please supply a character string for a mode (e.g. `'partition'`). + +--- + + Code + set_model_engine_tidyclust("sponge", mode = "partition") + Error + Please supply a character string for an engine name (e.g. `'stats'`) + +--- + + Code + set_model_engine_tidyclust("sponge", mode = "regression", eng = "gum") + Error + 'regression' is not a known mode for model `sponge()`. + +# adding a new package + + Code + set_dependency_tidyclust("sponge", "gum", letters[1:2]) + Error + Please supply a single character value for the package name. + +--- + + Code + set_dependency_tidyclust("sponge", "gummies", "trident") + Error + The engine 'gummies' has not been registered for model 'sponge'. + +--- + + Code + set_dependency_tidyclust("sponge", "gum", "trident", mode = "regression") + Error + mode 'regression' is not a valid mode for 'sponge' + +# adding a new argument + + Code + set_model_arg_tidyclust(model = "lunchroom", eng = "gum", tidyclust = "modeling", + original = "modelling", func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) + Error + Model `lunchroom` has not been registered. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", + func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) + Error + Please supply a character string for the argument. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", original = "modelling", + func = list(pkg = "foo", fun = "bar"), has_submodel = FALSE) + Error + Please supply a character string for the argument. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", + original = "modelling", func = "foo::bar", has_submodel = FALSE) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", + original = "modelling", func = list(pkg = "foo", fun = "bar"), has_submodel = 2) + Error + The `submodels` argument should be a single logical. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "modeling", + original = "modelling", func = list(pkg = "foo", fun = "bar")) + Error + argument "has_submodel" is missing, with no default + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", + original = "yodelling", func = c(foo = "a", bar = "b"), has_submodel = FALSE) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", + original = "yodelling", func = c(foo = "a"), has_submodel = FALSE) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +--- + + Code + set_model_arg_tidyclust(model = "sponge", eng = "gum", tidyclust = "yodeling", + original = "yodelling", func = c(fun = 2, pkg = 1), has_submodel = FALSE) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +# adding a new fit + + Code + set_fit_tidyclust(model = "cactus", eng = "gum", mode = "partition", value = fit_vals) + Error + Model `cactus` has not been registered. + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "nose", mode = "partition", value = fit_vals) + Error + Engine 'nose' is not supported for `sponge()`. See `show_engines('sponge')`. + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "frog", value = fit_vals) + Error + 'frog' is not a known mode for model `sponge()`. + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ + -i]) + Error + The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ + -i]) + Error + The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ + -i]) + Error + The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals[ + -i]) + Error + The `fit` module should have elements: `defaults`, `func`, `interface`, `protect` + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_0) + Error + The `interface` element should have a single value of: `data.frame`, `formula`, `matrix` + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_1) + Error + The `defaults` element should be a list: + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_2) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +--- + + Code + set_fit_tidyclust(model = "sponge", eng = "gum", mode = "partition", value = fit_vals_3) + Error + The `interface` element should have a single value of: `data.frame`, `formula`, `matrix` + +# adding a new predict method + + Code + set_pred_tidyclust(model = "cactus", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals) + Error + Model `cactus` has not been registered. + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "nose", mode = "partition", type = "cluster", + value = cluster_vals) + Error + Engine 'nose' is not supported for `sponge()`. See `show_engines('sponge')`. + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "eggs", + value = cluster_vals) + Error + The prediction type should be one of: 'cluster', 'raw' + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "frog", type = "cluster", + value = cluster_vals) + Error + 'frog' is not a known mode for model `sponge()`. + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals[-i]) + Error + The `predict` module should have elements: `args`, `func`, `post`, `pre` + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals[-i]) + Error + The `predict` module should have elements: `args`, `func`, `post`, `pre` + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals[-i]) + Error + The `predict` module should have elements: `args`, `func`, `post`, `pre` + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals[-i]) + Error + The `predict` module should have elements: `args`, `func`, `post`, `pre` + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals_0) + Error + The `pre` module should be null or a function: + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals_1) + Error + The `post` module should be null or a function: + +--- + + Code + set_pred_tidyclust(model = "sponge", eng = "gum", mode = "partition", type = "cluster", + value = cluster_vals_2) + Error + `func` should be a named vector with element 'fun' and the optional elements 'pkg', 'range', 'trans', and 'values'. `func` and 'pkg' should both be single character strings. + +# showing model info + + Code + show_model_info_tidyclust("k_means") + Output + Information for `k_means` + modes: unknown, partition + + engines: + partition: ClusterR, stats + + arguments: + stats: + num_clusters --> centers + ClusterR: + num_clusters --> clusters + + fit modules: + engine mode + stats partition + ClusterR partition + + prediction modules: + mode engine methods + partition ClusterR cluster + partition stats cluster + + diff --git a/tests/testthat/helper-tidyclust-package.R b/tests/testthat/helper-tidyclust-package.R index cfd0721b..ef04b021 100644 --- a/tests/testthat/helper-tidyclust-package.R +++ b/tests/testthat/helper-tidyclust-package.R @@ -1,27 +1,27 @@ -new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 - -helper_objects_tidyclust <- function() { - rec_tune_1 <- - recipes::recipe(~ ., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) %>% - recipes::step_pca(recipes::all_predictors(), num_comp = tune()) - - rec_no_tune_1 <- - recipes::recipe(~ ., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) - - kmeans_mod_no_tune <- k_means(num_clusters = 2) - - kmeans_mod <- k_means(num_clusters = tune()) - - list( - rec_tune_1 = rec_tune_1, - rec_no_tune_1 = rec_no_tune_1, - kmeans_mod = kmeans_mod, - kmeans_mod_no_tune = kmeans_mod_no_tune - ) -} - -new_empty_quosure <- function(expr) { - rlang::new_quosure(expr, env = rlang::empty_env()) -} +new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 + +helper_objects_tidyclust <- function() { + rec_tune_1 <- + recipes::recipe(~ ., data = mtcars) %>% + recipes::step_normalize(recipes::all_predictors()) %>% + recipes::step_pca(recipes::all_predictors(), num_comp = tune()) + + rec_no_tune_1 <- + recipes::recipe(~ ., data = mtcars) %>% + recipes::step_normalize(recipes::all_predictors()) + + kmeans_mod_no_tune <- k_means(num_clusters = 2) + + kmeans_mod <- k_means(num_clusters = tune()) + + list( + rec_tune_1 = rec_tune_1, + rec_no_tune_1 = rec_no_tune_1, + kmeans_mod = kmeans_mod, + kmeans_mod_no_tune = kmeans_mod_no_tune + ) +} + +new_empty_quosure <- function(expr) { + rlang::new_quosure(expr, env = rlang::empty_env()) +} diff --git a/tests/testthat/test-arguments.R b/tests/testthat/test-arguments.R index e9c1a38f..0df6433b 100644 --- a/tests/testthat/test-arguments.R +++ b/tests/testthat/test-arguments.R @@ -1,45 +1,45 @@ -test_that('pipe arguments', { - mod_1 <- k_means() %>% - set_args(num_clusters = 1) - expect_equal( - rlang::quo_get_expr(mod_1$args$num_clusters), - 1 - ) - expect_equal( - rlang::quo_get_env(mod_1$args$num_clusters), - rlang::empty_env() - ) - - mod_2 <- k_means(num_clusters = 2) %>% - set_args(num_clusters = 1) - - var_env <- rlang::current_env() - - expect_equal( - rlang::quo_get_expr(mod_2$args$num_clusters), - 1 - ) - expect_equal( - rlang::quo_get_env(mod_2$args$num_clusters), - rlang::empty_env() - ) - - expect_snapshot(error = TRUE, k_means() %>% set_args()) -}) - - -test_that('pipe engine', { - mod_1 <- k_means() %>% - set_mode("partition") - expect_equal(mod_1$mode, "partition") - - expect_snapshot(error = TRUE, k_means() %>% set_mode()) - expect_snapshot(error = TRUE, k_means() %>% set_mode(2)) - expect_snapshot(error = TRUE, k_means() %>% set_mode("haberdashery")) -}) - -test_that("can't set a mode that isn't allowed by the model spec", { - expect_snapshot(error = TRUE, - set_mode(k_means(), "classification") - ) -}) +test_that('pipe arguments', { + mod_1 <- k_means() %>% + set_args(num_clusters = 1) + expect_equal( + rlang::quo_get_expr(mod_1$args$num_clusters), + 1 + ) + expect_equal( + rlang::quo_get_env(mod_1$args$num_clusters), + rlang::empty_env() + ) + + mod_2 <- k_means(num_clusters = 2) %>% + set_args(num_clusters = 1) + + var_env <- rlang::current_env() + + expect_equal( + rlang::quo_get_expr(mod_2$args$num_clusters), + 1 + ) + expect_equal( + rlang::quo_get_env(mod_2$args$num_clusters), + rlang::empty_env() + ) + + expect_snapshot(error = TRUE, k_means() %>% set_args()) +}) + + +test_that('pipe engine', { + mod_1 <- k_means() %>% + set_mode("partition") + expect_equal(mod_1$mode, "partition") + + expect_snapshot(error = TRUE, k_means() %>% set_mode()) + expect_snapshot(error = TRUE, k_means() %>% set_mode(2)) + expect_snapshot(error = TRUE, k_means() %>% set_mode("haberdashery")) +}) + +test_that("can't set a mode that isn't allowed by the model spec", { + expect_snapshot(error = TRUE, + set_mode(k_means(), "classification") + ) +}) diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index e973e288..063a7db3 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -1,28 +1,28 @@ -test_that('partition models', { - x <- k_means(num_clusters = 2) - - set.seed(1234) - reg_form <- x %>% fit(~ ., data = mtcars) - set.seed(1234) - reg_xy <- x %>% fit_xy(mtcars) - - expect_equal( - colnames(augment(reg_form, head(mtcars))), - c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster") - ) - expect_equal(nrow(augment(reg_form, head(mtcars))), 6) - - expect_equal( - colnames(augment(reg_xy, head(mtcars))), - c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster") - ) - expect_equal(nrow(augment(reg_xy, head(mtcars))), 6) - - expect_s3_class(augment(reg_form, head(mtcars)), "tbl_df") - - reg_form$spec$mode <- "depeche" - - expect_snapshot(error = TRUE, augment(reg_form, head(mtcars[, -1]))) -}) +test_that('partition models', { + x <- k_means(num_clusters = 2) + + set.seed(1234) + reg_form <- x %>% fit(~ ., data = mtcars) + set.seed(1234) + reg_xy <- x %>% fit_xy(mtcars) + + expect_equal( + colnames(augment(reg_form, head(mtcars))), + c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", + "gear", "carb", ".pred_cluster") + ) + expect_equal(nrow(augment(reg_form, head(mtcars))), 6) + + expect_equal( + colnames(augment(reg_xy, head(mtcars))), + c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", + "gear", "carb", ".pred_cluster") + ) + expect_equal(nrow(augment(reg_xy, head(mtcars))), 6) + + expect_s3_class(augment(reg_form, head(mtcars)), "tbl_df") + + reg_form$spec$mode <- "depeche" + + expect_snapshot(error = TRUE, augment(reg_form, head(mtcars[, -1]))) +}) diff --git a/tests/testthat/test-cluster_metric_set.R b/tests/testthat/test-cluster_metric_set.R index 14038214..bb6a1854 100644 --- a/tests/testthat/test-cluster_metric_set.R +++ b/tests/testthat/test-cluster_metric_set.R @@ -1,42 +1,42 @@ -test_that("cluster_metric_set works", { - kmeans_spec <- k_means(num_clusters = 5) %>% - set_engine("stats") - - kmeans_fit <- fit(kmeans_spec, ~., mtcars) - - my_metrics <- cluster_metric_set(sse_ratio, tot_sse, tot_wss, avg_silhouette) - - exp_res <- tibble::tibble( - .metric = c("sse_ratio", "tot_sse", "tot_wss", "avg_silhouette"), - .estimator = "standard", - .estimate = vapply( - list(sse_ratio_vec, tot_sse_vec, tot_wss_vec, avg_silhouette_vec), - function(x) x(kmeans_fit, new_data = mtcars), - FUN.VALUE = numeric(1) - ) - ) - - expect_equal( - my_metrics(kmeans_fit, new_data = mtcars), - exp_res - ) - - expect_snapshot(error = TRUE, my_metrics(kmeans_fit)) - - my_metrics <- cluster_metric_set(sse_ratio, tot_sse, tot_wss) - - expect_equal( - my_metrics(kmeans_fit), - exp_res[-4, ] - ) -}) - -test_that("cluster_metric_set error with wrong input", { - expect_snapshot(error = TRUE, - cluster_metric_set(mean) - ) - - expect_snapshot(error = TRUE, - cluster_metric_set(sse_ratio, mean) - ) -}) +test_that("cluster_metric_set works", { + kmeans_spec <- k_means(num_clusters = 5) %>% + set_engine("stats") + + kmeans_fit <- fit(kmeans_spec, ~., mtcars) + + my_metrics <- cluster_metric_set(sse_ratio, tot_sse, tot_wss, avg_silhouette) + + exp_res <- tibble::tibble( + .metric = c("sse_ratio", "tot_sse", "tot_wss", "avg_silhouette"), + .estimator = "standard", + .estimate = vapply( + list(sse_ratio_vec, tot_sse_vec, tot_wss_vec, avg_silhouette_vec), + function(x) x(kmeans_fit, new_data = mtcars), + FUN.VALUE = numeric(1) + ) + ) + + expect_equal( + my_metrics(kmeans_fit, new_data = mtcars), + exp_res + ) + + expect_snapshot(error = TRUE, my_metrics(kmeans_fit)) + + my_metrics <- cluster_metric_set(sse_ratio, tot_sse, tot_wss) + + expect_equal( + my_metrics(kmeans_fit), + exp_res[-4, ] + ) +}) + +test_that("cluster_metric_set error with wrong input", { + expect_snapshot(error = TRUE, + cluster_metric_set(mean) + ) + + expect_snapshot(error = TRUE, + cluster_metric_set(sse_ratio, mean) + ) +}) diff --git a/tests/testthat/test-control.R b/tests/testthat/test-control.R index 2203dfea..6925c5dd 100644 --- a/tests/testthat/test-control.R +++ b/tests/testthat/test-control.R @@ -1,14 +1,14 @@ -test_that("control class", { - skip("waiting for workflow PR") - x <- k_means(num_clusters = 5) %>% set_engine("stats") - ctrl <- control_cluster() - class(ctrl) <- c("potato", "chair") - expect_snapshot( - error = TRUE, - fit(x, mpg ~ ., data = mtcars, control = ctrl) - ) - expect_snapshot( - error = TRUE, - fit_xy(x, x = mtcars[, -1], y = mtcars$mpg, control = ctrl) - ) -}) +test_that("control class", { + skip("waiting for workflow PR") + x <- k_means(num_clusters = 5) %>% set_engine("stats") + ctrl <- control_cluster() + class(ctrl) <- c("potato", "chair") + expect_snapshot( + error = TRUE, + fit(x, mpg ~ ., data = mtcars, control = ctrl) + ) + expect_snapshot( + error = TRUE, + fit_xy(x, x = mtcars[, -1], y = mtcars$mpg, control = ctrl) + ) +}) diff --git a/tests/testthat/test-extract_summary.R b/tests/testthat/test-extract_summary.R index 060205e4..f110305b 100644 --- a/tests/testthat/test-extract_summary.R +++ b/tests/testthat/test-extract_summary.R @@ -1,40 +1,40 @@ -test_that("extract summary works for kmeans", { - obj1 <- k_means(num_clusters = mtcars[1:3, ]) %>% - set_engine("stats", algorithm = "MacQueen") %>% - fit(~., mtcars) - - obj2 <- k_means(num_clusters = 3) %>% - set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% - fit(~., mtcars) - - summ1 <- extract_fit_summary(obj1) - summ2 <- extract_fit_summary(obj2) - - expect_equal(names(summ1), names(summ2)) - - # check order - expect_equal(summ1$n_members, c(17, 11, 4)) -}) - -test_that("extract summary works for kmeans when num_clusters = 1", { - obj1 <- k_means(num_clusters = 1) %>% - set_engine("stats") %>% - fit(~., mtcars) - - obj2 <- k_means(num_clusters = 1) %>% - set_engine("ClusterR") %>% - fit(~., mtcars) - - summ1 <- extract_fit_summary(obj1) - summ2 <- extract_fit_summary(obj2) - - expect_equal( - summ1$centroids, - tibble::as_tibble(lapply(mtcars, mean)) - ) - - expect_equal( - summ2$centroids, - tibble::as_tibble(lapply(mtcars, mean)) - ) -}) +test_that("extract summary works for kmeans", { + obj1 <- k_means(num_clusters = mtcars[1:3, ]) %>% + set_engine("stats", algorithm = "MacQueen") %>% + fit(~., mtcars) + + obj2 <- k_means(num_clusters = 3) %>% + set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% + fit(~., mtcars) + + summ1 <- extract_fit_summary(obj1) + summ2 <- extract_fit_summary(obj2) + + expect_equal(names(summ1), names(summ2)) + + # check order + expect_equal(summ1$n_members, c(17, 11, 4)) +}) + +test_that("extract summary works for kmeans when num_clusters = 1", { + obj1 <- k_means(num_clusters = 1) %>% + set_engine("stats") %>% + fit(~., mtcars) + + obj2 <- k_means(num_clusters = 1) %>% + set_engine("ClusterR") %>% + fit(~., mtcars) + + summ1 <- extract_fit_summary(obj1) + summ2 <- extract_fit_summary(obj2) + + expect_equal( + summ1$centroids, + tibble::as_tibble(lapply(mtcars, mean)) + ) + + expect_equal( + summ2$centroids, + tibble::as_tibble(lapply(mtcars, mean)) + ) +}) diff --git a/tests/testthat/test-hier_clust.R b/tests/testthat/test-hier_clust.R index 55a3daec..8e7e8bc2 100644 --- a/tests/testthat/test-hier_clust.R +++ b/tests/testthat/test-hier_clust.R @@ -1,85 +1,85 @@ -test_that("primary arguments", { - basic <- hier_clust(mode = "partition") - basic_stats <- translate_tidyclust(basic %>% set_engine("stats")) - expect_equal( - basic_stats$method$fit$args, - list( - x = rlang::expr(missing_arg()) - ) - ) -}) - -test_that("engine arguments", { - stats_print <- hier_clust(mode = "partition") - expect_equal( - translate_tidyclust( - stats_print %>% - set_engine("stats", linkage_method = "single") - )$method$fit$args, - list( - x = rlang::expr(missing_arg()), - nstart = new_empty_quosure("single") - ) - ) -}) - -test_that("bad input", { - expect_snapshot(error = TRUE, hier_clust(mode = "bogus")) - expect_snapshot(error = TRUE, { - bt <- hier_clust(linkage_method = "bogus") %>% set_engine("stats") - fit(bt, mpg ~ ., mtcars) - }) - expect_snapshot(error = TRUE, translate_tidyclust(hier_clust(), engine = NULL)) - expect_snapshot(error = TRUE, translate_tidyclust(hier_clust(formula = ~x))) -}) - -test_that("predictions", { - set.seed(1234) - hclust_fit <- hier_clust(k = 4) %>% - set_engine("stats") %>% - fit(~., mtcars) - - set.seed(1234) - ref_res <- cutree(hclust(dist(mtcars)), k = 4) - - ref_predictions <- ref_res %>% unname() - - relevel_preds <- function(x) { - factor(unname(x), unique(unname(x))) %>% as.numeric() - } - - expect_equal( - relevel_preds(ref_predictions), - predict(hclust_fit, mtcars)$.pred_cluster %>% as.numeric() - ) - - expect_equal( - relevel_preds(unname(ref_res$cluster)), - extract_cluster_assignment(hclust_fit)$.cluster %>% as.numeric() - ) - - expect_equal( - relevel_preds(predict(hclust_fit, mtcars)$.pred_cluster), - extract_cluster_assignment(hclust_fit)$.cluster %>% as.numeric() - ) -}) - -test_that("Right classes", { - expect_equal(class(hier_clust()), c("hier_clust", "cluster_spec")) -}) - -test_that("printing", { - expect_snapshot( - hier_clust() - ) - expect_snapshot( - hier_clust(k = 10) - ) -}) - -test_that('updating', { - expect_snapshot( - hier_clust(k = 5) %>% - update(k = tune()) - ) -}) +test_that("primary arguments", { + basic <- hier_clust(mode = "partition") + basic_stats <- translate_tidyclust(basic %>% set_engine("stats")) + expect_equal( + basic_stats$method$fit$args, + list( + x = rlang::expr(missing_arg()) + ) + ) +}) + +test_that("engine arguments", { + stats_print <- hier_clust(mode = "partition") + expect_equal( + translate_tidyclust( + stats_print %>% + set_engine("stats", linkage_method = "single") + )$method$fit$args, + list( + x = rlang::expr(missing_arg()), + nstart = new_empty_quosure("single") + ) + ) +}) + +test_that("bad input", { + expect_snapshot(error = TRUE, hier_clust(mode = "bogus")) + expect_snapshot(error = TRUE, { + bt <- hier_clust(linkage_method = "bogus") %>% set_engine("stats") + fit(bt, mpg ~ ., mtcars) + }) + expect_snapshot(error = TRUE, translate_tidyclust(hier_clust(), engine = NULL)) + expect_snapshot(error = TRUE, translate_tidyclust(hier_clust(formula = ~x))) +}) + +test_that("predictions", { + set.seed(1234) + hclust_fit <- hier_clust(k = 4) %>% + set_engine("stats") %>% + fit(~., mtcars) + + set.seed(1234) + ref_res <- cutree(hclust(dist(mtcars)), k = 4) + + ref_predictions <- ref_res %>% unname() + + relevel_preds <- function(x) { + factor(unname(x), unique(unname(x))) %>% as.numeric() + } + + expect_equal( + relevel_preds(ref_predictions), + predict(hclust_fit, mtcars)$.pred_cluster %>% as.numeric() + ) + + expect_equal( + relevel_preds(unname(ref_res$cluster)), + extract_cluster_assignment(hclust_fit)$.cluster %>% as.numeric() + ) + + expect_equal( + relevel_preds(predict(hclust_fit, mtcars)$.pred_cluster), + extract_cluster_assignment(hclust_fit)$.cluster %>% as.numeric() + ) +}) + +test_that("Right classes", { + expect_equal(class(hier_clust()), c("hier_clust", "cluster_spec")) +}) + +test_that("printing", { + expect_snapshot( + hier_clust() + ) + expect_snapshot( + hier_clust(k = 10) + ) +}) + +test_that('updating', { + expect_snapshot( + hier_clust(k = 5) %>% + update(k = tune()) + ) +}) diff --git a/tests/testthat/test-k_means.R b/tests/testthat/test-k_means.R index 4913215e..2e0340a5 100644 --- a/tests/testthat/test-k_means.R +++ b/tests/testthat/test-k_means.R @@ -1,101 +1,101 @@ -test_that("primary arguments", { - basic <- k_means(mode = "partition") - basic_stats <- translate_tidyclust(basic %>% set_engine("stats")) - expect_equal( - basic_stats$method$fit$args, - list( - x = rlang::expr(missing_arg()), - centers = rlang::expr(missing_arg()) - ) - ) - - k <- k_means(num_clusters = 15, mode = "partition") - k_stats <- translate_tidyclust(k %>% set_engine("stats")) - expect_equal( - k_stats$method$fit$args, - list( - x = rlang::expr(missing_arg()), - centers = rlang::expr(missing_arg()), - centers = new_empty_quosure(15) - ) - ) -}) - -test_that("engine arguments", { - stats_print <- k_means(mode = "partition") - expect_equal( - translate_tidyclust( - stats_print %>% - set_engine("stats", nstart = 1L) - )$method$fit$args, - list( - x = rlang::expr(missing_arg()), - centers = rlang::expr(missing_arg()), - nstart = new_empty_quosure(1L) - ) - ) -}) - -test_that("bad input", { - expect_snapshot(error = TRUE, k_means(mode = "bogus")) - expect_snapshot(error = TRUE, { - bt <- k_means(num_clusters = -1) %>% set_engine("stats") - fit(bt, mpg ~ ., mtcars) - }) - expect_snapshot(error = TRUE, translate_tidyclust(k_means(), engine = NULL)) - expect_snapshot(error = TRUE, translate_tidyclust(k_means(formula = ~x))) -}) - -test_that("predictions", { - set.seed(1234) - kmeans_fit <- k_means(num_clusters = 4) %>% - set_engine("stats") %>% - fit(~., mtcars) - - set.seed(1234) - ref_res <- kmeans(mtcars, 4) - - ref_predictions <- ref_res$centers %>% - flexclust::dist2(mtcars) %>% - apply(2, which.min) %>% - unname() - - relevel_preds <- function(x) { - factor(unname(x), unique(unname(x))) %>% as.numeric() - } - - expect_equal( - relevel_preds(ref_predictions), - predict(kmeans_fit, mtcars)$.pred_cluster %>% as.numeric() - ) - - expect_equal( - relevel_preds(unname(ref_res$cluster)), - extract_cluster_assignment(kmeans_fit)$.cluster %>% as.numeric() - ) - - expect_equal( - relevel_preds(predict(kmeans_fit, mtcars)$.pred_cluster), - extract_cluster_assignment(kmeans_fit)$.cluster %>% as.numeric() - ) -}) - -test_that("Right classes", { - expect_equal(class(k_means()), c("k_means", "cluster_spec")) -}) - -test_that("printing", { - expect_snapshot( - k_means() - ) - expect_snapshot( - k_means(num_clusters = 10) - ) -}) - -test_that('updating', { - expect_snapshot( - k_means(num_clusters = 5) %>% - update(num_clusters = tune()) - ) -}) +test_that("primary arguments", { + basic <- k_means(mode = "partition") + basic_stats <- translate_tidyclust(basic %>% set_engine("stats")) + expect_equal( + basic_stats$method$fit$args, + list( + x = rlang::expr(missing_arg()), + centers = rlang::expr(missing_arg()) + ) + ) + + k <- k_means(num_clusters = 15, mode = "partition") + k_stats <- translate_tidyclust(k %>% set_engine("stats")) + expect_equal( + k_stats$method$fit$args, + list( + x = rlang::expr(missing_arg()), + centers = rlang::expr(missing_arg()), + centers = new_empty_quosure(15) + ) + ) +}) + +test_that("engine arguments", { + stats_print <- k_means(mode = "partition") + expect_equal( + translate_tidyclust( + stats_print %>% + set_engine("stats", nstart = 1L) + )$method$fit$args, + list( + x = rlang::expr(missing_arg()), + centers = rlang::expr(missing_arg()), + nstart = new_empty_quosure(1L) + ) + ) +}) + +test_that("bad input", { + expect_snapshot(error = TRUE, k_means(mode = "bogus")) + expect_snapshot(error = TRUE, { + bt <- k_means(num_clusters = -1) %>% set_engine("stats") + fit(bt, mpg ~ ., mtcars) + }) + expect_snapshot(error = TRUE, translate_tidyclust(k_means(), engine = NULL)) + expect_snapshot(error = TRUE, translate_tidyclust(k_means(formula = ~x))) +}) + +test_that("predictions", { + set.seed(1234) + kmeans_fit <- k_means(num_clusters = 4) %>% + set_engine("stats") %>% + fit(~., mtcars) + + set.seed(1234) + ref_res <- kmeans(mtcars, 4) + + ref_predictions <- ref_res$centers %>% + flexclust::dist2(mtcars) %>% + apply(2, which.min) %>% + unname() + + relevel_preds <- function(x) { + factor(unname(x), unique(unname(x))) %>% as.numeric() + } + + expect_equal( + relevel_preds(ref_predictions), + predict(kmeans_fit, mtcars)$.pred_cluster %>% as.numeric() + ) + + expect_equal( + relevel_preds(unname(ref_res$cluster)), + extract_cluster_assignment(kmeans_fit)$.cluster %>% as.numeric() + ) + + expect_equal( + relevel_preds(predict(kmeans_fit, mtcars)$.pred_cluster), + extract_cluster_assignment(kmeans_fit)$.cluster %>% as.numeric() + ) +}) + +test_that("Right classes", { + expect_equal(class(k_means()), c("k_means", "cluster_spec")) +}) + +test_that("printing", { + expect_snapshot( + k_means() + ) + expect_snapshot( + k_means(num_clusters = 10) + ) +}) + +test_that('updating', { + expect_snapshot( + k_means(num_clusters = 5) %>% + update(num_clusters = tune()) + ) +}) diff --git a/tests/testthat/test-k_means_diagnostics.R b/tests/testthat/test-k_means_diagnostics.R index 429cecbe..11ab6962 100644 --- a/tests/testthat/test-k_means_diagnostics.R +++ b/tests/testthat/test-k_means_diagnostics.R @@ -1,103 +1,103 @@ -test_that("kmeans sse metrics work", { - kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% - set_engine("stats", algorithm = "MacQueen") %>% - fit(~., mtcars) - - kmeans_fit_ClusterR <- k_means(k = 3) %>% - set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% - fit(~., mtcars) - - km_orig <- kmeans(mtcars, centers = mtcars[1:3, ], algorithm = "MacQueen") - km_orig_2 <- ClusterR::KMeans_rcpp( - data = mtcars, - clusters = 3, - CENTROIDS = as.matrix(mtcars[1:3, ]) - ) - - expect_equal(within_cluster_sse(kmeans_fit_stats)$wss, - c(42877.103, 76954.010, 7654.146), # hard coded because of order - tolerance = 0.005 - ) - - expect_equal(tot_wss(kmeans_fit_stats), km_orig$tot.withinss, tolerance = 0.005) - expect_equal(tot_sse(kmeans_fit_stats), km_orig$totss, tolerance = 0.005) - expect_equal(sse_ratio(kmeans_fit_stats), km_orig$tot.withinss / km_orig$totss, tolerance = 0.005) - - expect_equal(within_cluster_sse(kmeans_fit_ClusterR)$wss, - c(56041.432, 4665.041, 42877.103), # hard coded because of order - tolerance = 0.005 - ) - - expect_equal(tot_wss(kmeans_fit_ClusterR), sum(km_orig_2$WCSS_per_cluster), tolerance = 0.005) - expect_equal(tot_sse(kmeans_fit_ClusterR), tot_sse(kmeans_fit_stats), tolerance = 0.005) - expect_equal(sse_ratio(kmeans_fit_ClusterR), 0.1661624, tolerance = 0.005) -}) - -test_that("kmeans sse metrics work on new data", { - kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% - set_engine("stats", algorithm = "MacQueen") %>% - fit(~., mtcars) - - new_data <- mtcars[1:4, ] - - expect_equal(within_cluster_sse(kmeans_fit_stats, new_data)$wss, - c(2799.21, 12855.17), - tolerance = 0.005 - ) - - expect_equal(tot_wss(kmeans_fit_stats, new_data), 15654.38, tolerance = 0.005) - expect_equal(tot_sse(kmeans_fit_stats, new_data), 32763.7, tolerance = 0.005) - expect_equal(sse_ratio(kmeans_fit_stats, new_data), 15654.38 / 32763.7, tolerance = 0.005) -}) - -test_that("kmeans sihouette metrics work", { - kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% - set_engine("stats", algorithm = "MacQueen") %>% - fit(~., mtcars) - - kmeans_fit_ClusterR <- k_means(k = 3) %>% - set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% - fit(~., mtcars) - - new_data <- mtcars[1:4, ] - - dists <- mtcars %>% - as.matrix() %>% - dist() - - expect_equal( - names(silhouettes(kmeans_fit_stats, dists = dists)), - names(silhouettes(kmeans_fit_ClusterR, dists = dists)) - ) - - expect_equal(avg_silhouette(kmeans_fit_stats, dists = dists), 0.4993742, - tolerance = 0.005 - ) - expect_equal(avg_silhouette(kmeans_fit_ClusterR, dists = dists), 0.5473414, - tolerance = 0.005 - ) -}) - -test_that("kmeans sihouette metrics work with new data", { - kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% - set_engine("stats", algorithm = "MacQueen") %>% - fit(~., mtcars) - - kmeans_fit_ClusterR <- k_means(k = 3) %>% - set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% - fit(~., mtcars) - - new_data <- mtcars[1:4, ] - - expect_equal( - names(silhouettes(kmeans_fit_stats, new_data = new_data)), - names(silhouettes(kmeans_fit_ClusterR, new_data = new_data)) - ) - - expect_equal(avg_silhouette(kmeans_fit_stats, new_data = new_data), 0.5176315, - tolerance = 0.005 - ) - expect_equal(avg_silhouette(kmeans_fit_ClusterR, new_data = new_data), 0.5176315, - tolerance = 0.005 - ) -}) +test_that("kmeans sse metrics work", { + kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% + set_engine("stats", algorithm = "MacQueen") %>% + fit(~., mtcars) + + kmeans_fit_ClusterR <- k_means(k = 3) %>% + set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% + fit(~., mtcars) + + km_orig <- kmeans(mtcars, centers = mtcars[1:3, ], algorithm = "MacQueen") + km_orig_2 <- ClusterR::KMeans_rcpp( + data = mtcars, + clusters = 3, + CENTROIDS = as.matrix(mtcars[1:3, ]) + ) + + expect_equal(within_cluster_sse(kmeans_fit_stats)$wss, + c(42877.103, 76954.010, 7654.146), # hard coded because of order + tolerance = 0.005 + ) + + expect_equal(tot_wss(kmeans_fit_stats), km_orig$tot.withinss, tolerance = 0.005) + expect_equal(tot_sse(kmeans_fit_stats), km_orig$totss, tolerance = 0.005) + expect_equal(sse_ratio(kmeans_fit_stats), km_orig$tot.withinss / km_orig$totss, tolerance = 0.005) + + expect_equal(within_cluster_sse(kmeans_fit_ClusterR)$wss, + c(56041.432, 4665.041, 42877.103), # hard coded because of order + tolerance = 0.005 + ) + + expect_equal(tot_wss(kmeans_fit_ClusterR), sum(km_orig_2$WCSS_per_cluster), tolerance = 0.005) + expect_equal(tot_sse(kmeans_fit_ClusterR), tot_sse(kmeans_fit_stats), tolerance = 0.005) + expect_equal(sse_ratio(kmeans_fit_ClusterR), 0.1661624, tolerance = 0.005) +}) + +test_that("kmeans sse metrics work on new data", { + kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% + set_engine("stats", algorithm = "MacQueen") %>% + fit(~., mtcars) + + new_data <- mtcars[1:4, ] + + expect_equal(within_cluster_sse(kmeans_fit_stats, new_data)$wss, + c(2799.21, 12855.17), + tolerance = 0.005 + ) + + expect_equal(tot_wss(kmeans_fit_stats, new_data), 15654.38, tolerance = 0.005) + expect_equal(tot_sse(kmeans_fit_stats, new_data), 32763.7, tolerance = 0.005) + expect_equal(sse_ratio(kmeans_fit_stats, new_data), 15654.38 / 32763.7, tolerance = 0.005) +}) + +test_that("kmeans sihouette metrics work", { + kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% + set_engine("stats", algorithm = "MacQueen") %>% + fit(~., mtcars) + + kmeans_fit_ClusterR <- k_means(k = 3) %>% + set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% + fit(~., mtcars) + + new_data <- mtcars[1:4, ] + + dists <- mtcars %>% + as.matrix() %>% + dist() + + expect_equal( + names(silhouettes(kmeans_fit_stats, dists = dists)), + names(silhouettes(kmeans_fit_ClusterR, dists = dists)) + ) + + expect_equal(avg_silhouette(kmeans_fit_stats, dists = dists), 0.4993742, + tolerance = 0.005 + ) + expect_equal(avg_silhouette(kmeans_fit_ClusterR, dists = dists), 0.5473414, + tolerance = 0.005 + ) +}) + +test_that("kmeans sihouette metrics work with new data", { + kmeans_fit_stats <- k_means(k = mtcars[1:3, ]) %>% + set_engine("stats", algorithm = "MacQueen") %>% + fit(~., mtcars) + + kmeans_fit_ClusterR <- k_means(k = 3) %>% + set_engine("ClusterR", CENTROIDS = as.matrix(mtcars[1:3, ])) %>% + fit(~., mtcars) + + new_data <- mtcars[1:4, ] + + expect_equal( + names(silhouettes(kmeans_fit_stats, new_data = new_data)), + names(silhouettes(kmeans_fit_ClusterR, new_data = new_data)) + ) + + expect_equal(avg_silhouette(kmeans_fit_stats, new_data = new_data), 0.5176315, + tolerance = 0.005 + ) + expect_equal(avg_silhouette(kmeans_fit_ClusterR, new_data = new_data), 0.5176315, + tolerance = 0.005 + ) +}) diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index 0660a759..0ef70df0 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -1,10 +1,10 @@ -test_that("partition predictions", { - kmeans_fit <- - k_means(num_clusters = 3, mode = "partition") %>% - set_engine("stats") %>% - fit(~., data = mtcars) - - expect_true(tibble::is_tibble(predict(kmeans_fit, new_data = mtcars))) - expect_true(is.factor(tidyclust:::predict_cluster.cluster_fit(kmeans_fit, new_data = mtcars))) - expect_equal(names(predict(kmeans_fit, new_data = mtcars)), ".pred_cluster") -}) +test_that("partition predictions", { + kmeans_fit <- + k_means(num_clusters = 3, mode = "partition") %>% + set_engine("stats") %>% + fit(~., data = mtcars) + + expect_true(tibble::is_tibble(predict(kmeans_fit, new_data = mtcars))) + expect_true(is.factor(tidyclust:::predict_cluster.cluster_fit(kmeans_fit, new_data = mtcars))) + expect_equal(names(predict(kmeans_fit, new_data = mtcars)), ".pred_cluster") +}) diff --git a/tests/testthat/test-tune_cluster.R b/tests/testthat/test-tune_cluster.R index 8b958f73..e845848b 100644 --- a/tests/testthat/test-tune_cluster.R +++ b/tests/testthat/test-tune_cluster.R @@ -1,383 +1,383 @@ -test_that("tune recipe only", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod_no_tune) - pset <- hardhat::extract_parameter_set_dials(wflow) %>% - update(num_comp = dials::num_comp(c(1, 3))) - grid <- dials::grid_regular(pset, levels = 3) - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss, tot_sse) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - res_est <- tune::collect_metrics(res) - res_workflow <- res$.extracts[[1]]$.extracts[[1]] - - # Ensure tunable parameters in recipe are finalized - num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp - - expect_equal(res$id, folds$id) - expect_equal(nrow(res_est), nrow(grid) * 2) - expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) - expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) - expect_equal(res_est$n, rep(10, nrow(grid) * 2)) - expect_false(identical(num_comp, expr(tune()))) - expect_true(res_workflow$trained) -}) - -test_that("tune model only (with recipe)", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_no_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod) - pset <- hardhat::extract_parameter_set_dials(wflow) - grid <- dials::grid_regular(pset, levels = 3) - grid$num_clusters <- grid$num_clusters + 1 - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss, tot_sse) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - res_est <- tune::collect_metrics(res) - res_workflow <- res$.extracts[[1]]$.extracts[[1]] - - # Ensure tunable parameters in spec are finalized - num_clusters_quo <- res_workflow$fit$fit$spec$args$num_clusters - num_clusters <- rlang::quo_get_expr(num_clusters_quo) - - expect_equal(res$id, folds$id) - expect_equal(nrow(res_est), nrow(grid) * 2) - expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) - expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) - expect_equal(res_est$n, rep(10, nrow(grid) * 2)) - expect_false(identical(num_clusters, expr(tune()))) - expect_true(res_workflow$trained) -}) - -test_that("tune model and recipe", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod) - pset <- hardhat::extract_parameter_set_dials(wflow) %>% - update(num_comp = dials::num_comp(c(1, 3))) - grid <- dials::grid_regular(pset, levels = 3) - grid$num_clusters <- grid$num_clusters + 1 - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss, tot_sse) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - res_est <- tune::collect_metrics(res) - res_workflow <- res$.extracts[[1]]$.extracts[[1]] - - # Ensure tunable parameters in spec are finalized - num_clusters_quo <- res_workflow$fit$fit$spec$args$num_clusters - num_clusters <- rlang::quo_get_expr(num_clusters_quo) - - # Ensure tunable parameters in recipe are finalized - num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp - - expect_equal(res$id, folds$id) - expect_equal( - colnames(res$.metrics[[1]]), - c("num_clusters", "num_comp", ".metric", ".estimator", ".estimate", ".config") - ) - expect_equal(nrow(res_est), nrow(grid) * 2) - expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) - expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) - expect_equal(res_est$n, rep(10, nrow(grid) * 2)) - expect_false(identical(num_clusters, expr(tune()))) - expect_false(identical(num_comp, expr(tune()))) - expect_true(res_workflow$trained) -}) - -test_that('tune model and recipe (parallel_over = "everything")', { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod) - pset <- hardhat::extract_parameter_set_dials(wflow) %>% - update(num_comp = dials::num_comp(c(1, 3))) - grid <- dials::grid_regular(pset, levels = 3) - grid$num_clusters <- grid$num_clusters + 1 - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity, parallel_over = "everything") - metrics <- cluster_metric_set(tot_wss, tot_sse) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - res_est <- tune::collect_metrics(res) - - expect_equal(res$id, folds$id) - expect_equal( - colnames(res$.metrics[[1]]), - c("num_clusters", "num_comp", ".metric", ".estimator", ".estimate", ".config") - ) - expect_equal(nrow(res_est), nrow(grid) * 2) - expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) - expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) - expect_equal(res_est$n, rep(10, nrow(grid) * 2)) -}) - -test_that("tune model only - failure in formula is caught elegantly", { - helper_objects <- helper_objects_tidyclust() - - set.seed(7898) - data_folds <- rsample::vfold_cv(mtcars, v = 2) - - cars_grid <- tibble::tibble(num_clusters = 2) - - # these terms don't exist! - expect_snapshot( - cars_res <- tune_cluster( - helper_objects$kmeans_mod, - ~ z, - resamples = data_folds, - grid = cars_grid, - control = tune::control_grid(extract = function(x) { - 1 - }, save_pred = TRUE) - ) - ) - - notes <- cars_res$.notes - note <- notes[[1]]$note - - extracts <- cars_res$.extracts - predictions <- cars_res$.predictions - - expect_length(notes, 2L) - - # formula failed - no models run - expect_equal(extracts, list(NULL, NULL)) - expect_equal(predictions, list(NULL, NULL)) -}) - -test_that("argument order gives errors for recipes", { - helper_objects <- helper_objects_tidyclust() - - expect_snapshot(error = TRUE, { - tune_cluster( - helper_objects$rec_tune_1, - helper_objects$kmeans_mod_no_tune, - rsample::vfold_cv(mtcars, v = 2) - ) - }) -}) - -test_that("argument order gives errors for formula", { - helper_objects <- helper_objects_tidyclust() - - expect_snapshot(error = TRUE, { - tune_cluster( - mpg ~ ., - helper_objects$kmeans_mod_no_tune, - rsample::vfold_cv(mtcars, v = 2) - ) - }) -}) - -test_that("metrics can be NULL", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod_no_tune) - pset <- hardhat::extract_parameter_set_dials(wflow) %>% - update(num_comp = dials::num_comp(c(1, 3))) - grid <- dials::grid_regular(pset, levels = 3) - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss, tot_sse) - - set.seed(4400) - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control - ) - - set.seed(4400) - res1 <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - - expect_identical(res$.metrics, res1$.metrics) -}) - -test_that("tune recipe only", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod_no_tune) - pset <- hardhat::extract_parameter_set_dials(wflow) %>% - update(num_comp = dials::num_comp(c(1, 3))) - grid <- dials::grid_regular(pset, levels = 3) - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - res_est <- tune::collect_metrics(res) - res_workflow <- res$.extracts[[1]]$.extracts[[1]] - - # Ensure tunable parameters in recipe are finalized - num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp - - expect_equal(res$id, folds$id) - expect_equal(nrow(res_est), nrow(grid)) - expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) - expect_equal(res_est$n, rep(10, nrow(grid))) - expect_false(identical(num_comp, expr(tune()))) - expect_true(res_workflow$trained) -}) - -test_that("ellipses with tune_cluster", { - helper_objects <- helper_objects_tidyclust() - - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod_no_tune) - folds <- rsample::vfold_cv(mtcars) - expect_snapshot( - tune_cluster(wflow, resamples = folds, grid = 3, something = "wrong") - ) -}) - -test_that("determining the grid type", { - grid_1 <- expand.grid(a = 1:100, b = letters[1:2]) - expect_true(tune:::is_regular_grid(grid_1)) - expect_true(tune:::is_regular_grid(grid_1[-(1:10), ])) - expect_false(tune:::is_regular_grid(grid_1[-(1:100), ])) - set.seed(1932) - grid_2 <- data.frame(a = runif(length(letters)), b = letters) - expect_false(tune:::is_regular_grid(grid_2)) -}) - -test_that("retain extra attributes", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_no_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod) - pset <- hardhat::extract_parameter_set_dials(wflow) - grid <- dials::grid_regular(pset, levels = 3) - grid$num_clusters <- grid$num_clusters + 1 - folds <- rsample::vfold_cv(mtcars) - metrics <- cluster_metric_set(tot_wss, tot_sse) - res <- tune_cluster(wflow, resamples = folds, grid = grid, metrics = metrics) - - att <- attributes(res) - att_names <- names(att) - expect_true(any(att_names == "metrics")) - expect_true(any(att_names == "parameters")) - - expect_true(inherits(att$parameters, "parameters")) - expect_true(inherits(att$metrics, "cluster_metric_set")) -}) - -test_that("select_best() and show_best() works", { - helper_objects <- helper_objects_tidyclust() - - set.seed(4400) - wflow <- workflows::workflow() %>% - workflows::add_recipe(helper_objects$rec_no_tune_1) %>% - workflows::add_model(helper_objects$kmeans_mod) - pset <- hardhat::extract_parameter_set_dials(wflow) - grid <- dials::grid_regular(pset, levels = 10) - grid$num_clusters <- grid$num_clusters + 1 - folds <- rsample::vfold_cv(mtcars) - control <- tune::control_grid(extract = identity) - metrics <- cluster_metric_set(tot_wss, tot_sse) - - res <- tune_cluster( - wflow, - resamples = folds, - grid = grid, - control = control, - metrics = metrics - ) - - expect_snapshot(tmp <- tune::show_best(res)) - - expect_equal( - tune::show_best(res, metric = "tot_wss"), - tune::collect_metrics(res) %>% - dplyr::filter(.metric == "tot_wss") %>% - dplyr::slice_min(mean, n = 5, with_ties = FALSE) - ) - - expect_equal( - tune::show_best(res, metric = "tot_sse"), - tune::collect_metrics(res) %>% - dplyr::filter(.metric == "tot_sse") %>% - dplyr::slice_min(mean, n = 5, with_ties = FALSE) - ) - - expect_snapshot(tmp <- tune::select_best(res)) - - expect_equal( - tune::select_best(res, metric = "tot_wss"), - tune::collect_metrics(res) %>% - dplyr::filter(.metric == "tot_wss") %>% - dplyr::slice_min(mean, n = 1, with_ties = FALSE) %>% - dplyr::select(num_clusters, .config) - ) - - expect_equal( - tune::select_best(res, metric = "tot_sse"), - tune::collect_metrics(res) %>% - dplyr::filter(.metric == "tot_sse") %>% - dplyr::slice_min(mean, n = 1, with_ties = FALSE) %>% - dplyr::select(num_clusters, .config) - ) -}) - +test_that("tune recipe only", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod_no_tune) + pset <- hardhat::extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + grid <- dials::grid_regular(pset, levels = 3) + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss, tot_sse) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + res_est <- tune::collect_metrics(res) + res_workflow <- res$.extracts[[1]]$.extracts[[1]] + + # Ensure tunable parameters in recipe are finalized + num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp + + expect_equal(res$id, folds$id) + expect_equal(nrow(res_est), nrow(grid) * 2) + expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) + expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) + expect_equal(res_est$n, rep(10, nrow(grid) * 2)) + expect_false(identical(num_comp, expr(tune()))) + expect_true(res_workflow$trained) +}) + +test_that("tune model only (with recipe)", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_no_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod) + pset <- hardhat::extract_parameter_set_dials(wflow) + grid <- dials::grid_regular(pset, levels = 3) + grid$num_clusters <- grid$num_clusters + 1 + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss, tot_sse) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + res_est <- tune::collect_metrics(res) + res_workflow <- res$.extracts[[1]]$.extracts[[1]] + + # Ensure tunable parameters in spec are finalized + num_clusters_quo <- res_workflow$fit$fit$spec$args$num_clusters + num_clusters <- rlang::quo_get_expr(num_clusters_quo) + + expect_equal(res$id, folds$id) + expect_equal(nrow(res_est), nrow(grid) * 2) + expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) + expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) + expect_equal(res_est$n, rep(10, nrow(grid) * 2)) + expect_false(identical(num_clusters, expr(tune()))) + expect_true(res_workflow$trained) +}) + +test_that("tune model and recipe", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod) + pset <- hardhat::extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + grid <- dials::grid_regular(pset, levels = 3) + grid$num_clusters <- grid$num_clusters + 1 + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss, tot_sse) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + res_est <- tune::collect_metrics(res) + res_workflow <- res$.extracts[[1]]$.extracts[[1]] + + # Ensure tunable parameters in spec are finalized + num_clusters_quo <- res_workflow$fit$fit$spec$args$num_clusters + num_clusters <- rlang::quo_get_expr(num_clusters_quo) + + # Ensure tunable parameters in recipe are finalized + num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp + + expect_equal(res$id, folds$id) + expect_equal( + colnames(res$.metrics[[1]]), + c("num_clusters", "num_comp", ".metric", ".estimator", ".estimate", ".config") + ) + expect_equal(nrow(res_est), nrow(grid) * 2) + expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) + expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) + expect_equal(res_est$n, rep(10, nrow(grid) * 2)) + expect_false(identical(num_clusters, expr(tune()))) + expect_false(identical(num_comp, expr(tune()))) + expect_true(res_workflow$trained) +}) + +test_that('tune model and recipe (parallel_over = "everything")', { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod) + pset <- hardhat::extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + grid <- dials::grid_regular(pset, levels = 3) + grid$num_clusters <- grid$num_clusters + 1 + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity, parallel_over = "everything") + metrics <- cluster_metric_set(tot_wss, tot_sse) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + res_est <- tune::collect_metrics(res) + + expect_equal(res$id, folds$id) + expect_equal( + colnames(res$.metrics[[1]]), + c("num_clusters", "num_comp", ".metric", ".estimator", ".estimate", ".config") + ) + expect_equal(nrow(res_est), nrow(grid) * 2) + expect_equal(sum(res_est$.metric == "tot_sse"), nrow(grid)) + expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) + expect_equal(res_est$n, rep(10, nrow(grid) * 2)) +}) + +test_that("tune model only - failure in formula is caught elegantly", { + helper_objects <- helper_objects_tidyclust() + + set.seed(7898) + data_folds <- rsample::vfold_cv(mtcars, v = 2) + + cars_grid <- tibble::tibble(num_clusters = 2) + + # these terms don't exist! + expect_snapshot( + cars_res <- tune_cluster( + helper_objects$kmeans_mod, + ~ z, + resamples = data_folds, + grid = cars_grid, + control = tune::control_grid(extract = function(x) { + 1 + }, save_pred = TRUE) + ) + ) + + notes <- cars_res$.notes + note <- notes[[1]]$note + + extracts <- cars_res$.extracts + predictions <- cars_res$.predictions + + expect_length(notes, 2L) + + # formula failed - no models run + expect_equal(extracts, list(NULL, NULL)) + expect_equal(predictions, list(NULL, NULL)) +}) + +test_that("argument order gives errors for recipes", { + helper_objects <- helper_objects_tidyclust() + + expect_snapshot(error = TRUE, { + tune_cluster( + helper_objects$rec_tune_1, + helper_objects$kmeans_mod_no_tune, + rsample::vfold_cv(mtcars, v = 2) + ) + }) +}) + +test_that("argument order gives errors for formula", { + helper_objects <- helper_objects_tidyclust() + + expect_snapshot(error = TRUE, { + tune_cluster( + mpg ~ ., + helper_objects$kmeans_mod_no_tune, + rsample::vfold_cv(mtcars, v = 2) + ) + }) +}) + +test_that("metrics can be NULL", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod_no_tune) + pset <- hardhat::extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + grid <- dials::grid_regular(pset, levels = 3) + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss, tot_sse) + + set.seed(4400) + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control + ) + + set.seed(4400) + res1 <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + + expect_identical(res$.metrics, res1$.metrics) +}) + +test_that("tune recipe only", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod_no_tune) + pset <- hardhat::extract_parameter_set_dials(wflow) %>% + update(num_comp = dials::num_comp(c(1, 3))) + grid <- dials::grid_regular(pset, levels = 3) + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + res_est <- tune::collect_metrics(res) + res_workflow <- res$.extracts[[1]]$.extracts[[1]] + + # Ensure tunable parameters in recipe are finalized + num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp + + expect_equal(res$id, folds$id) + expect_equal(nrow(res_est), nrow(grid)) + expect_equal(sum(res_est$.metric == "tot_wss"), nrow(grid)) + expect_equal(res_est$n, rep(10, nrow(grid))) + expect_false(identical(num_comp, expr(tune()))) + expect_true(res_workflow$trained) +}) + +test_that("ellipses with tune_cluster", { + helper_objects <- helper_objects_tidyclust() + + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod_no_tune) + folds <- rsample::vfold_cv(mtcars) + expect_snapshot( + tune_cluster(wflow, resamples = folds, grid = 3, something = "wrong") + ) +}) + +test_that("determining the grid type", { + grid_1 <- expand.grid(a = 1:100, b = letters[1:2]) + expect_true(tune:::is_regular_grid(grid_1)) + expect_true(tune:::is_regular_grid(grid_1[-(1:10), ])) + expect_false(tune:::is_regular_grid(grid_1[-(1:100), ])) + set.seed(1932) + grid_2 <- data.frame(a = runif(length(letters)), b = letters) + expect_false(tune:::is_regular_grid(grid_2)) +}) + +test_that("retain extra attributes", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_no_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod) + pset <- hardhat::extract_parameter_set_dials(wflow) + grid <- dials::grid_regular(pset, levels = 3) + grid$num_clusters <- grid$num_clusters + 1 + folds <- rsample::vfold_cv(mtcars) + metrics <- cluster_metric_set(tot_wss, tot_sse) + res <- tune_cluster(wflow, resamples = folds, grid = grid, metrics = metrics) + + att <- attributes(res) + att_names <- names(att) + expect_true(any(att_names == "metrics")) + expect_true(any(att_names == "parameters")) + + expect_true(inherits(att$parameters, "parameters")) + expect_true(inherits(att$metrics, "cluster_metric_set")) +}) + +test_that("select_best() and show_best() works", { + helper_objects <- helper_objects_tidyclust() + + set.seed(4400) + wflow <- workflows::workflow() %>% + workflows::add_recipe(helper_objects$rec_no_tune_1) %>% + workflows::add_model(helper_objects$kmeans_mod) + pset <- hardhat::extract_parameter_set_dials(wflow) + grid <- dials::grid_regular(pset, levels = 10) + grid$num_clusters <- grid$num_clusters + 1 + folds <- rsample::vfold_cv(mtcars) + control <- tune::control_grid(extract = identity) + metrics <- cluster_metric_set(tot_wss, tot_sse) + + res <- tune_cluster( + wflow, + resamples = folds, + grid = grid, + control = control, + metrics = metrics + ) + + expect_snapshot(tmp <- tune::show_best(res)) + + expect_equal( + tune::show_best(res, metric = "tot_wss"), + tune::collect_metrics(res) %>% + dplyr::filter(.metric == "tot_wss") %>% + dplyr::slice_min(mean, n = 5, with_ties = FALSE) + ) + + expect_equal( + tune::show_best(res, metric = "tot_sse"), + tune::collect_metrics(res) %>% + dplyr::filter(.metric == "tot_sse") %>% + dplyr::slice_min(mean, n = 5, with_ties = FALSE) + ) + + expect_snapshot(tmp <- tune::select_best(res)) + + expect_equal( + tune::select_best(res, metric = "tot_wss"), + tune::collect_metrics(res) %>% + dplyr::filter(.metric == "tot_wss") %>% + dplyr::slice_min(mean, n = 1, with_ties = FALSE) %>% + dplyr::select(num_clusters, .config) + ) + + expect_equal( + tune::select_best(res, metric = "tot_sse"), + tune::collect_metrics(res) %>% + dplyr::filter(.metric == "tot_sse") %>% + dplyr::slice_min(mean, n = 1, with_ties = FALSE) %>% + dplyr::select(num_clusters, .config) + ) +}) +