From c9edff9a422427d7783e0e83bbbe346db9fec50f Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 27 Jan 2025 11:42:58 -0800 Subject: [PATCH] air whole package --- R/aaa.R | 46 ++- R/append.R | 54 ++-- R/arguments.R | 13 +- R/cluster_spec.R | 9 +- R/compat-purrr.R | 15 +- R/convert_data.R | 30 +- R/dials-params.R | 8 +- R/engine_docs.R | 10 +- R/engines.R | 19 +- R/extract_cluster_assignment.R | 18 +- R/extract_fit_summary.R | 25 +- R/extract_parameter_set_dials.R | 2 +- R/fit.R | 102 ++++--- R/fit_helpers.R | 6 +- R/hier_clust.R | 42 +-- R/hier_clust_data.R | 9 +- R/k_means.R | 50 ++-- R/k_means_data.R | 38 ++- R/load_ns.R | 15 +- R/metric-aaa.R | 30 +- R/metric-helpers.R | 20 +- R/metric-silhouette.R | 38 ++- R/metric-sse.R | 85 ++++-- R/misc.R | 14 +- R/model_object_docs.R | 1 - R/predict.R | 30 +- R/predict_helpers.R | 38 ++- R/print.R | 3 +- R/pull.R | 10 +- R/reconcile_clusterings.R | 10 +- R/required_pkgs.R | 2 +- R/translate.R | 18 +- R/tunable.R | 26 +- R/tune_args.R | 16 +- R/tune_cluster.R | 275 ++++++++++-------- R/tune_helpers.R | 100 ++++--- dev/cross_val_kmeans.R | 17 +- dev/test_hc.R | 6 +- tests/testthat/helper-tidyclust-package.R | 12 +- tests/testthat/test-augment.R | 28 +- tests/testthat/test-cluster_metric_set.R | 14 +- tests/testthat/test-extract_centroids.R | 2 +- .../test-extract_cluster_assignment.R | 2 +- tests/testthat/test-extract_fit_summary.R | 2 +- tests/testthat/test-hier_clust-stats.R | 15 +- tests/testthat/test-k_means-clustMixType.R | 1 - tests/testthat/test-k_means-clusterR.R | 8 +- tests/testthat/test-k_means-klaR.R | 14 +- tests/testthat/test-k_means.R | 6 +- tests/testthat/test-k_means_diagnostics.R | 9 +- tests/testthat/test-predict.R | 2 +- tests/testthat/test-predict_formats.R | 4 +- tests/testthat/test-reconcile_clusterings.R | 16 +- tests/testthat/test-tune_cluster.R | 21 +- tests/testthat/test-workflows.R | 4 +- 55 files changed, 855 insertions(+), 555 deletions(-) diff --git a/R/aaa.R b/R/aaa.R index 6865a2da..db456411 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -3,17 +3,47 @@ utils::globalVariables( c( - ".", "..object", ".cluster", ".iter_config", ".iter_model", - ".iter_preprocessor", ".msg_model", ".submodels", "call_info", "cluster", - "component", "component_id", "compute_intercept", "data", "dist", "engine", - "engine2", "exposed", "func", "id", "iteration", "lab", "name", "neighbor", - "new_data", "object", "orig_label", "original", "predictor_indicators", - "remove_intercept", "seed", "sil_width", "splits", "tunable", "type", - "value", "x", "y" + ".", + "..object", + ".cluster", + ".iter_config", + ".iter_model", + ".iter_preprocessor", + ".msg_model", + ".submodels", + "call_info", + "cluster", + "component", + "component_id", + "compute_intercept", + "data", + "dist", + "engine", + "engine2", + "exposed", + "func", + "id", + "iteration", + "lab", + "name", + "neighbor", + "new_data", + "object", + "orig_label", + "original", + "predictor_indicators", + "remove_intercept", + "seed", + "sil_width", + "splits", + "tunable", + "type", + "value", + "x", + "y" ) ) - release_bullets <- function() { c( "Run `knit_engine_docs()` and `devtools::document()` to update docs" diff --git a/R/append.R b/R/append.R index 7ca6d68f..b9519ffc 100644 --- a/R/append.R +++ b/R/append.R @@ -1,9 +1,11 @@ # https://github.com/tidymodels/tune/blob/main/R/pull.R#L136 -append_predictions <- function(collection, - predictions, - split, - control, - .config = NULL) { +append_predictions <- function( + collection, + predictions, + split, + control, + .config = NULL +) { if (!control$save_pred) { return(NULL) } @@ -27,14 +29,16 @@ append_predictions <- function(collection, dplyr::bind_rows(collection, predictions) } -append_metrics <- function(workflow, - collection, - predictions, - metrics, - param_names, - event_level, - split, - .config = NULL) { +append_metrics <- function( + workflow, + collection, + predictions, + metrics, + param_names, + event_level, + split, + .config = NULL +) { if (inherits(predictions, "try-error")) { return(collection) } @@ -54,20 +58,22 @@ append_metrics <- function(workflow, dplyr::bind_rows(collection, tmp_est) } -append_extracts <- function(collection, - workflow, - grid, - split, - ctrl, - .config = NULL) { +append_extracts <- function( + collection, + workflow, + grid, + split, + ctrl, + .config = NULL +) { extracts <- grid %>% - dplyr::bind_cols(labels(split)) %>% - dplyr::mutate( - .extracts = list( - extract_details(workflow, ctrl$extract) + dplyr::bind_cols(labels(split)) %>% + dplyr::mutate( + .extracts = list( + extract_details(workflow, ctrl$extract) + ) ) - ) if (!rlang::is_null(.config)) { extracts <- cbind(extracts, .config) diff --git a/R/arguments.R b/R/arguments.R index eb44bebd..3ffafc9f 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -8,10 +8,12 @@ check_eng_args <- function(args, obj, core_args) { if (length(common_args) > 0) { args <- args[!(names(args) %in% common_args)] common_args <- paste0(common_args, collapse = ", ") - rlang::warn(glue::glue( - "The following arguments cannot be manually modified ", - "and were removed: {common_args}." - )) + rlang::warn( + glue::glue( + "The following arguments cannot be manually modified ", + "and were removed: {common_args}." + ) + ) } args } @@ -25,7 +27,8 @@ make_x_call <- function(object, target) { } object$method$fit$args[[unname(data_args["x"])]] <- - switch(target, + switch( + target, none = rlang::expr(x), data.frame = rlang::expr(maybe_data_frame(x)), matrix = rlang::expr(maybe_matrix(x)), diff --git a/R/cluster_spec.R b/R/cluster_spec.R index 6fe8fa0f..cb718d44 100644 --- a/R/cluster_spec.R +++ b/R/cluster_spec.R @@ -11,15 +11,18 @@ #' @keywords internal new_cluster_spec <- function(cls, args, eng_args, mode, method, engine) { modelenv::check_spec_mode_engine_val( - model = cls, + model = cls, mode = mode, eng = engine, call = rlang::caller_env() ) out <- list( - args = args, eng_args = eng_args, - mode = mode, method = method, engine = engine + args = args, + eng_args = eng_args, + mode = mode, + method = method, + engine = engine ) class(out) <- make_classes_tidyclust(cls) out <- modelenv::new_unsupervised_spec(out) diff --git a/R/compat-purrr.R b/R/compat-purrr.R index e60efc8c..1f042c69 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -79,11 +79,16 @@ imap <- function(.x, .f, ...) { pmap <- function(.l, .f, ...) { .f <- as.function(.f) args <- .rlang_purrr_args_recycle(.l) - do.call("mapply", c( - FUN = list(quote(.f)), - args, MoreArgs = quote(list(...)), - SIMPLIFY = FALSE, USE.NAMES = FALSE - )) + do.call( + "mapply", + c( + FUN = list(quote(.f)), + args, + MoreArgs = quote(list(...)), + SIMPLIFY = FALSE, + USE.NAMES = FALSE + ) + ) } .rlang_purrr_args_recycle <- function(args) { lengths <- map_int(args, length) diff --git a/R/convert_data.R b/R/convert_data.R index 96e42aa5..30e7a009 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -32,13 +32,15 @@ #' @inheritParams fit.cluster_spec #' @rdname convert_helpers #' @keywords internal -.convert_form_to_x_fit <- function(formula, - data, - ..., - na.action = na.omit, - indicators = "traditional", - composition = "data.frame", - remove_intercept = TRUE) { +.convert_form_to_x_fit <- function( + formula, + data, + ..., + na.action = na.omit, + indicators = "traditional", + composition = "data.frame", + remove_intercept = TRUE +) { if (!(composition %in% c("data.frame", "matrix"))) { rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") } @@ -155,9 +157,7 @@ local_one_hot_contrasts <- function(frame = rlang::caller_env()) { #' @inheritParams .convert_form_to_x_fit #' @rdname convert_helpers #' @keywords internal -.convert_x_to_form_fit <- function(x, - weights = NULL, - remove_intercept = TRUE) { +.convert_x_to_form_fit <- function(x, weights = NULL, remove_intercept = TRUE) { if (is.vector(x)) { rlang::abort("`x` cannot be a vector.") } @@ -212,10 +212,12 @@ make_formula <- function(x, short = TRUE) { #' @inheritParams predict.cluster_fit #' @rdname convert_helpers #' @keywords internal -.convert_form_to_x_new <- function(object, - new_data, - na.action = stats::na.pass, - composition = "data.frame") { +.convert_form_to_x_new <- function( + object, + new_data, + na.action = stats::na.pass, + composition = "data.frame" +) { if (!(composition %in% c("data.frame", "matrix"))) { rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") } diff --git a/R/dials-params.R b/R/dials-params.R index 0b569eda..dc39702d 100644 --- a/R/dials-params.R +++ b/R/dials-params.R @@ -40,6 +40,12 @@ linkage_method <- function(values = values_linkage_method) { #' @rdname linkage_method #' @export values_linkage_method <- c( - "ward.D", "ward.D2", "single", "complete", "average", "mcquitty", "median", + "ward.D", + "ward.D2", + "single", + "complete", + "average", + "mcquitty", + "median", "centroid" ) diff --git a/R/engine_docs.R b/R/engine_docs.R index 0a8d37d6..e74164fa 100644 --- a/R/engine_docs.R +++ b/R/engine_docs.R @@ -19,17 +19,17 @@ knit_engine_docs <- function(pattern = NULL) { } outputs <- gsub("Rmd$", "md", files) - res <- map2(files, outputs, ~ try(knitr::knit(.x, .y), silent = TRUE)) - is_error <- map_lgl(res, ~ inherits(.x, "try-error")) + res <- map2(files, outputs, ~try(knitr::knit(.x, .y), silent = TRUE)) + is_error <- map_lgl(res, ~inherits(.x, "try-error")) if (any(is_error)) { # In some cases where there are issues, the md file is empty. errors <- res[which(is_error)] error_nms <- basename(files)[which(is_error)] errors <- - map_chr(errors, ~ cli::ansi_strip(as.character(.x))) %>% - map2_chr(error_nms, ~ paste0(.y, ": ", .x)) %>% - map_chr(~ gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE)) + map_chr(errors, ~cli::ansi_strip(as.character(.x))) %>% + map2_chr(error_nms, ~paste0(.y, ": ", .x)) %>% + map_chr(~gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE)) cat("There were failures duing knitting:\n\n") cat(errors) cat("\n\n") diff --git a/R/engines.R b/R/engines.R index 3a1004d6..e464f61b 100644 --- a/R/engines.R +++ b/R/engines.R @@ -31,15 +31,16 @@ set_engine.cluster_spec <- function(object, engine, ...) { stop_missing_engine <- function(cls, call = rlang::caller_env()) { info <- modelenv::get_from_env(cls) %>% - dplyr::group_by(mode) %>% - dplyr::summarize( - msg = paste0( - unique(mode), " {", - paste0(unique(engine), collapse = ", "), - "}" - ), - .groups = "drop" - ) + dplyr::group_by(mode) %>% + dplyr::summarize( + msg = paste0( + unique(mode), + " {", + paste0(unique(engine), collapse = ", "), + "}" + ), + .groups = "drop" + ) if (nrow(info) == 0) { rlang::abort(glue::glue("No known engines for `{cls}()`."), call = call) } diff --git a/R/extract_cluster_assignment.R b/R/extract_cluster_assignment.R index af18ce35..c975d22e 100644 --- a/R/extract_cluster_assignment.R +++ b/R/extract_cluster_assignment.R @@ -111,9 +111,11 @@ extract_cluster_assignment.kmodes <- function(object, ...) { } #' @export -extract_cluster_assignment.hclust <- function(object, - ..., - call = rlang::caller_env(0)) { +extract_cluster_assignment.hclust <- function( + object, + ..., + call = rlang::caller_env(0) +) { # if k or h is passed in the dots, use those. Otherwise, use attributes # from original model specification args <- list(...) @@ -159,10 +161,12 @@ extract_cluster_assignment.hclust <- function(object, # ------------------------------------------------------------------------------ -cluster_assignment_tibble <- function(clusters, - n_clusters, - ..., - prefix = "Cluster_") { +cluster_assignment_tibble <- function( + clusters, + n_clusters, + ..., + prefix = "Cluster_" +) { reorder_clusts <- order(union(unique(clusters), seq_len(n_clusters))) names <- paste0(prefix, seq_len(n_clusters)) res <- names[reorder_clusts][clusters] diff --git a/R/extract_fit_summary.R b/R/extract_fit_summary.R index ce6526a5..d77bbb9c 100644 --- a/R/extract_fit_summary.R +++ b/R/extract_fit_summary.R @@ -23,8 +23,11 @@ extract_fit_summary <- function(object, ...) { } #' @export -extract_fit_summary.cluster_spec <- function(object, ..., - call = rlang::caller_env(n = 0)) { +extract_fit_summary.cluster_spec <- function( + object, + ..., + call = rlang::caller_env(n = 0) +) { rlang::abort( paste( "This function requires a fitted model.", @@ -68,9 +71,11 @@ extract_fit_summary.kmeans <- function(object, ..., prefix = "Cluster_") { } #' @export -extract_fit_summary.KMeansCluster <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.KMeansCluster <- function( + object, + ..., + prefix = "Cluster_" +) { names <- paste0(prefix, seq_len(nrow(object$centroids))) names <- factor(names) @@ -93,9 +98,7 @@ extract_fit_summary.KMeansCluster <- function(object, } #' @export -extract_fit_summary.kproto <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.kproto <- function(object, ..., prefix = "Cluster_") { names <- paste0(prefix, seq_len(nrow(object$centers))) names <- factor(names) @@ -118,9 +121,7 @@ extract_fit_summary.kproto <- function(object, } #' @export -extract_fit_summary.kmodes <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.kmodes <- function(object, ..., prefix = "Cluster_") { names <- paste0(prefix, seq_len(nrow(object$modes))) names <- factor(names) @@ -166,7 +167,7 @@ extract_fit_summary.hclust <- function(object, ...) { sse_within_total_total <- map2_dbl( by_clust$data, seq_len(n_clust), - ~ sum(Rfast::dista(centroids[.y, ], .x)) + ~sum(Rfast::dista(centroids[.y, ], .x)) ) list( diff --git a/R/extract_parameter_set_dials.R b/R/extract_parameter_set_dials.R index a3241b6b..620f8921 100644 --- a/R/extract_parameter_set_dials.R +++ b/R/extract_parameter_set_dials.R @@ -10,7 +10,7 @@ extract_parameter_set_dials.cluster_spec <- function(x, ...) { all_args, by = c("name", "source", "component") ) %>% - dplyr::mutate(object = map(call_info, eval_call_info)) + dplyr::mutate(object = map(call_info, eval_call_info)) dials::parameters_constr( res$name, diff --git a/R/fit.R b/R/fit.R index 51d55957..f5a6cfae 100644 --- a/R/fit.R +++ b/R/fit.R @@ -85,11 +85,13 @@ #' @return A fitted [`cluster_fit`] object. #' @export #' @export fit.cluster_spec -fit.cluster_spec <- function(object, - formula, - data, - control = control_cluster(), - ...) { +fit.cluster_spec <- function( + object, + formula, + data, + control = control_cluster(), + ... +) { if (object$mode == "unknown") { rlang::abort("Please set the mode in the model specification.") } @@ -133,32 +135,30 @@ fit.cluster_spec <- function(object, # used here, `fit_interface_formula` will determine if a # translation has to be made if the model interface is x/y/ res <- - switch(interfaces, + switch( + interfaces, # homogeneous combinations: - formula_formula = - form_form( - object = object, - control = control, - env = eval_env - ), + formula_formula = form_form( + object = object, + control = control, + env = eval_env + ), # heterogenous combinations - formula_matrix = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - formula_data.frame = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), + formula_matrix = form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + formula_data.frame = form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), rlang::abort(glue::glue("{interfaces} is unknown.")) ) model_classes <- class(res$fit) @@ -270,36 +270,34 @@ fit_xy.cluster_spec <- # used here, `fit_interface_formula` will determine if a # translation has to be made if the model interface is x/y/ res <- - switch(interfaces, + switch( + interfaces, # homogeneous combinations: matrix_matrix = , - data.frame_matrix = - x_x( - object = object, - env = eval_env, - control = control, - target = "matrix", - ... - ), + data.frame_matrix = x_x( + object = object, + env = eval_env, + control = control, + target = "matrix", + ... + ), data.frame_data.frame = , - matrix_data.frame = - x_x( - object = object, - env = eval_env, - control = control, - target = "data.frame", - ... - ), + matrix_data.frame = x_x( + object = object, + env = eval_env, + control = control, + target = "data.frame", + ... + ), # heterogenous combinations matrix_formula = , - data.frame_formula = - x_form( - object = object, - env = eval_env, - control = control, - ... - ), + data.frame_formula = x_form( + object = object, + env = eval_env, + control = control, + ... + ), rlang::abort(glue::glue("{interfaces} is unknown.")) ) model_classes <- class(res$fit) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 136a1bf1..b2141571 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -41,7 +41,7 @@ form_form <- function(object, control, env, ...) { form_x <- function(object, control, env, target = "none", ...) { encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) indicators <- encoding_info %>% dplyr::pull(predictor_indicators) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) @@ -74,7 +74,7 @@ x_x <- function(object, env, control, target = "none", y = NULL, ...) { } encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) if (remove_intercept) { @@ -120,7 +120,7 @@ x_x <- function(object, env, control, target = "none", y = NULL, ...) { x_form <- function(object, env, control, ...) { encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) diff --git a/R/hier_clust.R b/R/hier_clust.R index 1d5e898c..ed475342 100644 --- a/R/hier_clust.R +++ b/R/hier_clust.R @@ -56,11 +56,13 @@ #' hier_clust() #' @export hier_clust <- - function(mode = "partition", - engine = "stats", - num_clusters = NULL, - cut_height = NULL, - linkage_method = "complete") { + function( + mode = "partition", + engine = "stats", + num_clusters = NULL, + cut_height = NULL, + linkage_method = "complete" + ) { args <- list( num_clusters = enquo(num_clusters), cut_height = enquo(cut_height), @@ -95,15 +97,19 @@ print.hier_clust <- function(x, ...) { #' @method update hier_clust #' @rdname tidyclust_update #' @export -update.hier_clust <- function(object, - parameters = NULL, - num_clusters = NULL, - cut_height = NULL, - linkage_method = NULL, - fresh = FALSE, ...) { +update.hier_clust <- function( + object, + parameters = NULL, + num_clusters = NULL, + cut_height = NULL, + linkage_method = NULL, + fresh = FALSE, + ... +) { eng_args <- parsnip::update_engine_parameters( object$eng_args, - fresh = fresh, ... + fresh = fresh, + ... ) if (!is.null(parameters)) { @@ -182,11 +188,13 @@ translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) { #' @return A dendrogram #' @keywords internal #' @export -.hier_clust_fit_stats <- function(x, - num_clusters = NULL, - cut_height = NULL, - linkage_method = NULL, - dist_fun = Rfast::Dist) { +.hier_clust_fit_stats <- function( + x, + num_clusters = NULL, + cut_height = NULL, + linkage_method = NULL, + dist_fun = Rfast::Dist +) { dmat <- dist_fun(x) res <- stats::hclust(stats::as.dist(dmat), method = linkage_method) attr(res, "num_clusters") <- num_clusters diff --git a/R/hier_clust_data.R b/R/hier_clust_data.R index 695b7ffe..2c9010a8 100644 --- a/R/hier_clust_data.R +++ b/R/hier_clust_data.R @@ -81,11 +81,10 @@ make_hier_clust <- function() { pre = NULL, post = NULL, func = c(fun = ".hier_clust_predict_stats"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) } diff --git a/R/k_means.R b/R/k_means.R index 3595827b..b79d3933 100644 --- a/R/k_means.R +++ b/R/k_means.R @@ -39,9 +39,7 @@ #' k_means() #' @export k_means <- - function(mode = "partition", - engine = "stats", - num_clusters = NULL) { + function(mode = "partition", engine = "stats", num_clusters = NULL) { args <- list( num_clusters = enquo(num_clusters) ) @@ -80,13 +78,17 @@ translate_tidyclust.k_means <- function(x, engine = x$engine, ...) { #' @method update k_means #' @rdname tidyclust_update #' @export -update.k_means <- function(object, - parameters = NULL, - num_clusters = NULL, - fresh = FALSE, ...) { +update.k_means <- function( + object, + parameters = NULL, + num_clusters = NULL, + fresh = FALSE, + ... +) { eng_args <- parsnip::update_engine_parameters( object$eng_args, - fresh = fresh, ... + fresh = fresh, + ... ) if (!is.null(parameters)) { @@ -170,17 +172,19 @@ check_args.k_means <- function(object) { #' obs_per_cluster, between.SS_DIV_total.SS #' @keywords internal #' @export -.k_means_fit_ClusterR <- function(data, - clusters, - num_init = 1, - max_iters = 100, - initializer = "kmeans++", - fuzzy = FALSE, - verbose = FALSE, - CENTROIDS = NULL, - tol = 1e-04, - tol_optimal_init = 0.3, - seed = 1) { +.k_means_fit_ClusterR <- function( + data, + clusters, + num_init = 1, + max_iters = 100, + initializer = "kmeans++", + fuzzy = FALSE, + verbose = FALSE, + CENTROIDS = NULL, + tol = 1e-04, + tol_optimal_init = 0.3, + seed = 1 +) { if (is.null(clusters)) { rlang::abort( "Please specify `num_clust` to be able to fit specification.", @@ -259,8 +263,8 @@ check_args.k_means <- function(object) { c( "Engine `clustMixType` requires both numeric and categorical \\ predictors.", - "x" = "Only numeric predictors where used.", - "i" = "Try using the `stats` engine with \\ + "x" = "Only numeric predictors where used.", + "i" = "Try using the `stats` engine with \\ {.code mod %>% set_engine(\"stats\")}." ), call = call("fit") @@ -271,8 +275,8 @@ check_args.k_means <- function(object) { c( "Engine `clustMixType` requires both numeric and categorical \\ predictors.", - "x" = "Only categorical predictors where used.", - "i" = "Try using the `klaR` engine with \\ + "x" = "Only categorical predictors where used.", + "i" = "Try using the `klaR` engine with \\ {.code mod %>% set_engine(\"klaR\")}." ), call = call("fit") diff --git a/R/k_means_data.R b/R/k_means_data.R index 73f6be6d..09eae346 100644 --- a/R/k_means_data.R +++ b/R/k_means_data.R @@ -64,11 +64,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_stats"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) @@ -131,15 +130,14 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_ClusterR"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) -# ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- modelenv::set_model_engine("k_means", "partition", "clustMixType") modelenv::set_dependency( @@ -198,11 +196,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_clustMixType"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) @@ -265,11 +262,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_klaR"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) } diff --git a/R/load_ns.R b/R/load_ns.R index 715053df..4439d4b3 100644 --- a/R/load_ns.R +++ b/R/load_ns.R @@ -37,6 +37,17 @@ load_namespace <- function(x) { } infra_pkgs <- c( - "tune", "recipes", "tidyclust", "yardstick", "purrr", "dplyr", "tibble", - "dials", "rsample", "workflows", "tidyr", "rlang", "vctrs" + "tune", + "recipes", + "tidyclust", + "yardstick", + "purrr", + "dplyr", + "tibble", + "dials", + "rsample", + "workflows", + "tidyr", + "rlang", + "vctrs" ) diff --git a/R/metric-aaa.R b/R/metric-aaa.R index 9adfea01..d32a97bc 100644 --- a/R/metric-aaa.R +++ b/R/metric-aaa.R @@ -61,10 +61,12 @@ cluster_metric_set <- function(...) { if (fn_cls == "cluster_metric") { make_cluster_metric_function(fns) } else { - rlang::abort(paste0( - "Internal error: `validate_function_class()` should have ", - "errored on unknown classes." - )) + rlang::abort( + paste0( + "Internal error: `validate_function_class()` should have ", + "errored on unknown classes." + ) + ) } } @@ -183,8 +185,11 @@ make_cluster_metric_function <- function(fns) { ) calls <- lapply(fns, rlang::call2, !!!call_args) metric_list <- mapply( - FUN = eval_safely, calls, names(calls), - SIMPLIFY = FALSE, USE.NAMES = FALSE + FUN = eval_safely, + calls, + names(calls), + SIMPLIFY = FALSE, + USE.NAMES = FALSE ) dplyr::bind_rows(metric_list) } @@ -197,11 +202,14 @@ make_cluster_metric_function <- function(fns) { } eval_safely <- function(expr, expr_nm, data = NULL, env = rlang::caller_env()) { - tryCatch(expr = { - rlang::eval_tidy(expr, data = data, env = env) - }, error = function(e) { - rlang::abort(paste0("In metric: `", expr_nm, "`\n", conditionMessage(e))) - }) + tryCatch( + expr = { + rlang::eval_tidy(expr, data = data, env = env) + }, + error = function(e) { + rlang::abort(paste0("In metric: `", expr_nm, "`\n", conditionMessage(e))) + } + ) } #' @export diff --git a/R/metric-helpers.R b/R/metric-helpers.R index 57c0d139..f7304e7a 100644 --- a/R/metric-helpers.R +++ b/R/metric-helpers.R @@ -8,8 +8,12 @@ #' @param dist_fun A custom distance functions. #' #' @return A list -prep_data_dist <- function(object, new_data = NULL, - dists = NULL, dist_fun = Rfast::Dist) { +prep_data_dist <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = Rfast::Dist +) { # Sihouettes requires a distance matrix if (is.null(new_data) && is.null(dists)) { rlang::abort( @@ -45,11 +49,13 @@ prep_data_dist <- function(object, new_data = NULL, dists <- dist_fun(new_data) } - return(list( - clusters = clusters, - data = new_data, - dists = dists - )) + return( + list( + clusters = clusters, + data = new_data, + dists = dists + ) + ) } #' Computes distance from observations to centroids diff --git a/R/metric-silhouette.R b/R/metric-silhouette.R index 2e823a50..2783d92e 100644 --- a/R/metric-silhouette.R +++ b/R/metric-silhouette.R @@ -23,8 +23,12 @@ #' #' silhouette(kmeans_fit, dists = dists) #' @export -silhouette <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist) { +silhouette <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = Rfast::Dist +) { if (inherits(object, "cluster_spec")) { rlang::abort( paste( @@ -43,7 +47,8 @@ silhouette <- function(object, new_data = NULL, dists = NULL, if (!inherits(sil, "silhouette")) { res <- tibble::tibble( cluster = preproc$clusters, - neighbor = factor(rep(NA_character_, length(preproc$clusters)), + neighbor = factor( + rep(NA_character_, length(preproc$clusters)), levels = levels(preproc$clusters) ), sil_width = NA_real_ @@ -113,8 +118,13 @@ silhouette_avg.cluster_spec <- function(object, ...) { #' @export #' @rdname silhouette_avg -silhouette_avg.cluster_fit <- function(object, new_data = NULL, dists = NULL, - dist_fun = NULL, ...) { +silhouette_avg.cluster_fit <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { dist_fun <- Rfast::Dist } @@ -134,12 +144,22 @@ silhouette_avg.workflow <- silhouette_avg.cluster_fit #' @export #' @rdname silhouette_avg -silhouette_avg_vec <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { +silhouette_avg_vec <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = Rfast::Dist, + ... +) { silhouette_avg_impl(object, new_data, dists, dist_fun, ...) } -silhouette_avg_impl <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { +silhouette_avg_impl <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = Rfast::Dist, + ... +) { mean(silhouette(object, new_data, dists, dist_fun, ...)$sil_width) } diff --git a/R/metric-sse.R b/R/metric-sse.R index 84aa0ebb..dd49c176 100644 --- a/R/metric-sse.R +++ b/R/metric-sse.R @@ -47,10 +47,12 @@ sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) { res <- dist_to_centroids %>% tibble::as_tibble(.name_repair = "minimal") %>% - map(~ c( - .cluster = which.min(.x), - dist = min(.x)^2 - )) %>% + map( + ~c( + .cluster = which.min(.x), + dist = min(.x)^2 + ) + ) %>% dplyr::bind_rows() %>% dplyr::mutate( .cluster = factor(paste0("Cluster_", .cluster)) @@ -112,8 +114,12 @@ sse_within_total.cluster_spec <- function(object, ...) { #' @export #' @rdname sse_within_total -sse_within_total.cluster_fit <- function(object, new_data = NULL, - dist_fun = NULL, ...) { +sse_within_total.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { dist_fun <- Rfast::dista } @@ -133,13 +139,21 @@ sse_within_total.workflow <- sse_within_total.cluster_fit #' @export #' @rdname sse_within_total -sse_within_total_vec <- function(object, new_data = NULL, - dist_fun = Rfast::dista, ...) { +sse_within_total_vec <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { sse_within_total_impl(object, new_data, dist_fun, ...) } -sse_within_total_impl <- function(object, new_data = NULL, - dist_fun = Rfast::dista, ...) { +sse_within_total_impl <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE) } @@ -187,8 +201,12 @@ sse_total.cluster_spec <- function(object, ...) { #' @export #' @rdname sse_total -sse_total.cluster_fit <- function(object, new_data = NULL, dist_fun = NULL, - ...) { +sse_total.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { dist_fun <- Rfast::dista } @@ -208,12 +226,21 @@ sse_total.workflow <- sse_total.cluster_fit #' @export #' @rdname sse_total -sse_total_vec <- function(object, new_data = NULL, dist_fun = Rfast::dista, ...) { +sse_total_vec <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { sse_total_impl(object, new_data, dist_fun, ...) } -sse_total_impl <- function(object, new_data = NULL, dist_fun = Rfast::dista, - ...) { +sse_total_impl <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { # Preprocess data before computing distances if appropriate if (inherits(object, "workflow") && !is.null(new_data)) { new_data <- extract_post_preprocessor(object, new_data) @@ -276,8 +303,12 @@ sse_ratio.cluster_spec <- function(object, ...) { #' @export #' @rdname sse_ratio -sse_ratio.cluster_fit <- function(object, new_data = NULL, - dist_fun = NULL, ...) { +sse_ratio.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { dist_fun <- Rfast::dista } @@ -296,17 +327,21 @@ sse_ratio.workflow <- sse_ratio.cluster_fit #' @export #' @rdname sse_ratio -sse_ratio_vec <- function(object, - new_data = NULL, - dist_fun = Rfast::dista, - ...) { +sse_ratio_vec <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { sse_ratio_impl(object, new_data, dist_fun, ...) } -sse_ratio_impl <- function(object, - new_data = NULL, - dist_fun = Rfast::dista, - ...) { +sse_ratio_impl <- function( + object, + new_data = NULL, + dist_fun = Rfast::dista, + ... +) { sse_within_total_vec(object, new_data, dist_fun) / sse_total_vec(object, new_data, dist_fun) } diff --git a/R/misc.R b/R/misc.R index a79caffc..d2598ead 100644 --- a/R/misc.R +++ b/R/misc.R @@ -12,13 +12,15 @@ check_args.default <- function(object) { check_spec_pred_type <- function(object, type) { if (!spec_has_pred_type(object, type)) { possible_preds <- names(object$spec$method$pred) - rlang::abort(c( - glue::glue("No {type} prediction method available for this model."), - glue::glue( - "Value for `type` should be one of: ", - glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ") + rlang::abort( + c( + glue::glue("No {type} prediction method available for this model."), + glue::glue( + "Value for `type` should be one of: ", + glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ") + ) ) - )) + ) } invisible(NULL) } diff --git a/R/model_object_docs.R b/R/model_object_docs.R index c851467a..8f7d973d 100644 --- a/R/model_object_docs.R +++ b/R/model_object_docs.R @@ -185,4 +185,3 @@ NULL #' @rdname cluster_fit #' @name cluster_fit NULL - diff --git a/R/predict.R b/R/predict.R index 0451a3e8..6b78028b 100644 --- a/R/predict.R +++ b/R/predict.R @@ -88,11 +88,13 @@ #' @method predict cluster_fit #' @export predict.cluster_fit #' @export -predict.cluster_fit <- function(object, - new_data, - type = NULL, - opts = list(), - ...) { +predict.cluster_fit <- function( + object, + new_data, + type = NULL, + opts = list(), + ... +) { if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") return(NULL) @@ -103,23 +105,22 @@ predict.cluster_fit <- function(object, type <- check_pred_type(object, type) - res <- switch(type, + res <- switch( + type, cluster = predict_cluster(object = object, new_data = new_data, ...), raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), rlang::abort(glue::glue("I don't know about type = '{type}'")) ) - res <- switch(type, - cluster = format_cluster(res), - res - ) + res <- switch(type, cluster = format_cluster(res), res) res } check_pred_type <- function(object, type, ...) { if (is.null(type)) { type <- - switch(object$spec$mode, + switch( + object$spec$mode, partition = "cluster", rlang::abort("`type` should be 'cluster'.") ) @@ -154,13 +155,14 @@ prepare_data <- function(object, new_data) { remove_intercept <- modelenv::get_encoding(class(object$spec)[1]) %>% - dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% - dplyr::pull(remove_intercept) + dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% + dplyr::pull(remove_intercept) if (remove_intercept && any(grepl("Intercept", names(new_data)))) { new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) } - switch(fit_interface, + switch( + fit_interface, none = new_data, data.frame = as.data.frame(new_data), matrix = as.matrix(new_data), diff --git a/R/predict_helpers.R b/R/predict_helpers.R index f4780ae4..cef573c3 100644 --- a/R/predict_helpers.R +++ b/R/predict_helpers.R @@ -18,15 +18,23 @@ make_predictions <- function(x, prefix, n_clusters) { make_predictions(clusters, prefix, n_clusters) } -.k_means_predict_clustMixType <- function(object, new_data, prefix = "Cluster_") { +.k_means_predict_clustMixType <- function( + object, + new_data, + prefix = "Cluster_" +) { clusters <- predict(object, new_data)$cluster n_clusters <- length(object$size) make_predictions(clusters, prefix, n_clusters) } -.k_means_predict_klaR <- function(object, new_data, prefix = "Cluster_", - ties = c("first", "last", "random")) { +.k_means_predict_klaR <- function( + object, + new_data, + prefix = "Cluster_", + ties = c("first", "last", "random") +) { ties <- rlang::arg_match(ties) modes <- object$modes @@ -42,7 +50,6 @@ make_predictions <- function(x, prefix, n_clusters) { which_min <- which(misses == min(misses)) - if (length(which_min) == 1) { clusters[i] <- which_min } else { @@ -58,7 +65,12 @@ make_predictions <- function(x, prefix, n_clusters) { make_predictions(clusters, prefix, n_modes) } -.hier_clust_predict_stats <- function(object, new_data, ..., prefix = "Cluster_") { +.hier_clust_predict_stats <- function( + object, + new_data, + ..., + prefix = "Cluster_" +) { linkage_method <- object$method new_data <- as.matrix(new_data) @@ -75,7 +87,8 @@ make_predictions <- function(x, prefix, n_clusters) { ## complete, single, average, and median linkage_methods are basically the ## same idea, just different summary distance to cluster - cluster_dist_fun <- switch(linkage_method, + cluster_dist_fun <- switch( + linkage_method, "single" = min, "complete" = max, "average" = mean, @@ -111,7 +124,7 @@ make_predictions <- function(x, prefix, n_clusters) { d_means <- map( seq_len(n_clust), - ~ t( + ~t( t(training_data[clusters$.cluster == cluster_names[.x], ]) - cluster_centers[.x, ] ) @@ -122,8 +135,10 @@ make_predictions <- function(x, prefix, n_clusters) { function(new_obs) { map( seq_len(n_clust), - ~ t(t(training_data[clusters$.cluster == cluster_names[.x], ]) - - new_data[new_obs, ]) + ~t( + t(training_data[clusters$.cluster == cluster_names[.x], ]) - + new_data[new_obs, ] + ) ) } ) @@ -134,8 +149,9 @@ make_predictions <- function(x, prefix, n_clusters) { d_new_list, function(v) { map2_dbl( - d_means, v, - ~ sum((n * .x + .y)^2 / (n + 1)^2 - .x^2) + d_means, + v, + ~sum((n * .x + .y)^2 / (n + 1)^2 - .x^2) ) } ) diff --git a/R/print.R b/R/print.R index 90cf9f4e..ee8ad397 100644 --- a/R/print.R +++ b/R/print.R @@ -5,7 +5,8 @@ print.cluster_fit <- function(x, ...) { cat("tidyclust cluster object\n\n") if (!is.na(x$elapsed[["elapsed"]])) { cat( - "Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), + "Fit time: ", + prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n" ) } diff --git a/R/pull.R b/R/pull.R index 3761dfe0..7dd33013 100644 --- a/R/pull.R +++ b/R/pull.R @@ -32,20 +32,20 @@ pulley <- function(resamples, res, col) { if (all(map_lgl(res, inherits, "simpleError"))) { res <- resamples %>% - dplyr::mutate(col = map(splits, ~NULL)) %>% - stats::setNames(c(names(resamples), col)) + dplyr::mutate(col = map(splits, ~NULL)) %>% + stats::setNames(c(names(resamples), col)) return(res) } id_cols <- grep("^id", names(resamples), value = TRUE) resamples <- dplyr::arrange(resamples, !!!rlang::syms(id_cols)) - pulled_vals <- dplyr::bind_rows(map(res, ~ .x[[col]])) + pulled_vals <- dplyr::bind_rows(map(res, ~.x[[col]])) if (nrow(pulled_vals) == 0) { res <- resamples %>% - dplyr::mutate(col = map(splits, ~NULL)) %>% - stats::setNames(c(names(resamples), col)) + dplyr::mutate(col = map(splits, ~NULL)) %>% + stats::setNames(c(names(resamples), col)) return(res) } diff --git a/R/reconcile_clusterings.R b/R/reconcile_clusterings.R index 41d1ed33..a9ba296b 100644 --- a/R/reconcile_clusterings.R +++ b/R/reconcile_clusterings.R @@ -31,10 +31,12 @@ #' factor2 <- c("Dog", "Dog", "Cat", "Dog", "Fish", "Parrot") #' reconcile_clusterings_mapping(factor1, factor2, one_to_one = FALSE) #' @export -reconcile_clusterings_mapping <- function(primary, - alternative, - one_to_one = TRUE, - optimize = "accuracy") { +reconcile_clusterings_mapping <- function( + primary, + alternative, + one_to_one = TRUE, + optimize = "accuracy" +) { rlang::check_installed("RcppHungarian") if (length(primary) != length(alternative)) { rlang::abort( diff --git a/R/required_pkgs.R b/R/required_pkgs.R index c87309fd..ec596cb2 100644 --- a/R/required_pkgs.R +++ b/R/required_pkgs.R @@ -16,7 +16,7 @@ get_pkgs <- function(x, infra) { cls <- class(x)[1] pkgs <- modelenv::get_from_env(paste0(cls, "_pkgs")) %>% - dplyr::filter(engine == x$engine) + dplyr::filter(engine == x$engine) res <- pkgs$pkg[[1]] if (length(res) == 0) { res <- character(0) diff --git a/R/translate.R b/R/translate.R index 65d69ee8..24367020 100644 --- a/R/translate.R +++ b/R/translate.R @@ -103,20 +103,20 @@ get_cluster_spec <- function(model, mode, engine) { res <- list() res$libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) %>% - dplyr::filter(engine == !!engine) %>% - .[["pkg"]] %>% - .[[1]] + dplyr::filter(engine == !!engine) %>% + .[["pkg"]] %>% + .[[1]] res$fit <- rlang::env_get(m_env, paste0(model, "_fit")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::pull(value) %>% - .[[1]] + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + .[[1]] pred_code <- rlang::env_get(m_env, paste0(model, "_predict")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::select(-engine, -mode) + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) res$pred <- pred_code[["value"]] names(res$pred) <- pred_code$type @@ -139,7 +139,7 @@ deharmonize <- function(args, key) { parsn <- tibble::tibble(exposed = names(args), order = seq_along(args)) merged <- dplyr::left_join(parsn, key, by = "exposed") %>% - dplyr::arrange(order) + dplyr::arrange(order) # TODO correct for bad merge? names(args) <- merged$original diff --git a/R/tunable.R b/R/tunable.R index da49bd95..f4c73427 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -16,7 +16,9 @@ tunable.cluster_spec <- function(x, ...) { rlang::abort( paste( "The `tidyclust` model database doesn't know about the arguments for ", - "model `", mod_type(x), "`. Was it registered?", + "model `", + mod_type(x), + "`. Was it registered?", sep = "" ), call. = FALSE @@ -25,17 +27,17 @@ tunable.cluster_spec <- function(x, ...) { arg_vals <- mod_env[[arg_name]] %>% - dplyr::filter(engine == x$engine) %>% - dplyr::select(name = exposed, call_info = func) %>% - dplyr::full_join( - tibble::tibble(name = c(names(x$args), names(x$eng_args))), - by = "name" - ) %>% - dplyr::mutate( - source = "cluster_spec", - component = mod_type(x), - component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") - ) + dplyr::filter(engine == x$engine) %>% + dplyr::select(name = exposed, call_info = func) %>% + dplyr::full_join( + tibble::tibble(name = c(names(x$args), names(x$eng_args))), + by = "name" + ) %>% + dplyr::mutate( + source = "cluster_spec", + component = mod_type(x), + component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") + ) if (nrow(arg_vals) > 0) { has_info <- map_lgl(arg_vals$call_info, is.null) diff --git a/R/tune_args.R b/R/tune_args.R index e0cac86b..1c275d32 100644 --- a/R/tune_args.R +++ b/R/tune_args.R @@ -124,13 +124,15 @@ tune_id <- function(x) { NA_character_ } -tune_tbl <- function(name = character(), - tunable = logical(), - id = character(), - source = character(), - component = character(), - component_id = character(), - full = FALSE) { +tune_tbl <- function( + name = character(), + tunable = logical(), + id = character(), + source = character(), + component = character(), + component_id = character(), + full = FALSE +) { complete_id <- id[!is.na(id)] dups <- duplicated(complete_id) if (any(dups)) { diff --git a/R/tune_cluster.R b/R/tune_cluster.R index 0cf06310..dd58ce8f 100644 --- a/R/tune_cluster.R +++ b/R/tune_cluster.R @@ -71,15 +71,23 @@ tune_cluster.default <- function(object, ...) { #' @export #' @rdname tune_cluster -tune_cluster.cluster_spec <- function(object, preprocessor, resamples, ..., - param_info = NULL, grid = 10, - metrics = NULL, - control = tune::control_grid()) { +tune_cluster.cluster_spec <- function( + object, + preprocessor, + resamples, + ..., + param_info = NULL, + grid = 10, + metrics = NULL, + control = tune::control_grid() +) { if (rlang::is_missing(preprocessor) || !tune::is_preprocessor(preprocessor)) { - rlang::abort(paste( - "To tune a model spec, you must preprocess", - "with a formula or recipe" - )) + rlang::abort( + paste( + "To tune a model spec, you must preprocess", + "with a formula or recipe" + ) + ) } tune::empty_ellipses(...) @@ -106,9 +114,15 @@ tune_cluster.cluster_spec <- function(object, preprocessor, resamples, ..., #' @export #' @rdname tune_cluster -tune_cluster.workflow <- function(object, resamples, ..., param_info = NULL, - grid = 10, metrics = NULL, - control = tune::control_grid()) { +tune_cluster.workflow <- function( + object, + resamples, + ..., + param_info = NULL, + grid = 10, + metrics = NULL, + control = tune::control_grid() +) { tune::empty_ellipses(...) control <- parsnip::condense_control(control, tune::control_grid()) @@ -131,13 +145,15 @@ tune_cluster.workflow <- function(object, resamples, ..., param_info = NULL, # ------------------------------------------------------------------------------ -tune_cluster_workflow <- function(workflow, - resamples, - grid = 10, - metrics = NULL, - pset = NULL, - control = NULL, - rng = TRUE) { +tune_cluster_workflow <- function( + workflow, + resamples, + grid = 10, + metrics = NULL, + pset = NULL, + control = NULL, + rng = TRUE +) { tune::check_rset(resamples) metrics <- check_metrics(metrics, workflow) @@ -185,12 +201,14 @@ tune_cluster_workflow <- function(workflow, ) } -tune_cluster_loop <- function(resamples, - grid, - workflow, - metrics, - control, - rng) { +tune_cluster_loop <- function( + resamples, + grid, + workflow, + metrics, + control, + rng +) { `%op%` <- get_operator(control$allow_par, workflow) `%:%` <- foreach::`%:%` @@ -222,67 +240,71 @@ tune_cluster_loop <- function(resamples, # created by `eval()`. This causes the handler to run much too early. By evaluating in # a local environment, we prevent `defer()`/`on.exit()` from finding the short-lived # context of `%op%`. Instead it looks all the way up here to register the handler. - + results <- local({ suppressPackageStartupMessages( foreach::foreach( - split = splits, - seed = seeds, - .packages = packages, - .errorhandling = "pass" - ) %op% { - # Extract internal function from tune namespace - tune_cluster_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - tune_cluster_loop_iter_safely( - split = split, - grid_info = grid_info, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - ) - }) + split = splits, + seed = seeds, + .packages = packages, + .errorhandling = "pass" + ) %op% + { + # Extract internal function from tune namespace + tune_cluster_loop_iter_safely <- utils::getFromNamespace( + x = "tune_cluster_loop_iter_safely", + ns = "tidyclust" + ) + + tune_cluster_loop_iter_safely( + split = split, + grid_info = grid_info, + workflow = workflow, + metrics = metrics, + control = control, + seed = seed + ) + } + ) + }) } else if (identical(parallel_over, "everything")) { seeds <- generate_seeds(rng, n_resamples * n_grid_info) - results <- local(suppressPackageStartupMessages( - foreach::foreach( - iteration = iterations, - split = splits, - .packages = packages, - .errorhandling = "pass" - ) %:% + results <- local( + suppressPackageStartupMessages( foreach::foreach( - row = rows, - seed = slice_seeds(seeds, iteration, n_grid_info), + iteration = iterations, + split = splits, .packages = packages, - .errorhandling = "pass", - .combine = iter_combine - ) %op% { - # Extract internal function from tidyclust namespace - tune_grid_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - grid_info_row <- vctrs::vec_slice(grid_info, row) - - tune_grid_loop_iter_safely( - split = split, - grid_info = grid_info_row, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - )) + .errorhandling = "pass" + ) %:% + foreach::foreach( + row = rows, + seed = slice_seeds(seeds, iteration, n_grid_info), + .packages = packages, + .errorhandling = "pass", + .combine = iter_combine + ) %op% + { + # Extract internal function from tidyclust namespace + tune_grid_loop_iter_safely <- utils::getFromNamespace( + x = "tune_cluster_loop_iter_safely", + ns = "tidyclust" + ) + + grid_info_row <- vctrs::vec_slice(grid_info, row) + + tune_grid_loop_iter_safely( + split = split, + grid_info = grid_info_row, + workflow = workflow, + metrics = metrics, + control = control, + seed = seed + ) + } + ) + ) } else { rlang::abort("Internal error: Invalid `parallel_over`.") } @@ -311,7 +333,8 @@ compute_grid_info <- function(workflow, grid) { if (any_parameters_preprocessor) { compute_grid_info_model_and_preprocessor( workflow, - grid, parameters_model + grid, + parameters_model ) } else { compute_grid_info_model(workflow, grid, parameters_model) @@ -331,12 +354,14 @@ compute_grid_info <- function(workflow, grid) { } } -tune_cluster_loop_iter <- function(split, - grid_info, - workflow, - metrics, - control, - seed) { +tune_cluster_loop_iter <- function( + split, + grid_info, + workflow, + metrics, + control, + seed +) { load_pkgs(workflow) load_namespace(control$pkgs) @@ -541,12 +566,14 @@ tune_cluster_loop_iter <- function(split, ) } -tune_cluster_loop_iter_safely <- function(split, - grid_info, - workflow, - metrics, - control, - seed) { +tune_cluster_loop_iter_safely <- function( + split, + grid_info, + workflow, + metrics, + control, + seed +) { tune_cluster_loop_iter_wrapper <- super_safely(tune_cluster_loop_iter) time <- proc.time() @@ -617,7 +644,8 @@ super_safely <- function(fn) { expr = tryCatch( expr = list( result = fn(...), - error = NULL, warnings = warnings + error = NULL, + warnings = warnings ), error = handle_error ), @@ -636,37 +664,45 @@ compute_grid_info_model <- function(workflow, grid, parameters_model) { msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L) msgs_preprocessor <- rep(msgs_preprocessor, times = n_fit_models) msgs_model <- new_msgs_model( - i = seq_fit_models, n = n_fit_models, + i = seq_fit_models, + n = n_fit_models, msgs_preprocessor = msgs_preprocessor ) iter_configs <- compute_config_ids(out, "Preprocessor1") out <- tibble::add_column( - .data = out, .iter_preprocessor = 1L, + .data = out, + .iter_preprocessor = 1L, .before = 1L ) out <- tibble::add_column( - .data = out, .msg_preprocessor = msgs_preprocessor, + .data = out, + .msg_preprocessor = msgs_preprocessor, .after = ".iter_preprocessor" ) out <- tibble::add_column( - .data = out, .iter_model = seq_fit_models, + .data = out, + .iter_model = seq_fit_models, .after = ".msg_preprocessor" ) out <- tibble::add_column( - .data = out, .iter_config = iter_configs, + .data = out, + .iter_config = iter_configs, .after = ".iter_model" ) out <- tibble::add_column( - .data = out, .msg_model = msgs_model, + .data = out, + .msg_model = msgs_model, .after = ".iter_config" ) out } # https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L484 -compute_grid_info_model_and_preprocessor <- function(workflow, - grid, - parameters_model) { +compute_grid_info_model_and_preprocessor <- function( + workflow, + grid, + parameters_model +) { parameter_names_model <- parameters_model[["id"]] # Nest model parameters, keep preprocessor parameters outside @@ -751,9 +787,7 @@ compute_grid_info_model_and_preprocessor <- function(workflow, } # https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L359 -compute_grid_info_preprocessor <- function(workflow, - grid, - parameters_model) { +compute_grid_info_preprocessor <- function(workflow, grid, parameters_model) { out <- grid n_preprocessors <- nrow(out) @@ -825,7 +859,8 @@ check_metrics <- function(x, object) { mode <- extract_spec_parsnip(object)$mode if (is.null(x)) { - switch(mode, + switch( + mode, partition = { x <- cluster_metric_set(sse_within_total, sse_total) }, @@ -857,10 +892,12 @@ check_metrics <- function(x, object) { } # https://github.com/tidymodels/tune/blob/main/R/checks.R#L144 -check_parameters <- function(workflow, - pset = NULL, - data, - grid_names = character(0)) { +check_parameters <- function( + workflow, + pset = NULL, + data, + grid_names = character(0) +) { if (is.null(pset)) { pset <- hardhat::extract_parameter_set_dials(workflow) } @@ -934,10 +971,12 @@ check_workflow <- function(x, pset = NULL, check_dials = FALSE) { incompl <- dials::has_unknowns(pset$object) if (any(incompl)) { - rlang::abort(paste0( - "The workflow has arguments whose ranges are not finalized: ", - paste0("'", pset$id[incompl], "'", collapse = ", ") - )) + rlang::abort( + paste0( + "The workflow has arguments whose ranges are not finalized: ", + paste0("'", pset$id[incompl], "'", collapse = ", ") + ) + ) } } @@ -952,11 +991,13 @@ check_param_objects <- function(pset) { params <- map_lgl(pset$object, inherits, "param") if (!all(params)) { - rlang::abort(paste0( - "The workflow has arguments to be tuned that are missing some ", - "parameter objects: ", - paste0("'", pset$id[!params], "'", collapse = ", ") - )) + rlang::abort( + paste0( + "The workflow has arguments to be tuned that are missing some ", + "parameter objects: ", + paste0("'", pset$id[!params], "'", collapse = ", ") + ) + ) } invisible(pset) } diff --git a/R/tune_helpers.R b/R/tune_helpers.R index 86597404..5c387dd8 100644 --- a/R/tune_helpers.R +++ b/R/tune_helpers.R @@ -6,13 +6,20 @@ new_bare_tibble <- function(x, ..., class = character()) { } is_cataclysmic <- function(x) { - is_err <- map_lgl(x$.metrics, inherits, c( - "simpleError", - "error" - )) + is_err <- map_lgl( + x$.metrics, + inherits, + c( + "simpleError", + "error" + ) + ) if (any(!is_err)) { - is_good <- map_lgl(x$.metrics[!is_err], ~ tibble::is_tibble(.x) && - nrow(.x) > 0) + is_good <- map_lgl( + x$.metrics[!is_err], + ~tibble::is_tibble(.x) && + nrow(.x) > 0 + ) is_err[!is_err] <- !is_good } all(is_err) @@ -31,10 +38,13 @@ set_workflow <- function(workflow, control) { "setting `save_workflow = FALSE`." ) cols <- get_tidyclust_colors() - msg <- strwrap(msg, prefix = paste0( - cols$symbol$info(cli::symbol$info), - " " - )) + msg <- strwrap( + msg, + prefix = paste0( + cols$symbol$info(cli::symbol$info), + " " + ) + ) msg <- cols$message$info(paste0(msg, collapse = "\n")) rlang::inform(msg) } @@ -46,8 +56,14 @@ set_workflow <- function(workflow, control) { } # https://github.com/tidymodels/tune/blob/main/R/tune_results.R -new_tune_results <- function(x, parameters, metrics, - rset_info, ..., class = character()) { +new_tune_results <- function( + x, + parameters, + metrics, + rset_info, + ..., + class = character() +) { new_bare_tibble( x = x, parameters = parameters, @@ -92,8 +108,11 @@ new_grid_info_resamples <- function() { ) iter_config <- list("Preprocessor1_Model1") out <- tibble::tibble( - .iter_preprocessor = 1L, .msg_preprocessor = msgs_preprocessor, - .iter_model = 1L, .iter_config = iter_config, .msg_model = msgs_model, + .iter_preprocessor = 1L, + .msg_preprocessor = msgs_preprocessor, + .iter_model = 1L, + .iter_config = iter_config, + .msg_model = msgs_model, .submodels = list(list()) ) out @@ -153,7 +172,7 @@ min_grid.cluster_spec <- function(x, grid, ...) { blank_submodels <- function(grid) { grid %>% dplyr::mutate( - .submodels = map(seq_along(nrow(grid)), ~ list()) + .submodels = map(seq_along(nrow(grid)), ~list()) ) %>% dplyr::mutate_if(is.factor, as.character) } @@ -218,9 +237,7 @@ catcher <- function(expr) { signals <<- append(signals, list(cnd)) rlang::cnd_muffle(cnd) } - res <- try(withCallingHandlers(warning = add_cond, expr), - silent = TRUE - ) + res <- try(withCallingHandlers(warning = add_cond, expr), silent = TRUE) list(res = res, signals = signals) } @@ -232,16 +249,17 @@ siren <- function(x, type = "info") { symb <- dplyr::case_when( type == "warning" ~ tidyclust_color$symbol$warning("!"), type == "go" ~ tidyclust_color$symbol$go(cli::symbol$pointer), - type == "danger" ~ tidyclust_color$symbol$danger("x"), type == - "success" ~ tidyclust_color$symbol$success(tidyclust_symbol$success), + type == "danger" ~ tidyclust_color$symbol$danger("x"), + type == "success" ~ + tidyclust_color$symbol$success(tidyclust_symbol$success), type == "info" ~ tidyclust_color$symbol$info("i") ) msg <- dplyr::case_when( type == "warning" ~ tidyclust_color$message$warning(msg), - type == "go" ~ tidyclust_color$message$go(msg), type == "danger" ~ - tidyclust_color$message$danger(msg), type == "success" ~ - tidyclust_color$message$success(msg), type == "info" ~ - tidyclust_color$message$info(msg) + type == "go" ~ tidyclust_color$message$go(msg), + type == "danger" ~ tidyclust_color$message$danger(msg), + type == "success" ~ tidyclust_color$message$success(msg), + type == "info" ~ tidyclust_color$message$info(msg) ) if (inherits(msg, "character")) { msg <- as.character(msg) @@ -254,15 +272,17 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { control2$verbose <- TRUE wrn <- res$signals if (length(wrn) > 0) { - wrn_msg <- map_chr(wrn, ~ .x$message) + wrn_msg <- map_chr(wrn, ~.x$message) wrn_msg <- unique(wrn_msg) wrn_msg <- paste(wrn_msg, collapse = ", ") wrn_msg <- tibble::tibble( - location = loc, type = "warning", + location = loc, + type = "warning", note = wrn_msg ) notes <- dplyr::bind_rows(notes, wrn_msg) - wrn_msg <- glue::glue_collapse(paste0(loc, ": ", wrn_msg$note), + wrn_msg <- glue::glue_collapse( + paste0(loc, ": ", wrn_msg$note), width = options()$width - 5 ) tune_log(control2, split, wrn_msg, type = "warning") @@ -271,11 +291,13 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { err_msg <- as.character(attr(res$res, "condition")) err_msg <- gsub("\n$", "", err_msg) err_msg <- tibble::tibble( - location = loc, type = "error", + location = loc, + type = "error", note = err_msg ) notes <- dplyr::bind_rows(notes, err_msg) - err_msg <- glue::glue_collapse(paste0(loc, ": ", err_msg$note), + err_msg <- glue::glue_collapse( + paste0(loc, ": ", err_msg$note), width = options()$width - 5 ) tune_log(control2, split, err_msg, type = "danger") @@ -319,7 +341,7 @@ merger <- function(x, y, ...) { grid_name <- colnames(y) if (inherits(x, "recipe")) { updater <- update_recipe - step_ids <- map_chr(x$steps, ~ .x$id) + step_ids <- map_chr(x$steps, ~.x$id) } else { updater <- update_model step_ids <- NULL @@ -332,7 +354,7 @@ merger <- function(x, y, ...) { dplyr::mutate( ..object = map( seq_along(nrow(y)), - ~ updater(y[.x, ], x, pset, step_ids, grid_name) + ~updater(y[.x, ], x, pset, step_ids, grid_name) ) ) %>% dplyr::select(x = ..object) @@ -440,8 +462,8 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL) { # Regular predictions tmp_res <- stats::predict(model, x_vals, type = type_iter) %>% - dplyr::mutate(.row = orig_rows) %>% - cbind(grid, row.names = NULL) + dplyr::mutate(.row = orig_rows) %>% + cbind(grid, row.names = NULL) if (!is.null(submodels)) { submod_length <- lengths(submodels) @@ -570,16 +592,18 @@ slice_seeds <- function(x, i, n) { iter_combine <- function(...) { results <- list(...) - metrics <- map(results, ~ .x[[".metrics"]]) - extracts <- map(results, ~ .x[[".extracts"]]) - predictions <- map(results, ~ .x[[".predictions"]]) - notes <- map(results, ~ .x[[".notes"]]) + metrics <- map(results, ~.x[[".metrics"]]) + extracts <- map(results, ~.x[[".extracts"]]) + predictions <- map(results, ~.x[[".predictions"]]) + notes <- map(results, ~.x[[".notes"]]) metrics <- vctrs::vec_c(!!!metrics) extracts <- vctrs::vec_c(!!!extracts) predictions <- vctrs::vec_c(!!!predictions) notes <- vctrs::vec_c(!!!notes) list( - .metrics = metrics, .extracts = extracts, .predictions = predictions, + .metrics = metrics, + .extracts = extracts, + .predictions = predictions, .notes = notes ) } diff --git a/dev/cross_val_kmeans.R b/dev/cross_val_kmeans.R index 55f972ab..eae5e935 100644 --- a/dev/cross_val_kmeans.R +++ b/dev/cross_val_kmeans.R @@ -17,12 +17,10 @@ res <- data.frame( ) for (k in 2:10) { - km <- k_means(k = k) %>% set_engine("stats") for (i in 1:5) { - tmp_train <- training(cvs$splits[[i]]) tmp_test <- testing(cvs$splits[[i]]) @@ -36,11 +34,8 @@ for (k in 2:10) { sil <- km_fit %>% silhouette_avg(tmp_test) - res <- rbind(res, - c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) - + res <- rbind(res, c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) } - } res %>% @@ -63,14 +58,12 @@ res <- data.frame( ) for (k in 2:10) { - km <- k_means(k = k) %>% set_engine("stats") full_fit <- km %>% fit(~., data = ir) for (i in 1:10) { - tmp_train <- training(cvs$splits[[i]]) tmp_test <- testing(cvs$splits[[i]]) @@ -87,11 +80,11 @@ for (k in 2:10) { acc <- accuracy(thing, clusters_1, clusters_2) f1 <- f_meas(thing, clusters_1, clusters_2) - res <- rbind(res, - c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate)) - + res <- rbind( + res, + c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate) + ) } - } res %>% diff --git a/dev/test_hc.R b/dev/test_hc.R index 006ef3d2..963d5c0e 100644 --- a/dev/test_hc.R +++ b/dev/test_hc.R @@ -1,14 +1,14 @@ library(tidyverse) library(celery) -ir <- iris[,-5] +ir <- iris[, -5] hclust(dist(ir)) bob <- hclust_fit(ir) hc <- hier_clust(k = 3) %>% - fit(~ ., data = ir) + fit(~., data = ir) km <- k_means(k = 3) %>% fit(~., data = ir) @@ -20,7 +20,7 @@ thing <- tibble( ) thing %>% - count(hc_c,truth) + count(hc_c, truth) cutree(hc$fit, k = 3) diff --git a/tests/testthat/helper-tidyclust-package.R b/tests/testthat/helper-tidyclust-package.R index 6635a112..aad0010a 100644 --- a/tests/testthat/helper-tidyclust-package.R +++ b/tests/testthat/helper-tidyclust-package.R @@ -1,14 +1,18 @@ -new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 +new_rng_snapshots <- utils::compareVersion( + "3.6.0", + as.character(getRversion()) +) > + 0 helper_objects_tidyclust <- function() { rec_tune_1 <- recipes::recipe(~., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) %>% - recipes::step_pca(recipes::all_predictors(), num_comp = tune()) + recipes::step_normalize(recipes::all_predictors()) %>% + recipes::step_pca(recipes::all_predictors(), num_comp = tune()) rec_no_tune_1 <- recipes::recipe(~., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) + recipes::step_normalize(recipes::all_predictors()) kmeans_mod_no_tune <- k_means(num_clusters = 2) diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index 48cf3951..0b09c75e 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -9,8 +9,18 @@ test_that("partition models", { expect_equal( colnames(augment(reg_form, head(mtcars))), c( - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster" + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ".pred_cluster" ) ) expect_equal(nrow(augment(reg_form, head(mtcars))), 6) @@ -18,8 +28,18 @@ test_that("partition models", { expect_equal( colnames(augment(reg_xy, head(mtcars))), c( - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster" + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ".pred_cluster" ) ) expect_equal(nrow(augment(reg_xy, head(mtcars))), 6) diff --git a/tests/testthat/test-cluster_metric_set.R b/tests/testthat/test-cluster_metric_set.R index bbfabe85..02c36ab3 100644 --- a/tests/testthat/test-cluster_metric_set.R +++ b/tests/testthat/test-cluster_metric_set.R @@ -4,13 +4,23 @@ test_that("cluster_metric_set works", { kmeans_fit <- fit(kmeans_spec, ~., mtcars) - my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within_total, silhouette_avg) + my_metrics <- cluster_metric_set( + sse_ratio, + sse_total, + sse_within_total, + silhouette_avg + ) exp_res <- tibble::tibble( .metric = c("sse_ratio", "sse_total", "sse_within_total", "silhouette_avg"), .estimator = "standard", .estimate = vapply( - list(sse_ratio_vec, sse_total_vec, sse_within_total_vec, silhouette_avg_vec), + list( + sse_ratio_vec, + sse_total_vec, + sse_within_total_vec, + silhouette_avg_vec + ), function(x) x(kmeans_fit, new_data = mtcars), FUN.VALUE = numeric(1) ) diff --git a/tests/testthat/test-extract_centroids.R b/tests/testthat/test-extract_centroids.R index ea5ada38..ebbc8920 100644 --- a/tests/testthat/test-extract_centroids.R +++ b/tests/testthat/test-extract_centroids.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in extract_centroids()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_centroids(spec, prefix = "C_") diff --git a/tests/testthat/test-extract_cluster_assignment.R b/tests/testthat/test-extract_cluster_assignment.R index 7cca9a1f..bf2bb5d0 100644 --- a/tests/testthat/test-extract_cluster_assignment.R +++ b/tests/testthat/test-extract_cluster_assignment.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in extract_cluster_assignment()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_cluster_assignment(spec, prefix = "C_") diff --git a/tests/testthat/test-extract_fit_summary.R b/tests/testthat/test-extract_fit_summary.R index e0de8ae2..a931be38 100644 --- a/tests/testthat/test-extract_fit_summary.R +++ b/tests/testthat/test-extract_fit_summary.R @@ -66,7 +66,7 @@ test_that("extract_fit_summary() errors for cluster spec", { test_that("prefix is passed in extract_fit_summary()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_fit_summary(spec, prefix = "C_") diff --git a/tests/testthat/test-hier_clust-stats.R b/tests/testthat/test-hier_clust-stats.R index 36c9585c..850d5734 100644 --- a/tests/testthat/test-hier_clust-stats.R +++ b/tests/testthat/test-hier_clust-stats.R @@ -53,8 +53,15 @@ test_that("extract_centroids() works", { expect_identical( colnames(centroids), - c(".cluster", "Sepal.Length", "Sepal.Width", "Petal.Length", - "Petal.Width", "Speciesversicolor", "Speciesvirginica") + c( + ".cluster", + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Speciesversicolor", + "Speciesvirginica" + ) ) expect_identical( @@ -73,7 +80,9 @@ test_that("extract_cluster_assignment() works", { clusters <- extract_cluster_assignment(res) expected <- vctrs::vec_cbind( - tibble::tibble(.cluster = factor(paste0("Cluster_", cutree(res$fit, k = 3)))) + tibble::tibble( + .cluster = factor(paste0("Cluster_", cutree(res$fit, k = 3))) + ) ) expect_identical( diff --git a/tests/testthat/test-k_means-clustMixType.R b/tests/testthat/test-k_means-clustMixType.R index 8928ff8e..e07a90e9 100644 --- a/tests/testthat/test-k_means-clustMixType.R +++ b/tests/testthat/test-k_means-clustMixType.R @@ -115,4 +115,3 @@ test_that("modifies errors about suggested other models", { fit(~., data = data.frame(letters, LETTERS)) ) }) - diff --git a/tests/testthat/test-k_means-clusterR.R b/tests/testthat/test-k_means-clusterR.R index 49d2d75c..b630e596 100644 --- a/tests/testthat/test-k_means-clusterR.R +++ b/tests/testthat/test-k_means-clusterR.R @@ -27,8 +27,12 @@ test_that("predicting", { expect_identical( preds, - tibble::tibble(.pred_cluster = factor(paste0("Cluster_", c(1, 1, 1, 2, 2)), - paste0("Cluster_", 1:3))) + tibble::tibble( + .pred_cluster = factor( + paste0("Cluster_", c(1, 1, 1, 2, 2)), + paste0("Cluster_", 1:3) + ) + ) ) }) diff --git a/tests/testthat/test-k_means-klaR.R b/tests/testthat/test-k_means-klaR.R index e205458a..4c554129 100644 --- a/tests/testthat/test-k_means-klaR.R +++ b/tests/testthat/test-k_means-klaR.R @@ -46,8 +46,12 @@ test_that("predicting", { expect_identical( preds, - tibble::tibble(.pred_cluster = factor(paste0("Cluster_", c(1, 1, 1, 1, 2)), - paste0("Cluster_", 1:3))) + tibble::tibble( + .pred_cluster = factor( + paste0("Cluster_", c(1, 1, 1, 1, 2)), + paste0("Cluster_", 1:3) + ) + ) ) }) @@ -89,14 +93,12 @@ test_that("predicting ties argument works", { expect_identical( predict(res, data.frame(x = "C", y = "C"), ties = "first"), - tibble::tibble(.pred_cluster = factor("Cluster_1", - paste0("Cluster_", 1:2))) + tibble::tibble(.pred_cluster = factor("Cluster_1", paste0("Cluster_", 1:2))) ) expect_identical( predict(res, data.frame(x = "C", y = "C"), ties = "last"), - tibble::tibble(.pred_cluster = factor("Cluster_2", - paste0("Cluster_", 1:2))) + tibble::tibble(.pred_cluster = factor("Cluster_2", paste0("Cluster_", 1:2))) ) }) diff --git a/tests/testthat/test-k_means.R b/tests/testthat/test-k_means.R index 9b530602..c9db68da 100644 --- a/tests/testthat/test-k_means.R +++ b/tests/testthat/test-k_means.R @@ -161,7 +161,7 @@ test_that("reordering is done correctly for ClusterR k_means", { expect_identical( summ$n_members, unname(as.integer(table(summ$cluster_assignments))) - ) + ) }) test_that("errors if `num_clust` isn't specified", { @@ -169,13 +169,13 @@ test_that("errors if `num_clust` isn't specified", { error = TRUE, k_means() %>% set_engine("stats") %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) ) expect_snapshot( error = TRUE, k_means() %>% set_engine("ClusterR") %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) ) }) diff --git a/tests/testthat/test-k_means_diagnostics.R b/tests/testthat/test-k_means_diagnostics.R index 98f355e0..dba6415f 100644 --- a/tests/testthat/test-k_means_diagnostics.R +++ b/tests/testthat/test-k_means_diagnostics.R @@ -16,7 +16,8 @@ test_that("kmeans sse metrics work", { clusters = 3 ) - expect_equal(sse_within(kmeans_fit_stats)$wss, + expect_equal( + sse_within(kmeans_fit_stats)$wss, c(42877.103, 76954.010, 7654.146), # hard coded because of order tolerance = 0.005 ) @@ -37,7 +38,8 @@ test_that("kmeans sse metrics work", { tolerance = 0.005 ) - expect_equal(sse_within(kmeans_fit_ClusterR)$wss, + expect_equal( + sse_within(kmeans_fit_ClusterR)$wss, c(42877.103, 56041.432, 4665.041), # hard coded because of order tolerance = 0.005 ) @@ -66,7 +68,8 @@ test_that("kmeans sse metrics work on new data", { new_data <- mtcars[1:4, ] - expect_equal(sse_within(kmeans_fit_stats, new_data)$wss, + expect_equal( + sse_within(kmeans_fit_stats, new_data)$wss, c(2799.21, 12855.17), tolerance = 0.005 ) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 349868e3..9f2003f6 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in predict()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- predict(spec, mtcars, prefix = "C_") diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index a368abae..0fd09bae 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -1,8 +1,8 @@ test_that("partition predictions", { kmeans_fit <- k_means(num_clusters = 3, mode = "partition") %>% - set_engine("stats") %>% - fit(~., data = mtcars) + set_engine("stats") %>% + fit(~., data = mtcars) expect_true(tibble::is_tibble(predict(kmeans_fit, new_data = mtcars))) expect_true( diff --git a/tests/testthat/test-reconcile_clusterings.R b/tests/testthat/test-reconcile_clusterings.R index 2b5d60f2..6ead4c6e 100644 --- a/tests/testthat/test-reconcile_clusterings.R +++ b/tests/testthat/test-reconcile_clusterings.R @@ -1,7 +1,11 @@ test_that("reconciliation works with one-to-one", { primary_cluster_assignment <- c( - "Apple", "Apple", "Carrot", "Carrot", - "Banana", "Banana" + "Apple", + "Apple", + "Carrot", + "Carrot", + "Banana", + "Banana" ) alt_cluster_assignment <- c("Dog", "Dog", "Cat", "Dog", "Fish", "Fish") @@ -18,8 +22,12 @@ test_that("reconciliation works with one-to-one", { test_that("reconciliation works with uneven numbers", { primary_cluster_assignment <- c( - "Apple", "Apple", "Carrot", "Carrot", - "Banana", "Banana" + "Apple", + "Apple", + "Carrot", + "Carrot", + "Banana", + "Banana" ) alt_cluster_assignment <- c("Dog", "Dog", "Cat", "Dog", "Parrot", "Fish") diff --git a/tests/testthat/test-tune_cluster.R b/tests/testthat/test-tune_cluster.R index 166daa02..922a69ae 100644 --- a/tests/testthat/test-tune_cluster.R +++ b/tests/testthat/test-tune_cluster.R @@ -163,7 +163,11 @@ test_that("tune model and recipe", { expect_equal( colnames(res$.metrics[[1]]), c( - "num_clusters", "num_comp", ".metric", ".estimator", ".estimate", + "num_clusters", + "num_comp", + ".metric", + ".estimator", + ".estimate", ".config" ) ) @@ -233,7 +237,11 @@ test_that('tune model and recipe (parallel_over = "everything")', { expect_equal( colnames(res$.metrics[[1]]), c( - "num_clusters", "num_comp", ".metric", ".estimator", ".estimate", + "num_clusters", + "num_comp", + ".metric", + ".estimator", + ".estimate", ".config" ) ) @@ -258,9 +266,12 @@ test_that("tune model only - failure in formula is caught elegantly", { ~z, resamples = data_folds, grid = cars_grid, - control = tune::control_grid(extract = function(x) { - 1 - }, save_pred = TRUE) + control = tune::control_grid( + extract = function(x) { + 1 + }, + save_pred = TRUE + ) ) ) diff --git a/tests/testthat/test-workflows.R b/tests/testthat/test-workflows.R index a0b9a5bc..d5277475 100644 --- a/tests/testthat/test-workflows.R +++ b/tests/testthat/test-workflows.R @@ -27,7 +27,7 @@ test_that("integrates with workflows::add_formula()", { kmeans_spec <- k_means(num_clusters = 2) wf_spec <- workflows::workflow() %>% - workflows::add_formula(~ .) %>% + workflows::add_formula(~.) %>% workflows::add_model(kmeans_spec) expect_no_error( @@ -51,7 +51,7 @@ test_that("integrates with workflows::add_recipe()", { kmeans_spec <- k_means(num_clusters = 2) wf_spec <- workflows::workflow() %>% - workflows::add_recipe(recipes::recipe(~ ., data = mtcars)) %>% + workflows::add_recipe(recipes::recipe(~., data = mtcars)) %>% workflows::add_model(kmeans_spec) expect_no_error(