diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 75096194..a5ae3718 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -9,8 +9,11 @@ on: branches: [main, master] pull_request: branches: [main, master] + workflow_dispatch: -name: R-CMD-check +name: R-CMD-check.yaml + +permissions: read-all jobs: R-CMD-check: @@ -25,21 +28,25 @@ jobs: - {os: macos-latest, r: 'release'} - {os: windows-latest, r: 'release'} - # use 4.1 to check with rtools40's older compiler - - {os: windows-latest, r: '4.1'} + # use 4.0 or 4.1 to check with rtools40's older compiler + - {os: windows-latest, r: 'oldrel-4'} - - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} - - {os: ubuntu-latest, r: 'release'} - - {os: ubuntu-latest, r: 'oldrel-1'} - - {os: ubuntu-latest, r: 'oldrel-2'} - - {os: ubuntu-latest, r: 'oldrel-3'} + - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} + - {os: ubuntu-latest, r: 'release'} + - {os: ubuntu-latest, r: 'oldrel-1'} + - {os: ubuntu-latest, r: 'oldrel-2'} + - {os: ubuntu-latest, r: 'oldrel-3'} + #- {os: ubuntu-latest, r: 'oldrel-4'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_KEEP_PKG_SOURCE: yes + CXX14: g++ + CXX14STD: -std=c++1y + CXX14FLAGS: -Wall -g -02 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -57,3 +64,8 @@ jobs: - uses: r-lib/actions/check-r-package@v2 with: upload-snapshots: true + + - name: Show testthat output + if: always() + run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true + shell: bash diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index 087f0b05..4bbce750 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -9,7 +9,9 @@ on: types: [published] workflow_dispatch: -name: pkgdown +name: pkgdown.yaml + +permissions: read-all jobs: pkgdown: @@ -19,8 +21,10 @@ jobs: group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + permissions: + contents: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -39,7 +43,7 @@ jobs: - name: Deploy to GitHub pages 🚀 if: github.event_name != 'pull_request' - uses: JamesIves/github-pages-deploy-action@v4.4.1 + uses: JamesIves/github-pages-deploy-action@v4.5.0 with: clean: false branch: gh-pages diff --git a/.github/workflows/pr-commands.yaml b/.github/workflows/pr-commands.yaml index 71f335b3..2edd93f2 100644 --- a/.github/workflows/pr-commands.yaml +++ b/.github/workflows/pr-commands.yaml @@ -4,7 +4,9 @@ on: issue_comment: types: [created] -name: Commands +name: pr-commands.yaml + +permissions: read-all jobs: document: @@ -13,8 +15,10 @@ jobs: runs-on: ubuntu-latest env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + permissions: + contents: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/pr-fetch@v2 with: @@ -50,8 +54,10 @@ jobs: runs-on: ubuntu-latest env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + permissions: + contents: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/pr-fetch@v2 with: diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 2c5bb502..98822609 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -6,7 +6,9 @@ on: pull_request: branches: [main, master] -name: test-coverage +name: test-coverage.yaml + +permissions: read-all jobs: test-coverage: @@ -15,7 +17,7 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-r@v2 with: @@ -23,28 +25,37 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::covr + extra-packages: any::covr, any::xml2 needs: coverage - name: Test coverage run: | - covr::codecov( + cov <- covr::package_coverage( quiet = FALSE, clean = FALSE, - install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package") + install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") ) + covr::to_cobertura(cov) shell: Rscript {0} + - uses: codecov/codecov-action@v4 + with: + fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }} + file: ./cobertura.xml + plugin: noop + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} + - name: Show testthat output if: always() run: | ## -------------------------------------------------------------------- - find ${{ runner.temp }}/package -name 'testthat.Rout*' -exec cat '{}' \; || true + find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true shell: bash - name: Upload test results if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: coverage-test-failures path: ${{ runner.temp }}/package diff --git a/DESCRIPTION b/DESCRIPTION index e676913a..f49d5abf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,17 +17,17 @@ Depends: R (>= 3.6) Imports: cli (>= 3.0.0), - dials (>= 1.1.0), + dials (>= 1.3.0), dplyr (>= 1.0.9), flexclust (>= 1.3-6), foreach, generics (>= 0.1.2), glue (>= 1.6.2), hardhat (>= 1.0.0), - modelenv (>= 0.1.0), + modelenv (>= 0.2.0.9000), parsnip (>= 1.0.2), + philentropy (>= 0.9.0), prettyunits (>= 1.1.0), - Rfast (>= 2.0.6), rlang (>= 1.0.6), rsample (>= 1.0.0), stats, @@ -49,6 +49,8 @@ Suggests: rmarkdown, testthat (>= 3.0.0), workflows (>= 1.1.2) +Remotes: + tidymodels/modelenv Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins, patchwork, ggforce, tidyverse/tidytemplate Config/testthat/edition: 3 diff --git a/NEWS.md b/NEWS.md index 190f4f63..0456878e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # tidyclust (development version) +* The philentropy package is now used to calculate distances rather than Rfast. (#199) + # tidyclust 0.2.3 * Update to fix revdep issue for clustMixType. (#190) diff --git a/R/aaa.R b/R/aaa.R index 6865a2da..db456411 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -3,17 +3,47 @@ utils::globalVariables( c( - ".", "..object", ".cluster", ".iter_config", ".iter_model", - ".iter_preprocessor", ".msg_model", ".submodels", "call_info", "cluster", - "component", "component_id", "compute_intercept", "data", "dist", "engine", - "engine2", "exposed", "func", "id", "iteration", "lab", "name", "neighbor", - "new_data", "object", "orig_label", "original", "predictor_indicators", - "remove_intercept", "seed", "sil_width", "splits", "tunable", "type", - "value", "x", "y" + ".", + "..object", + ".cluster", + ".iter_config", + ".iter_model", + ".iter_preprocessor", + ".msg_model", + ".submodels", + "call_info", + "cluster", + "component", + "component_id", + "compute_intercept", + "data", + "dist", + "engine", + "engine2", + "exposed", + "func", + "id", + "iteration", + "lab", + "name", + "neighbor", + "new_data", + "object", + "orig_label", + "original", + "predictor_indicators", + "remove_intercept", + "seed", + "sil_width", + "splits", + "tunable", + "type", + "value", + "x", + "y" ) ) - release_bullets <- function() { c( "Run `knit_engine_docs()` and `devtools::document()` to update docs" diff --git a/R/append.R b/R/append.R index 7ca6d68f..b9519ffc 100644 --- a/R/append.R +++ b/R/append.R @@ -1,9 +1,11 @@ # https://github.com/tidymodels/tune/blob/main/R/pull.R#L136 -append_predictions <- function(collection, - predictions, - split, - control, - .config = NULL) { +append_predictions <- function( + collection, + predictions, + split, + control, + .config = NULL +) { if (!control$save_pred) { return(NULL) } @@ -27,14 +29,16 @@ append_predictions <- function(collection, dplyr::bind_rows(collection, predictions) } -append_metrics <- function(workflow, - collection, - predictions, - metrics, - param_names, - event_level, - split, - .config = NULL) { +append_metrics <- function( + workflow, + collection, + predictions, + metrics, + param_names, + event_level, + split, + .config = NULL +) { if (inherits(predictions, "try-error")) { return(collection) } @@ -54,20 +58,22 @@ append_metrics <- function(workflow, dplyr::bind_rows(collection, tmp_est) } -append_extracts <- function(collection, - workflow, - grid, - split, - ctrl, - .config = NULL) { +append_extracts <- function( + collection, + workflow, + grid, + split, + ctrl, + .config = NULL +) { extracts <- grid %>% - dplyr::bind_cols(labels(split)) %>% - dplyr::mutate( - .extracts = list( - extract_details(workflow, ctrl$extract) + dplyr::bind_cols(labels(split)) %>% + dplyr::mutate( + .extracts = list( + extract_details(workflow, ctrl$extract) + ) ) - ) if (!rlang::is_null(.config)) { extracts <- cbind(extracts, .config) diff --git a/R/arguments.R b/R/arguments.R index eb44bebd..5b88b0cf 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -8,10 +8,9 @@ check_eng_args <- function(args, obj, core_args) { if (length(common_args) > 0) { args <- args[!(names(args) %in% common_args)] common_args <- paste0(common_args, collapse = ", ") - rlang::warn(glue::glue( - "The following arguments cannot be manually modified ", - "and were removed: {common_args}." - )) + cli::cli_warn( + "The arguments {common_args} cannot be manually modified and were removed." + ) } args } @@ -25,11 +24,12 @@ make_x_call <- function(object, target) { } object$method$fit$args[[unname(data_args["x"])]] <- - switch(target, + switch( + target, none = rlang::expr(x), data.frame = rlang::expr(maybe_data_frame(x)), matrix = rlang::expr(maybe_matrix(x)), - rlang::abort(glue::glue("Invalid data type target: {target}.")) + cli::cli_abort("Invalid data type target: {target}.") ) fit_call <- make_call( @@ -75,7 +75,7 @@ make_form_call <- function(object, env = NULL) { set_args.cluster_spec <- function(object, ...) { the_dots <- enquos(...) if (length(the_dots) == 0) { - rlang::abort("Please pass at least one named argument.") + cli::cli_abort("Please pass at least one named argument.") } main_args <- names(object$args) new_args <- names(the_dots) @@ -101,7 +101,7 @@ set_args.cluster_spec <- function(object, ...) { #' @inheritParams parsnip::set_mode #' @return An updated [`cluster_spec`] object. #' @export -set_mode.cluster_spec <- function(object, mode) { +set_mode.cluster_spec <- function(object, mode, ...) { cls <- class(object)[1] if (rlang::is_missing(mode)) { spec_modes <- rlang::env_get( diff --git a/R/augment.R b/R/augment.R index ba2c8232..7dd2499b 100644 --- a/R/augment.R +++ b/R/augment.R @@ -31,7 +31,7 @@ augment.cluster_fit <- function(x, new_data, ...) { stats::predict(x, new_data = new_data) ) } else { - rlang::abort(paste("Unknown mode:", x$spec$mode)) + cli::cli_abort("Unknown mode: {x$spec$mode}") } as_tibble(ret) } diff --git a/R/cluster_spec.R b/R/cluster_spec.R index 7037dd85..cb718d44 100644 --- a/R/cluster_spec.R +++ b/R/cluster_spec.R @@ -10,11 +10,19 @@ #' @export #' @keywords internal new_cluster_spec <- function(cls, args, eng_args, mode, method, engine) { - modelenv::check_spec_mode_engine_val(model = cls, mode = mode, eng = engine) + modelenv::check_spec_mode_engine_val( + model = cls, + mode = mode, + eng = engine, + call = rlang::caller_env() + ) out <- list( - args = args, eng_args = eng_args, - mode = mode, method = method, engine = engine + args = args, + eng_args = eng_args, + mode = mode, + method = method, + engine = engine ) class(out) <- make_classes_tidyclust(cls) out <- modelenv::new_unsupervised_spec(out) diff --git a/R/compat-purrr.R b/R/compat-purrr.R index e60efc8c..1f042c69 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -79,11 +79,16 @@ imap <- function(.x, .f, ...) { pmap <- function(.l, .f, ...) { .f <- as.function(.f) args <- .rlang_purrr_args_recycle(.l) - do.call("mapply", c( - FUN = list(quote(.f)), - args, MoreArgs = quote(list(...)), - SIMPLIFY = FALSE, USE.NAMES = FALSE - )) + do.call( + "mapply", + c( + FUN = list(quote(.f)), + args, + MoreArgs = quote(list(...)), + SIMPLIFY = FALSE, + USE.NAMES = FALSE + ) + ) } .rlang_purrr_args_recycle <- function(args) { lengths <- map_int(args, length) diff --git a/R/control.R b/R/control.R index 331b8820..ba54b853 100644 --- a/R/control.R +++ b/R/control.R @@ -37,10 +37,10 @@ check_control <- function(x, call = rlang::caller_env()) { abs(x - round(x)) < tol } if (!int_check(x$verbosity)) { - rlang::abort("verbosity should be an integer.", call = call) + cli::cli_abort("verbosity should be an integer.", call = call) } if (!is.logical(x$catch)) { - rlang::abort("catch should be a logical.", call = call) + cli::cli_abort("catch should be a logical.", call = call) } x } diff --git a/R/convert_data.R b/R/convert_data.R index 96e42aa5..b502b92b 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -32,15 +32,19 @@ #' @inheritParams fit.cluster_spec #' @rdname convert_helpers #' @keywords internal -.convert_form_to_x_fit <- function(formula, - data, - ..., - na.action = na.omit, - indicators = "traditional", - composition = "data.frame", - remove_intercept = TRUE) { +.convert_form_to_x_fit <- function( + formula, + data, + ..., + na.action = na.omit, + indicators = "traditional", + composition = "data.frame", + remove_intercept = TRUE +) { if (!(composition %in% c("data.frame", "matrix"))) { - rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") + cli::cli_abort( + "{.arg composition} should be {.cls data.frame} or {.cls matrix}." + ) } ## Assemble model.frame call from call arguments @@ -57,7 +61,7 @@ w <- as.vector(model.weights(mod_frame)) if (!is.null(w) && !is.numeric(w)) { - rlang::abort("`weights` must be a numeric vector") + cli::cli_abort("The {.arg weights} must be a numeric vector.") } # TODO: Do we actually use the offset when fitting? @@ -122,15 +126,11 @@ check_form_dots <- function(x) { good_args <- c("subset", "weights") good_names <- names(x) %in% good_args if (any(!good_names)) { - rlang::abort( - glue::glue( - "These argument(s) cannot be used to create the data: ", - glue::glue_collapse( - glue::glue("`{names(x)[!good_names]}`"), - sep = ", " - ), - ". Possible arguments are: ", - glue::glue_collapse(glue::glue("`{good_args}`"), sep = ", ") + cli::cli_abort( + c( + "The argument{?s} {.code {names(x)[!good_names]}} cannot be used + to create the data.", + "i" = "Possible arguments are: {.code {good_args}}." ) ) } @@ -155,11 +155,9 @@ local_one_hot_contrasts <- function(frame = rlang::caller_env()) { #' @inheritParams .convert_form_to_x_fit #' @rdname convert_helpers #' @keywords internal -.convert_x_to_form_fit <- function(x, - weights = NULL, - remove_intercept = TRUE) { +.convert_x_to_form_fit <- function(x, weights = NULL, remove_intercept = TRUE) { if (is.vector(x)) { - rlang::abort("`x` cannot be a vector.") + cli::cli_abort("{.arg x} cannot be a vector.") } if (remove_intercept) { @@ -182,10 +180,10 @@ local_one_hot_contrasts <- function(frame = rlang::caller_env()) { if (!is.null(weights)) { if (!is.numeric(weights)) { - rlang::abort("`weights` must be a numeric vector") + cli::cli_abort("The {.arg weights} must be a numeric vector.") } if (length(weights) != nrow(x)) { - rlang::abort(glue::glue("`weights` should have {nrow(x)} elements")) + cli::cli_abort("{.arg weights} should have {nrow(x)} elements.") } } @@ -212,12 +210,16 @@ make_formula <- function(x, short = TRUE) { #' @inheritParams predict.cluster_fit #' @rdname convert_helpers #' @keywords internal -.convert_form_to_x_new <- function(object, - new_data, - na.action = stats::na.pass, - composition = "data.frame") { +.convert_form_to_x_new <- function( + object, + new_data, + na.action = stats::na.pass, + composition = "data.frame" +) { if (!(composition %in% c("data.frame", "matrix"))) { - rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") + cli::cli_abort( + "{.arg composition} should be either {.code data.frame} or {.code matrix}." + ) } mod_terms <- object$terms diff --git a/R/dials-params.R b/R/dials-params.R index 0b569eda..dc39702d 100644 --- a/R/dials-params.R +++ b/R/dials-params.R @@ -40,6 +40,12 @@ linkage_method <- function(values = values_linkage_method) { #' @rdname linkage_method #' @export values_linkage_method <- c( - "ward.D", "ward.D2", "single", "complete", "average", "mcquitty", "median", + "ward.D", + "ward.D2", + "single", + "complete", + "average", + "mcquitty", + "median", "centroid" ) diff --git a/R/engine_docs.R b/R/engine_docs.R index 0a8d37d6..e74164fa 100644 --- a/R/engine_docs.R +++ b/R/engine_docs.R @@ -19,17 +19,17 @@ knit_engine_docs <- function(pattern = NULL) { } outputs <- gsub("Rmd$", "md", files) - res <- map2(files, outputs, ~ try(knitr::knit(.x, .y), silent = TRUE)) - is_error <- map_lgl(res, ~ inherits(.x, "try-error")) + res <- map2(files, outputs, ~try(knitr::knit(.x, .y), silent = TRUE)) + is_error <- map_lgl(res, ~inherits(.x, "try-error")) if (any(is_error)) { # In some cases where there are issues, the md file is empty. errors <- res[which(is_error)] error_nms <- basename(files)[which(is_error)] errors <- - map_chr(errors, ~ cli::ansi_strip(as.character(.x))) %>% - map2_chr(error_nms, ~ paste0(.y, ": ", .x)) %>% - map_chr(~ gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE)) + map_chr(errors, ~cli::ansi_strip(as.character(.x))) %>% + map2_chr(error_nms, ~paste0(.y, ": ", .x)) %>% + map_chr(~gsub("Error in .f(.x[[i]], ...) :", "", .x, fixed = TRUE)) cat("There were failures duing knitting:\n\n") cat(errors) cat("\n\n") diff --git a/R/engines.R b/R/engines.R index 3a1004d6..dbfc513d 100644 --- a/R/engines.R +++ b/R/engines.R @@ -31,21 +31,26 @@ set_engine.cluster_spec <- function(object, engine, ...) { stop_missing_engine <- function(cls, call = rlang::caller_env()) { info <- modelenv::get_from_env(cls) %>% - dplyr::group_by(mode) %>% - dplyr::summarize( - msg = paste0( - unique(mode), " {", - paste0(unique(engine), collapse = ", "), - "}" - ), - .groups = "drop" - ) + dplyr::group_by(mode) %>% + dplyr::summarize( + msg = paste0( + unique(mode), + " {", + paste0(unique(engine), collapse = ", "), + "}" + ), + .groups = "drop" + ) if (nrow(info) == 0) { - rlang::abort(glue::glue("No known engines for `{cls}()`."), call = call) + cli::cli_abort("No known engines for {.fn {cls}}.", call = call) } - msg <- paste0(info$msg, collapse = ", ") - msg <- paste("Missing engine. Possible mode/engine combinations are:", msg) - rlang::abort(msg, call = call) + cli::cli_abort( + c( + "Missing engine.", + "i" = "Possible mode/engine combinations are: {info$msg}." + ), + call = call + ) } load_libs <- function(x, quiet, attach = FALSE) { @@ -85,11 +90,8 @@ check_installs <- function(x, call = rlang::caller_env()) { if (any(!is_inst)) { missing_pkg <- x$method$libs[!is_inst] missing_pkg <- paste0(missing_pkg, collapse = ", ") - rlang::abort( - glue::glue( - "This engine requires some package installs: ", - glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ") - ), + cli::cli_abort( + "This engine requires installing {.pkg {missing_pkg}}.", call = call ) } diff --git a/R/extract.R b/R/extract.R index 4672516e..f55defc3 100644 --- a/R/extract.R +++ b/R/extract.R @@ -52,5 +52,5 @@ extract_fit_engine.cluster_fit <- function(x, ...) { if (any(names(x) == "fit")) { return(x$fit) } - rlang::abort("Internal error: The model fit does not have an engine fit.") + cli::cli_abort("Internal error: The model fit does not have an engine fit.") } diff --git a/R/extract_cluster_assignment.R b/R/extract_cluster_assignment.R index af18ce35..f530744c 100644 --- a/R/extract_cluster_assignment.R +++ b/R/extract_cluster_assignment.R @@ -69,10 +69,10 @@ extract_cluster_assignment <- function(object, ...) { #' @export extract_cluster_assignment.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } @@ -111,28 +111,30 @@ extract_cluster_assignment.kmodes <- function(object, ...) { } #' @export -extract_cluster_assignment.hclust <- function(object, - ..., - call = rlang::caller_env(0)) { +extract_cluster_assignment.hclust <- function( + object, + ..., + call = rlang::caller_env(0) +) { # if k or h is passed in the dots, use those. Otherwise, use attributes # from original model specification args <- list(...) if (!is.null(args[["h"]])) { - rlang::abort( - paste( - "Using `h` argument is not supported.", - "Please use `cut_height` instead." + cli::cli_abort( + c( + "Using {.arg h} argument is not supported.", + "i" = "Please use {.arg cut_height} instead." ), call = call ) } if (!is.null(args[["k"]])) { - rlang::abort( - paste( - "Using `k` argument is not supported.", - "Please use `num_clusters` instead." + cli::cli_abort( + c( + "Using {.arg k} argument is not supported.", + "i" = "Please use {.arg num_clusters} instead." ), call = call ) @@ -147,8 +149,8 @@ extract_cluster_assignment.hclust <- function(object, } if (is.null(num_clusters) && is.null(cut_height)) { - rlang::abort( - "Please specify either `num_clusters` or `cut_height`.", + cli::cli_abort( + "Please specify either {.arg num_clusters} or {.arg cut_height}.", call = call ) } @@ -159,10 +161,12 @@ extract_cluster_assignment.hclust <- function(object, # ------------------------------------------------------------------------------ -cluster_assignment_tibble <- function(clusters, - n_clusters, - ..., - prefix = "Cluster_") { +cluster_assignment_tibble <- function( + clusters, + n_clusters, + ..., + prefix = "Cluster_" +) { reorder_clusts <- order(union(unique(clusters), seq_len(n_clusters))) names <- paste0(prefix, seq_len(n_clusters)) res <- names[reorder_clusts][clusters] diff --git a/R/extract_fit_summary.R b/R/extract_fit_summary.R index ce6526a5..905b5184 100644 --- a/R/extract_fit_summary.R +++ b/R/extract_fit_summary.R @@ -23,12 +23,15 @@ extract_fit_summary <- function(object, ...) { } #' @export -extract_fit_summary.cluster_spec <- function(object, ..., - call = rlang::caller_env(n = 0)) { - rlang::abort( - paste( +extract_fit_summary.cluster_spec <- function( + object, + ..., + call = rlang::caller_env(n = 0) +) { + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ), call = call ) @@ -68,9 +71,11 @@ extract_fit_summary.kmeans <- function(object, ..., prefix = "Cluster_") { } #' @export -extract_fit_summary.KMeansCluster <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.KMeansCluster <- function( + object, + ..., + prefix = "Cluster_" +) { names <- paste0(prefix, seq_len(nrow(object$centroids))) names <- factor(names) @@ -93,9 +98,7 @@ extract_fit_summary.KMeansCluster <- function(object, } #' @export -extract_fit_summary.kproto <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.kproto <- function(object, ..., prefix = "Cluster_") { names <- paste0(prefix, seq_len(nrow(object$centers))) names <- factor(names) @@ -118,9 +121,7 @@ extract_fit_summary.kproto <- function(object, } #' @export -extract_fit_summary.kmodes <- function(object, - ..., - prefix = "Cluster_") { +extract_fit_summary.kmodes <- function(object, ..., prefix = "Cluster_") { names <- paste0(prefix, seq_len(nrow(object$modes))) names <- factor(names) @@ -166,7 +167,13 @@ extract_fit_summary.hclust <- function(object, ...) { sse_within_total_total <- map2_dbl( by_clust$data, seq_len(n_clust), - ~ sum(Rfast::dista(centroids[.y, ], .x)) + ~sum( + philentropy::dist_many_many( + as.matrix(centroids[.y, ]), + as.matrix(.x), + method = "euclidean" + ) + ) ) list( @@ -174,7 +181,13 @@ extract_fit_summary.hclust <- function(object, ...) { centroids = centroids, n_members = unname(as.integer(table(clusts))), sse_within_total_total = sse_within_total_total, - sse_total = sum(Rfast::dista(t(overall_centroid), training_data)), + sse_total = sum( + philentropy::dist_many_many( + t(overall_centroid), + as.matrix(training_data), + method = "euclidean" + ) + ), orig_labels = NULL, cluster_assignments = clusts ) diff --git a/R/extract_parameter_set_dials.R b/R/extract_parameter_set_dials.R index a3241b6b..7fd4ff48 100644 --- a/R/extract_parameter_set_dials.R +++ b/R/extract_parameter_set_dials.R @@ -10,7 +10,7 @@ extract_parameter_set_dials.cluster_spec <- function(x, ...) { all_args, by = c("name", "source", "component") ) %>% - dplyr::mutate(object = map(call_info, eval_call_info)) + dplyr::mutate(object = map(call_info, eval_call_info)) dials::parameters_constr( res$name, @@ -36,11 +36,7 @@ eval_call_info <- function(x) { silent = TRUE ) if (inherits(res, "try-error")) { - rlang::abort( - glue::glue( - "Error when calling {x$fun}(): {as.character(res)}" - ) - ) + cli::cli_abort("Error when calling {.fn {x$fun}}: {as.character(res)}") } } else { res <- NA diff --git a/R/finalize.R b/R/finalize.R index ac42bf3e..5e9a2e81 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -22,7 +22,7 @@ #' @export finalize_model_tidyclust <- function(x, parameters) { if (!inherits(x, "cluster_spec")) { - rlang::abort("`x` should be a tidyclust model specification.") + cli::cli_abort("{.arg x} should be a tidyclust model specification.") } parsnip::check_final_param(parameters) pset <- hardhat::extract_parameter_set_dials(x) @@ -46,7 +46,7 @@ finalize_model_tidyclust <- function(x, parameters) { #' @export finalize_workflow_tidyclust <- function(x, parameters) { if (!inherits(x, "workflow")) { - rlang::abort("`x` should be a workflow") + cli::cli_abort("{.arg x} should be {.obj_type_friendly workflow}") } parsnip::check_final_param(parameters) diff --git a/R/fit.R b/R/fit.R index 51d55957..635f2d0d 100644 --- a/R/fit.R +++ b/R/fit.R @@ -85,13 +85,15 @@ #' @return A fitted [`cluster_fit`] object. #' @export #' @export fit.cluster_spec -fit.cluster_spec <- function(object, - formula, - data, - control = control_cluster(), - ...) { +fit.cluster_spec <- function( + object, + formula, + data, + control = control_cluster(), + ... +) { if (object$mode == "unknown") { - rlang::abort("Please set the mode in the model specification.") + cli::cli_abort("Please set the mode in the model specification.") } control <- parsnip::condense_control(control, control_cluster()) @@ -101,13 +103,14 @@ fit.cluster_spec <- function(object, eng_vals <- possible_engines(object) object$engine <- eng_vals[1] if (control$verbosity > 0) { - rlang::warn(glue::glue("Engine set to `{object$engine}`.")) + cli::cli_warn("Engine set to {.code {object$engine}}.") } } if (all(c("x", "y") %in% names(dots))) { - rlang::abort( - "`fit.cluster_spec()` is for the formula methods. Use `fit_xy()` instead." + cli::cli_abort( + "The {.fn fit.cluster_spec} function is for the formula methods. + Use {.fn fit_xy} instead." ) } cl <- match.call(expand.dots = TRUE) @@ -133,33 +136,31 @@ fit.cluster_spec <- function(object, # used here, `fit_interface_formula` will determine if a # translation has to be made if the model interface is x/y/ res <- - switch(interfaces, + switch( + interfaces, # homogeneous combinations: - formula_formula = - form_form( - object = object, - control = control, - env = eval_env - ), + formula_formula = form_form( + object = object, + control = control, + env = eval_env + ), # heterogenous combinations - formula_matrix = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - formula_data.frame = - form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - rlang::abort(glue::glue("{interfaces} is unknown.")) + formula_matrix = form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + formula_data.frame = form_x( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + cli::cli_abort("{interfaces} is unknown.") ) model_classes <- class(res$fit) class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") @@ -176,7 +177,7 @@ check_interface <- function(formula, data, cl, model) { if (form_interface) { return("formula") } - rlang::abort("Error when checking the interface.") + cli::cli_abort("Error when checking the interface.") } inher <- function(x, cls, cl) { @@ -184,16 +185,9 @@ inher <- function(x, cls, cl) { call <- match.call() obj <- deparse(call[["x"]]) if (length(cls) > 1) { - rlang::abort( - glue::glue( - "`{obj}` should be one of the following classes: ", - glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ") - ) - ) + cli::cli_abort("{.code {obj}} should be {.cls {cls}}.") } else { - rlang::abort( - glue::glue("`{obj}` should be a {cls} object") - ) + cli::cli_abort("{.code {obj}} should be {.obj_type_friendly {cls}}.") } } invisible(x) @@ -241,14 +235,14 @@ fit_xy.cluster_spec <- control <- parsnip::condense_control(control, control_cluster()) if (is.null(colnames(x))) { - rlang::abort("'x' should have column names.") + cli::cli_abort("{.arg x} should have column names.") } if (is.null(object$engine)) { eng_vals <- possible_engines(object) object$engine <- eng_vals[1] if (control$verbosity > 0) { - rlang::warn(glue::glue("Engine set to `{object$engine}`.")) + cli::cli_warn("Engine set to {.code {object$engine}}.") } } @@ -270,37 +264,35 @@ fit_xy.cluster_spec <- # used here, `fit_interface_formula` will determine if a # translation has to be made if the model interface is x/y/ res <- - switch(interfaces, + switch( + interfaces, # homogeneous combinations: matrix_matrix = , - data.frame_matrix = - x_x( - object = object, - env = eval_env, - control = control, - target = "matrix", - ... - ), + data.frame_matrix = x_x( + object = object, + env = eval_env, + control = control, + target = "matrix", + ... + ), data.frame_data.frame = , - matrix_data.frame = - x_x( - object = object, - env = eval_env, - control = control, - target = "data.frame", - ... - ), + matrix_data.frame = x_x( + object = object, + env = eval_env, + control = control, + target = "data.frame", + ... + ), # heterogenous combinations matrix_formula = , - data.frame_formula = - x_form( - object = object, - env = eval_env, - control = control, - ... - ), - rlang::abort(glue::glue("{interfaces} is unknown.")) + data.frame_formula = x_form( + object = object, + env = eval_env, + control = control, + ... + ), + cli::cli_abort("{interfaces} is unknown.") ) model_classes <- class(res$fit) class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") @@ -311,7 +303,7 @@ check_x_interface <- function(x, cl, model) { sparse_ok <- allow_sparse(model) sparse_x <- inherits(x, "dgCMatrix") if (!sparse_ok && sparse_x) { - rlang::abort( + cli::cli_abort( "Sparse matrices not supported by this model/engine combination." ) } @@ -336,7 +328,7 @@ check_x_interface <- function(x, cl, model) { if (df_interface) { return("data.frame") } - rlang::abort("Error when checking the interface") + cli::cli_abort("Error when checking the interface") } allow_sparse <- function(x) { diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 136a1bf1..fd52affd 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -41,7 +41,7 @@ form_form <- function(object, control, env, ...) { form_x <- function(object, control, env, target = "none", ...) { encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) indicators <- encoding_info %>% dplyr::pull(predictor_indicators) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) @@ -70,11 +70,11 @@ form_x <- function(object, control, env, target = "none", ...) { x_x <- function(object, env, control, target = "none", y = NULL, ...) { if (!is.null(y) && length(y) > 0) { - rlang::abort("Outcomes are not used in `cluster_spec` objects.") + cli::cli_abort("Outcomes are not used in {.cls cluster_spec} objects.") } encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) if (remove_intercept) { @@ -120,7 +120,7 @@ x_x <- function(object, env, control, target = "none", y = NULL, ...) { x_form <- function(object, env, control, ...) { encoding_info <- modelenv::get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode, engine == object$engine) + dplyr::filter(mode == object$mode, engine == object$engine) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) diff --git a/R/hier_clust.R b/R/hier_clust.R index 1d5e898c..762d0075 100644 --- a/R/hier_clust.R +++ b/R/hier_clust.R @@ -56,11 +56,13 @@ #' hier_clust() #' @export hier_clust <- - function(mode = "partition", - engine = "stats", - num_clusters = NULL, - cut_height = NULL, - linkage_method = "complete") { + function( + mode = "partition", + engine = "stats", + num_clusters = NULL, + cut_height = NULL, + linkage_method = "complete" + ) { args <- list( num_clusters = enquo(num_clusters), cut_height = enquo(cut_height), @@ -95,15 +97,19 @@ print.hier_clust <- function(x, ...) { #' @method update hier_clust #' @rdname tidyclust_update #' @export -update.hier_clust <- function(object, - parameters = NULL, - num_clusters = NULL, - cut_height = NULL, - linkage_method = NULL, - fresh = FALSE, ...) { +update.hier_clust <- function( + object, + parameters = NULL, + num_clusters = NULL, + cut_height = NULL, + linkage_method = NULL, + fresh = FALSE, + ... +) { eng_args <- parsnip::update_engine_parameters( object$eng_args, - fresh = fresh, ... + fresh = fresh, + ... ) if (!is.null(parameters)) { @@ -150,7 +156,7 @@ check_args.hier_clust <- function(object) { args <- lapply(object$args, rlang::eval_tidy) if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0)) { - rlang::abort("The number of centers should be >= 0.") + cli::cli_abort("The number of centers should be >= 0.") } invisible(object) @@ -182,12 +188,16 @@ translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) { #' @return A dendrogram #' @keywords internal #' @export -.hier_clust_fit_stats <- function(x, - num_clusters = NULL, - cut_height = NULL, - linkage_method = NULL, - dist_fun = Rfast::Dist) { - dmat <- dist_fun(x) +.hier_clust_fit_stats <- function( + x, + num_clusters = NULL, + cut_height = NULL, + linkage_method = NULL, + dist_fun = philentropy::distance +) { + suppressMessages( + dmat <- dist_fun(x) + ) res <- stats::hclust(stats::as.dist(dmat), method = linkage_method) attr(res, "num_clusters") <- num_clusters attr(res, "cut_height") <- cut_height diff --git a/R/hier_clust_data.R b/R/hier_clust_data.R index 695b7ffe..2c9010a8 100644 --- a/R/hier_clust_data.R +++ b/R/hier_clust_data.R @@ -81,11 +81,10 @@ make_hier_clust <- function() { pre = NULL, post = NULL, func = c(fun = ".hier_clust_predict_stats"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) } diff --git a/R/k_means.R b/R/k_means.R index 3595827b..024e41c3 100644 --- a/R/k_means.R +++ b/R/k_means.R @@ -39,9 +39,7 @@ #' k_means() #' @export k_means <- - function(mode = "partition", - engine = "stats", - num_clusters = NULL) { + function(mode = "partition", engine = "stats", num_clusters = NULL) { args <- list( num_clusters = enquo(num_clusters) ) @@ -80,13 +78,17 @@ translate_tidyclust.k_means <- function(x, engine = x$engine, ...) { #' @method update k_means #' @rdname tidyclust_update #' @export -update.k_means <- function(object, - parameters = NULL, - num_clusters = NULL, - fresh = FALSE, ...) { +update.k_means <- function( + object, + parameters = NULL, + num_clusters = NULL, + fresh = FALSE, + ... +) { eng_args <- parsnip::update_engine_parameters( object$eng_args, - fresh = fresh, ... + fresh = fresh, + ... ) if (!is.null(parameters)) { @@ -131,7 +133,7 @@ check_args.k_means <- function(object) { args <- lapply(object$args, rlang::eval_tidy) if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0)) { - rlang::abort("The number of centers should be >= 0.") + cli::cli_abort("The number of centers should be >= 0.") } invisible(object) @@ -170,20 +172,22 @@ check_args.k_means <- function(object) { #' obs_per_cluster, between.SS_DIV_total.SS #' @keywords internal #' @export -.k_means_fit_ClusterR <- function(data, - clusters, - num_init = 1, - max_iters = 100, - initializer = "kmeans++", - fuzzy = FALSE, - verbose = FALSE, - CENTROIDS = NULL, - tol = 1e-04, - tol_optimal_init = 0.3, - seed = 1) { +.k_means_fit_ClusterR <- function( + data, + clusters, + num_init = 1, + max_iters = 100, + initializer = "kmeans++", + fuzzy = FALSE, + verbose = FALSE, + CENTROIDS = NULL, + tol = 1e-04, + tol_optimal_init = 0.3, + seed = 1 +) { if (is.null(clusters)) { - rlang::abort( - "Please specify `num_clust` to be able to fit specification.", + cli::cli_abort( + "Please specify {.arg num_clust} to be able to fit specification.", call = call("fit") ) } @@ -225,8 +229,8 @@ check_args.k_means <- function(object) { #' @export .k_means_fit_stats <- function(data, centers = NULL, ...) { if (is.null(centers)) { - rlang::abort( - "Please specify `num_clust` to be able to fit specification.", + cli::cli_abort( + "Please specify {.arg num_clust} to be able to fit specification.", call = call("fit") ) } @@ -259,8 +263,8 @@ check_args.k_means <- function(object) { c( "Engine `clustMixType` requires both numeric and categorical \\ predictors.", - "x" = "Only numeric predictors where used.", - "i" = "Try using the `stats` engine with \\ + "x" = "Only numeric predictors where used.", + "i" = "Try using the `stats` engine with \\ {.code mod %>% set_engine(\"stats\")}." ), call = call("fit") @@ -271,8 +275,8 @@ check_args.k_means <- function(object) { c( "Engine `clustMixType` requires both numeric and categorical \\ predictors.", - "x" = "Only categorical predictors where used.", - "i" = "Try using the `klaR` engine with \\ + "x" = "Only categorical predictors where used.", + "i" = "Try using the `klaR` engine with \\ {.code mod %>% set_engine(\"klaR\")}." ), call = call("fit") diff --git a/R/k_means_data.R b/R/k_means_data.R index 73f6be6d..09eae346 100644 --- a/R/k_means_data.R +++ b/R/k_means_data.R @@ -64,11 +64,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_stats"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) @@ -131,15 +130,14 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_ClusterR"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) -# ---------------------------------------------------------------------------- + # ---------------------------------------------------------------------------- modelenv::set_model_engine("k_means", "partition", "clustMixType") modelenv::set_dependency( @@ -198,11 +196,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_clustMixType"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) @@ -265,11 +262,10 @@ make_k_means <- function() { pre = NULL, post = NULL, func = c(fun = ".k_means_predict_klaR"), - args = - list( - object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) - ) + args = list( + object = rlang::expr(object$fit), + new_data = rlang::expr(new_data) + ) ) ) } diff --git a/R/load_ns.R b/R/load_ns.R index 715053df..1bd0a25a 100644 --- a/R/load_ns.R +++ b/R/load_ns.R @@ -28,8 +28,7 @@ load_namespace <- function(x) { did_load <- map_lgl(x, requireNamespace, quietly = TRUE) if (any(!did_load)) { bad <- x[!did_load] - msg <- paste0("'", bad, "'", collapse = ", ") - rlang::abort(paste("These packages could not be loaded:", msg)) + cli::cli_abort("The package{?s} {.pkg {bad}} could not be loaded.") } } @@ -37,6 +36,17 @@ load_namespace <- function(x) { } infra_pkgs <- c( - "tune", "recipes", "tidyclust", "yardstick", "purrr", "dplyr", "tibble", - "dials", "rsample", "workflows", "tidyr", "rlang", "vctrs" + "tune", + "recipes", + "tidyclust", + "yardstick", + "purrr", + "dplyr", + "tibble", + "dials", + "rsample", + "workflows", + "tidyr", + "rlang", + "vctrs" ) diff --git a/R/metric-aaa.R b/R/metric-aaa.R index 9adfea01..d1fe5831 100644 --- a/R/metric-aaa.R +++ b/R/metric-aaa.R @@ -17,7 +17,7 @@ #' @export new_cluster_metric <- function(fn, direction) { if (!is.function(fn)) { - rlang::abort("`fn` must be a function.") + cli::cli_abort("{.arg fn} must be a function.") } direction <- rlang::arg_match( @@ -61,17 +61,17 @@ cluster_metric_set <- function(...) { if (fn_cls == "cluster_metric") { make_cluster_metric_function(fns) } else { - rlang::abort(paste0( - "Internal error: `validate_function_class()` should have ", - "errored on unknown classes." - )) + cli::cli_abort( + "Internal error: {.fn validate_function_class} should have errored on + unknown classes." + ) } } validate_not_empty <- function(x) { if (rlang::is_empty(x)) { - rlang::abort( - "`cluster_metric_set()` requires at least 1 function supplied to `...`." + cli::cli_abort( + "{.fn cluster_metric_set} requires at least 1 function supplied to {.arg ...}." ) } } @@ -82,10 +82,10 @@ validate_inputs_are_functions <- function(fns) { if (!all_fns) { not_fn <- which(!is_fun_vec) not_fn <- paste(not_fn, collapse = ", ") - rlang::abort( - glue::glue( - "All inputs to `cluster_metric_set()` must be functions. ", - "These inputs are not: ({not_fn})." + cli::cli_abort( + c( + "All inputs to {.fn cluster_metric_set} must be functions.", + "i" = "These inputs are not: {not_fn}." ) ) } @@ -94,11 +94,9 @@ validate_inputs_are_functions <- function(fns) { get_quo_label <- function(quo) { out <- rlang::as_label(quo) if (length(out) != 1L) { - rlang::abort( - glue::glue( - "Internal error: ", - "`as_label(quo)` resulted in a character vector of length > 1." - ) + cli::cli_abort( + "Internal error: {.code as_label(quo)} resulted in a character vector + of length > 1." ) } is_namespaced <- grepl("::", out, fixed = TRUE) @@ -111,14 +109,14 @@ get_quo_label <- function(quo) { validate_function_typo <- function(fns, call = rlang::caller_env()) { if (any(map_lgl(fns, identical, silhouette))) { - rlang::abort( - "`silhouette` is not a cluster metric. Did you mean `silhouette_avg`?", + cli::cli_abort( + "The value {.val silhouette} is not a cluster metric. Did you mean {.code silhouette_avg}?", call = call ) } if (any(map_lgl(fns, identical, sse_within))) { - rlang::abort( - "`sse_within_total` is not a cluster metric. Did you mean `sse_within_total`?", + cli::cli_abort( + "{.arg sse_within_total} is not a cluster metric. Did you mean {.code sse_within_total}?", call = call ) } @@ -164,12 +162,10 @@ validate_function_class <- function(fns) { fn_names = fn_bad_names, USE.NAMES = FALSE ) - fn_pastable <- paste0(fn_pastable, collapse = "\n") - rlang::abort( - paste0( - "\nThe combination of metric functions must be:\n", - "- only clustering metrics\n", - "The following metric function types are being mixed:\n", + cli::cli_abort( + c( + "The combination of metric functions must be only clustering metrics.", + "i" = "The following metric function types are being mixed:", fn_pastable ) ) @@ -183,8 +179,11 @@ make_cluster_metric_function <- function(fns) { ) calls <- lapply(fns, rlang::call2, !!!call_args) metric_list <- mapply( - FUN = eval_safely, calls, names(calls), - SIMPLIFY = FALSE, USE.NAMES = FALSE + FUN = eval_safely, + calls, + names(calls), + SIMPLIFY = FALSE, + USE.NAMES = FALSE ) dplyr::bind_rows(metric_list) } @@ -197,11 +196,14 @@ make_cluster_metric_function <- function(fns) { } eval_safely <- function(expr, expr_nm, data = NULL, env = rlang::caller_env()) { - tryCatch(expr = { - rlang::eval_tidy(expr, data = data, env = env) - }, error = function(e) { - rlang::abort(paste0("In metric: `", expr_nm, "`\n", conditionMessage(e))) - }) + tryCatch( + expr = { + rlang::eval_tidy(expr, data = data, env = env) + }, + error = function(e) { + cli::cli_abort("In metric: {.code {expr_nm}}\n{conditionMessage(e)}") + } + ) } #' @export diff --git a/R/metric-helpers.R b/R/metric-helpers.R index 57c0d139..b6d8d736 100644 --- a/R/metric-helpers.R +++ b/R/metric-helpers.R @@ -8,11 +8,15 @@ #' @param dist_fun A custom distance functions. #' #' @return A list -prep_data_dist <- function(object, new_data = NULL, - dists = NULL, dist_fun = Rfast::Dist) { +prep_data_dist <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance +) { # Sihouettes requires a distance matrix if (is.null(new_data) && is.null(dists)) { - rlang::abort( + cli::cli_abort( "Must supply either a dataset or distance matrix to compute silhouettes." ) } @@ -27,9 +31,9 @@ prep_data_dist <- function(object, new_data = NULL, # If they supplied distance, check that it matches the data dimension if (!is.null(dists)) { if (!is.null(new_data) && nrow(new_data) != attr(dists, "Size")) { - rlang::abort("Dimensions of dataset and distance matrix must match.") + cli::cli_abort("Dimensions of dataset and distance matrix must match.") } else if (is.null(new_data) && length(clusters) != attr(dists, "Size")) { - rlang::abort( + cli::cli_abort( "Dimensions of training dataset and distance matrix must match." ) } @@ -42,14 +46,18 @@ prep_data_dist <- function(object, new_data = NULL, # Calculate distances including optionally supplied params if (is.null(dists)) { - dists <- dist_fun(new_data) + suppressMessages( + dists <- dist_fun(new_data) + ) } - return(list( - clusters = clusters, - data = new_data, - dists = dists - )) + return( + list( + clusters = clusters, + data = new_data, + dists = dists + ) + ) } #' Computes distance from observations to centroids @@ -57,11 +65,19 @@ prep_data_dist <- function(object, new_data = NULL, #' @param new_data A data frame #' @param centroids A data frame where each row is a centroid. #' @param dist_fun A function for computing matrix-to-matrix distances. Defaults -#' to `Rfast::dista()` -get_centroid_dists <- function(new_data, centroids, dist_fun = Rfast::dista) { +#' to +#' `function(x, y) philentropy::dist_many_many(x, y, method = "euclidean")`. +get_centroid_dists <- function( + new_data, + centroids, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + } +) { if (ncol(new_data) != ncol(centroids)) { - rlang::abort("Centroids must have same columns as data.") } - dist_fun(centroids, new_data) + suppressMessages( + dist_fun(as.matrix(centroids), as.matrix(new_data)) + ) } diff --git a/R/metric-silhouette.R b/R/metric-silhouette.R index 2e823a50..3b4a8ea2 100644 --- a/R/metric-silhouette.R +++ b/R/metric-silhouette.R @@ -23,13 +23,17 @@ #' #' silhouette(kmeans_fit, dists = dists) #' @export -silhouette <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist) { +silhouette <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance +) { if (inherits(object, "cluster_spec")) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } @@ -43,7 +47,8 @@ silhouette <- function(object, new_data = NULL, dists = NULL, if (!inherits(sil, "silhouette")) { res <- tibble::tibble( cluster = preproc$clusters, - neighbor = factor(rep(NA_character_, length(preproc$clusters)), + neighbor = factor( + rep(NA_character_, length(preproc$clusters)), levels = levels(preproc$clusters) ), sil_width = NA_real_ @@ -69,7 +74,6 @@ silhouette <- function(object, new_data = NULL, dists = NULL, #' @param dist_fun A function for calculating distances between observations. #' Defaults to Euclidean distance on processed data. #' @param ... Other arguments passed to methods. -#' #' @details Not to be confused with [silhouette()] that returns a tibble #' with silhouette for each observation. #' @@ -103,20 +107,25 @@ silhouette_avg <- new_cluster_metric( #' @export #' @rdname silhouette_avg silhouette_avg.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } #' @export #' @rdname silhouette_avg -silhouette_avg.cluster_fit <- function(object, new_data = NULL, dists = NULL, - dist_fun = NULL, ...) { +silhouette_avg.cluster_fit <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { - dist_fun <- Rfast::Dist + dist_fun <- philentropy::distance } res <- silhouette_avg_impl(object, new_data, dists, dist_fun, ...) @@ -134,12 +143,22 @@ silhouette_avg.workflow <- silhouette_avg.cluster_fit #' @export #' @rdname silhouette_avg -silhouette_avg_vec <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { +silhouette_avg_vec <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance, + ... +) { silhouette_avg_impl(object, new_data, dists, dist_fun, ...) } -silhouette_avg_impl <- function(object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, ...) { +silhouette_avg_impl <- function( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance, + ... +) { mean(silhouette(object, new_data, dists, dist_fun, ...)$sil_width) } diff --git a/R/metric-sse.R b/R/metric-sse.R index 84aa0ebb..1760c317 100644 --- a/R/metric-sse.R +++ b/R/metric-sse.R @@ -19,12 +19,18 @@ #' #' sse_within(kmeans_fit) #' @export -sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) { +sse_within <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + } +) { if (inherits(object, "cluster_spec")) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } @@ -43,14 +49,21 @@ sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) { n_members = summ$n_members ) } else { - dist_to_centroids <- dist_fun(summ$centroids, new_data) + suppressMessages( + dist_to_centroids <- dist_fun( + as.matrix(summ$centroids), + as.matrix(new_data) + ) + ) res <- dist_to_centroids %>% tibble::as_tibble(.name_repair = "minimal") %>% - map(~ c( - .cluster = which.min(.x), - dist = min(.x)^2 - )) %>% + map( + ~c( + .cluster = which.min(.x), + dist = min(.x)^2 + ) + ) %>% dplyr::bind_rows() %>% dplyr::mutate( .cluster = factor(paste0("Cluster_", .cluster)) @@ -102,20 +115,26 @@ sse_within_total <- new_cluster_metric( #' @export #' @rdname sse_within_total sse_within_total.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } #' @export #' @rdname sse_within_total -sse_within_total.cluster_fit <- function(object, new_data = NULL, - dist_fun = NULL, ...) { +sse_within_total.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { - dist_fun <- Rfast::dista + dist_fun <- function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + } } res <- sse_within_total_impl(object, new_data, dist_fun, ...) @@ -133,13 +152,25 @@ sse_within_total.workflow <- sse_within_total.cluster_fit #' @export #' @rdname sse_within_total -sse_within_total_vec <- function(object, new_data = NULL, - dist_fun = Rfast::dista, ...) { +sse_within_total_vec <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { sse_within_total_impl(object, new_data, dist_fun, ...) } -sse_within_total_impl <- function(object, new_data = NULL, - dist_fun = Rfast::dista, ...) { +sse_within_total_impl <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE) } @@ -177,20 +208,26 @@ sse_total <- new_cluster_metric( #' @export #' @rdname sse_total sse_total.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } #' @export #' @rdname sse_total -sse_total.cluster_fit <- function(object, new_data = NULL, dist_fun = NULL, - ...) { +sse_total.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { - dist_fun <- Rfast::dista + dist_fun <- function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + } } res <- sse_total_impl(object, new_data, dist_fun, ...) @@ -208,12 +245,25 @@ sse_total.workflow <- sse_total.cluster_fit #' @export #' @rdname sse_total -sse_total_vec <- function(object, new_data = NULL, dist_fun = Rfast::dista, ...) { +sse_total_vec <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { sse_total_impl(object, new_data, dist_fun, ...) } -sse_total_impl <- function(object, new_data = NULL, dist_fun = Rfast::dista, - ...) { +sse_total_impl <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { # Preprocess data before computing distances if appropriate if (inherits(object, "workflow") && !is.null(new_data)) { new_data <- extract_post_preprocessor(object, new_data) @@ -226,7 +276,10 @@ sse_total_impl <- function(object, new_data = NULL, dist_fun = Rfast::dista, } else { overall_mean <- colSums(summ$centroids * summ$n_members) / sum(summ$n_members) - tot <- dist_fun(t(as.matrix(overall_mean)), new_data)^2 %>% sum() + suppressMessages( + tot <- dist_fun(t(as.matrix(overall_mean)), as.matrix(new_data))^2 %>% + sum() + ) } return(tot) @@ -266,20 +319,26 @@ sse_ratio <- new_cluster_metric( #' @export #' @rdname sse_ratio sse_ratio.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } #' @export #' @rdname sse_ratio -sse_ratio.cluster_fit <- function(object, new_data = NULL, - dist_fun = NULL, ...) { +sse_ratio.cluster_fit <- function( + object, + new_data = NULL, + dist_fun = NULL, + ... +) { if (is.null(dist_fun)) { - dist_fun <- Rfast::dista + dist_fun <- function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + } } res <- sse_ratio_impl(object, new_data, dist_fun, ...) @@ -296,17 +355,25 @@ sse_ratio.workflow <- sse_ratio.cluster_fit #' @export #' @rdname sse_ratio -sse_ratio_vec <- function(object, - new_data = NULL, - dist_fun = Rfast::dista, - ...) { +sse_ratio_vec <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { sse_ratio_impl(object, new_data, dist_fun, ...) } -sse_ratio_impl <- function(object, - new_data = NULL, - dist_fun = Rfast::dista, - ...) { +sse_ratio_impl <- function( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = "euclidean") + }, + ... +) { sse_within_total_vec(object, new_data, dist_fun) / sse_total_vec(object, new_data, dist_fun) } diff --git a/R/misc.R b/R/misc.R index a79caffc..1cae8c73 100644 --- a/R/misc.R +++ b/R/misc.R @@ -12,13 +12,12 @@ check_args.default <- function(object) { check_spec_pred_type <- function(object, type) { if (!spec_has_pred_type(object, type)) { possible_preds <- names(object$spec$method$pred) - rlang::abort(c( - glue::glue("No {type} prediction method available for this model."), - glue::glue( - "Value for `type` should be one of: ", - glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ") + cli::cli_abort( + c( + "No {type} prediction method available for this model.", + "i" = "{.arg type} should be one of {.val {possible_preds}}." ) - )) + ) } invisible(NULL) } diff --git a/R/model_object_docs.R b/R/model_object_docs.R index c851467a..8f7d973d 100644 --- a/R/model_object_docs.R +++ b/R/model_object_docs.R @@ -185,4 +185,3 @@ NULL #' @rdname cluster_fit #' @name cluster_fit NULL - diff --git a/R/predict.R b/R/predict.R index 0451a3e8..595a81e4 100644 --- a/R/predict.R +++ b/R/predict.R @@ -88,13 +88,17 @@ #' @method predict cluster_fit #' @export predict.cluster_fit #' @export -predict.cluster_fit <- function(object, - new_data, - type = NULL, - opts = list(), - ...) { +predict.cluster_fit <- function( + object, + new_data, + type = NULL, + opts = list(), + ... +) { if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn( + "Model fit failed; cannot make predictions." + ) return(NULL) } @@ -103,34 +107,28 @@ predict.cluster_fit <- function(object, type <- check_pred_type(object, type) - res <- switch(type, + res <- switch( + type, cluster = predict_cluster(object = object, new_data = new_data, ...), raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), - rlang::abort(glue::glue("I don't know about type = '{type}'")) + cli::cli_abort("I don't know about type = {.val {type}}") ) - res <- switch(type, - cluster = format_cluster(res), - res - ) + res <- switch(type, cluster = format_cluster(res), res) res } check_pred_type <- function(object, type, ...) { if (is.null(type)) { type <- - switch(object$spec$mode, + switch( + object$spec$mode, partition = "cluster", - rlang::abort("`type` should be 'cluster'.") + cli::cli_abort("The {.arg type} argument should be {.val cluster}.") ) } if (!(type %in% pred_types)) { - rlang::abort( - glue::glue( - "`type` should be one of: ", - glue::glue_collapse(pred_types, sep = ", ", last = " and ") - ) - ) + cli::cli_abort("{.arg type} should be {.or {pred_types}}.") } type } @@ -154,13 +152,14 @@ prepare_data <- function(object, new_data) { remove_intercept <- modelenv::get_encoding(class(object$spec)[1]) %>% - dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% - dplyr::pull(remove_intercept) + dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>% + dplyr::pull(remove_intercept) if (remove_intercept && any(grepl("Intercept", names(new_data)))) { new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)")) } - switch(fit_interface, + switch( + fit_interface, none = new_data, data.frame = as.data.frame(new_data), matrix = as.matrix(new_data), @@ -180,10 +179,10 @@ make_pred_call <- function(x) { #' @export predict.cluster_spec <- function(object, ...) { - rlang::abort( - paste( + cli::cli_abort( + c( "This function requires a fitted model.", - "Please use `fit()` on your cluster specification." + "i" = "Please use {.fn fit} on your cluster specification." ) ) } diff --git a/R/predict_cluster.R b/R/predict_cluster.R index ceee4988..75646b5e 100644 --- a/R/predict_cluster.R +++ b/R/predict_cluster.R @@ -23,7 +23,7 @@ predict_cluster.cluster_fit <- function(object, new_data, ...) { check_spec_pred_type(object, "cluster") if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } diff --git a/R/predict_helpers.R b/R/predict_helpers.R index f4780ae4..2ad5dabd 100644 --- a/R/predict_helpers.R +++ b/R/predict_helpers.R @@ -18,15 +18,23 @@ make_predictions <- function(x, prefix, n_clusters) { make_predictions(clusters, prefix, n_clusters) } -.k_means_predict_clustMixType <- function(object, new_data, prefix = "Cluster_") { +.k_means_predict_clustMixType <- function( + object, + new_data, + prefix = "Cluster_" +) { clusters <- predict(object, new_data)$cluster n_clusters <- length(object$size) make_predictions(clusters, prefix, n_clusters) } -.k_means_predict_klaR <- function(object, new_data, prefix = "Cluster_", - ties = c("first", "last", "random")) { +.k_means_predict_klaR <- function( + object, + new_data, + prefix = "Cluster_", + ties = c("first", "last", "random") +) { ties <- rlang::arg_match(ties) modes <- object$modes @@ -42,7 +50,6 @@ make_predictions <- function(x, prefix, n_clusters) { which_min <- which(misses == min(misses)) - if (length(which_min) == 1) { clusters[i] <- which_min } else { @@ -58,7 +65,12 @@ make_predictions <- function(x, prefix, n_clusters) { make_predictions(clusters, prefix, n_modes) } -.hier_clust_predict_stats <- function(object, new_data, ..., prefix = "Cluster_") { +.hier_clust_predict_stats <- function( + object, + new_data, + ..., + prefix = "Cluster_" +) { linkage_method <- object$method new_data <- as.matrix(new_data) @@ -75,7 +87,8 @@ make_predictions <- function(x, prefix, n_clusters) { ## complete, single, average, and median linkage_methods are basically the ## same idea, just different summary distance to cluster - cluster_dist_fun <- switch(linkage_method, + cluster_dist_fun <- switch( + linkage_method, "single" = min, "complete" = max, "average" = mean, @@ -83,7 +96,11 @@ make_predictions <- function(x, prefix, n_clusters) { ) # need this to be obs on rows, dist to new data on cols - dists_new <- Rfast::dista(xnew = training_data, x = new_data, trans = TRUE) + dists_new <- philentropy::dist_many_many( + training_data, + new_data, + method = "euclidean" + ) cluster_dists <- dplyr::bind_cols(data.frame(dists_new), clusters) %>% dplyr::group_by(.cluster) %>% @@ -96,7 +113,12 @@ make_predictions <- function(x, prefix, n_clusters) { ## Centroid linkage_method, dist to center cluster_centers <- extract_centroids(object) %>% dplyr::select(-.cluster) - dists_means <- Rfast::dista(new_data, cluster_centers) + + dists_means <- philentropy::dist_many_many( + new_data, + cluster_centers, + method = "euclidean" + ) pred_clusts_num <- apply(dists_means, 1, which.min) } else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) { @@ -111,7 +133,7 @@ make_predictions <- function(x, prefix, n_clusters) { d_means <- map( seq_len(n_clust), - ~ t( + ~t( t(training_data[clusters$.cluster == cluster_names[.x], ]) - cluster_centers[.x, ] ) @@ -122,8 +144,10 @@ make_predictions <- function(x, prefix, n_clusters) { function(new_obs) { map( seq_len(n_clust), - ~ t(t(training_data[clusters$.cluster == cluster_names[.x], ]) - - new_data[new_obs, ]) + ~t( + t(training_data[clusters$.cluster == cluster_names[.x], ]) - + new_data[new_obs, ] + ) ) } ) @@ -134,18 +158,17 @@ make_predictions <- function(x, prefix, n_clusters) { d_new_list, function(v) { map2_dbl( - d_means, v, - ~ sum((n * .x + .y)^2 / (n + 1)^2 - .x^2) + d_means, + v, + ~sum((n * .x + .y)^2 / (n + 1)^2 - .x^2) ) } ) pred_clusts_num <- map_dbl(change_in_ess, which.min) } else { - rlang::abort( - glue::glue( - "linkage_method {linkage_method} is not supported for prediction." - ) + cli::cli_abort( + "linkage_method {.val {linkage_method}} is not supported for prediction." ) } pred_clusts <- unique(clusters$.cluster)[pred_clusts_num] diff --git a/R/predict_raw.R b/R/predict_raw.R index 6cf59f9e..4d64b59c 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -16,7 +16,7 @@ predict_raw.cluster_fit <- function(object, new_data, opts = list(), ...) { check_spec_pred_type(object, "raw") if (inherits(object$fit, "try-error")) { - rlang::warn("Cluster fit failed; cannot make predictions.") + cli::cli_warn("Cluster fit failed; cannot make predictions.") return(NULL) } diff --git a/R/print.R b/R/print.R index 90cf9f4e..ee8ad397 100644 --- a/R/print.R +++ b/R/print.R @@ -5,7 +5,8 @@ print.cluster_fit <- function(x, ...) { cat("tidyclust cluster object\n\n") if (!is.na(x$elapsed[["elapsed"]])) { cat( - "Fit time: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]), + "Fit time: ", + prettyunits::pretty_sec(x$elapsed[["elapsed"]]), "\n" ) } diff --git a/R/pull.R b/R/pull.R index 3761dfe0..7dd33013 100644 --- a/R/pull.R +++ b/R/pull.R @@ -32,20 +32,20 @@ pulley <- function(resamples, res, col) { if (all(map_lgl(res, inherits, "simpleError"))) { res <- resamples %>% - dplyr::mutate(col = map(splits, ~NULL)) %>% - stats::setNames(c(names(resamples), col)) + dplyr::mutate(col = map(splits, ~NULL)) %>% + stats::setNames(c(names(resamples), col)) return(res) } id_cols <- grep("^id", names(resamples), value = TRUE) resamples <- dplyr::arrange(resamples, !!!rlang::syms(id_cols)) - pulled_vals <- dplyr::bind_rows(map(res, ~ .x[[col]])) + pulled_vals <- dplyr::bind_rows(map(res, ~.x[[col]])) if (nrow(pulled_vals) == 0) { res <- resamples %>% - dplyr::mutate(col = map(splits, ~NULL)) %>% - stats::setNames(c(names(resamples), col)) + dplyr::mutate(col = map(splits, ~NULL)) %>% + stats::setNames(c(names(resamples), col)) return(res) } diff --git a/R/reconcile_clusterings.R b/R/reconcile_clusterings.R index 41d1ed33..f14ea26c 100644 --- a/R/reconcile_clusterings.R +++ b/R/reconcile_clusterings.R @@ -31,18 +31,17 @@ #' factor2 <- c("Dog", "Dog", "Cat", "Dog", "Fish", "Parrot") #' reconcile_clusterings_mapping(factor1, factor2, one_to_one = FALSE) #' @export -reconcile_clusterings_mapping <- function(primary, - alternative, - one_to_one = TRUE, - optimize = "accuracy") { +reconcile_clusterings_mapping <- function( + primary, + alternative, + one_to_one = TRUE, + optimize = "accuracy" +) { rlang::check_installed("RcppHungarian") if (length(primary) != length(alternative)) { - rlang::abort( - glue::glue( - "`primary` ({length(primary)}) ", - "and `alternative` ({length(alternative)}) ", - "must be the same length." - ) + cli::cli_abort( + "{.arg primary} ({length(primary)}) and {.arg alternative} ({length(alternative)}) + must be the same length." ) } @@ -53,18 +52,14 @@ reconcile_clusterings_mapping <- function(primary, nclust_2 <- length(levels(clusters_2)) if (one_to_one && nclust_1 != nclust_2) { - rlang::abort( - glue::glue( - "For one-to-one matching, must have the same number of clusters in", - "primary and alt." - ) + cli::cli_abort( + "For one-to-one matching, must have the same number of clusters in + primary and alt." ) } else if (nclust_1 > nclust_2) { - rlang::abort( - glue::glue( - "Primary clustering must have equal or fewer clusters to alternate", - "clustering." - ) + cli::cli_abort( + "Primary clustering must have equal or fewer clusters to alternate + clustering." ) } diff --git a/R/required_pkgs.R b/R/required_pkgs.R index c87309fd..2566fbc4 100644 --- a/R/required_pkgs.R +++ b/R/required_pkgs.R @@ -2,7 +2,7 @@ #' @export required_pkgs.cluster_spec <- function(x, infra = TRUE, ...) { if (is.null(x$engine)) { - rlang::abort("Please set an engine.") + cli::cli_abort("Please set an engine.") } get_pkgs(x, infra) } @@ -16,7 +16,7 @@ get_pkgs <- function(x, infra) { cls <- class(x)[1] pkgs <- modelenv::get_from_env(paste0(cls, "_pkgs")) %>% - dplyr::filter(engine == x$engine) + dplyr::filter(engine == x$engine) res <- pkgs$pkg[[1]] if (length(res) == 0) { res <- character(0) diff --git a/R/translate.R b/R/translate.R index 65d69ee8..d6e05a97 100644 --- a/R/translate.R +++ b/R/translate.R @@ -40,14 +40,14 @@ translate_tidyclust <- function(x, ...) { translate_tidyclust.default <- function(x, engine = x$engine, ...) { check_empty_ellipse_tidyclust(...) if (is.null(engine)) { - rlang::abort("Please set an engine.") + cli::cli_abort("Please set an engine.") } mod_name <- specific_model(x) x$engine <- engine if (x$mode == "unknown") { - rlang::abort("Model code depends on the mode; please specify one.") + cli::cli_abort("Model code depends on the mode. Please specify one.") } modelenv::check_spec_mode_engine_val( @@ -103,20 +103,20 @@ get_cluster_spec <- function(model, mode, engine) { res <- list() res$libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) %>% - dplyr::filter(engine == !!engine) %>% - .[["pkg"]] %>% - .[[1]] + dplyr::filter(engine == !!engine) %>% + .[["pkg"]] %>% + .[[1]] res$fit <- rlang::env_get(m_env, paste0(model, "_fit")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::pull(value) %>% - .[[1]] + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + .[[1]] pred_code <- rlang::env_get(m_env, paste0(model, "_predict")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::select(-engine, -mode) + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) res$pred <- pred_code[["value"]] names(res$pred) <- pred_code$type @@ -139,7 +139,7 @@ deharmonize <- function(args, key) { parsn <- tibble::tibble(exposed = names(args), order = seq_along(args)) merged <- dplyr::left_join(parsn, key, by = "exposed") %>% - dplyr::arrange(order) + dplyr::arrange(order) # TODO correct for bad merge? names(args) <- merged$original @@ -149,8 +149,8 @@ deharmonize <- function(args, key) { check_empty_ellipse_tidyclust <- function(...) { terms <- quos(...) if (!rlang::is_empty(terms)) { - rlang::abort( - "Please pass other arguments to the model function via `set_engine()`." + cli::cli_abort( + "Please pass other arguments to the model function via {.fn set_engine}." ) } terms diff --git a/R/tunable.R b/R/tunable.R index da49bd95..534cd692 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -5,37 +5,33 @@ tunable.cluster_spec <- function(x, ...) { mod_env <- rlang::ns_env("modelenv")$modelenv if (is.null(x$engine)) { - rlang::abort( - "Please declare an engine first using `set_engine()`.", - call. = FALSE + cli::cli_abort( + "Please declare an engine first using {.fn set_engine}.", + call = FALSE ) } arg_name <- paste0(mod_type(x), "_args") if (!(any(arg_name == names(mod_env)))) { - rlang::abort( - paste( - "The `tidyclust` model database doesn't know about the arguments for ", - "model `", mod_type(x), "`. Was it registered?", - sep = "" - ), - call. = FALSE + cli::cli_abort( + "The {.pkg tidyclust} model database doesn't know about the arguments for + model {.code {mod_type(x)}}. Was it registered?" ) } arg_vals <- mod_env[[arg_name]] %>% - dplyr::filter(engine == x$engine) %>% - dplyr::select(name = exposed, call_info = func) %>% - dplyr::full_join( - tibble::tibble(name = c(names(x$args), names(x$eng_args))), - by = "name" - ) %>% - dplyr::mutate( - source = "cluster_spec", - component = mod_type(x), - component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") - ) + dplyr::filter(engine == x$engine) %>% + dplyr::select(name = exposed, call_info = func) %>% + dplyr::full_join( + tibble::tibble(name = c(names(x$args), names(x$eng_args))), + by = "name" + ) %>% + dplyr::mutate( + source = "cluster_spec", + component = mod_type(x), + component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") + ) if (nrow(arg_vals) > 0) { has_info <- map_lgl(arg_vals$call_info, is.null) diff --git a/R/tune_args.R b/R/tune_args.R index e0cac86b..42807b6e 100644 --- a/R/tune_args.R +++ b/R/tune_args.R @@ -77,10 +77,10 @@ find_tune_id <- function(x) { } if (sum(tunable_elems == "", na.rm = TRUE) > 1) { - rlang::abort( - glue::glue( - "Only one tunable value is currently allowed per argument. ", - "The current argument has: `{paste0(deparse(x), collapse = '')}`." + cli::cli_abort( + c( + "Only one tunable value is currently allowed per argument.", + "i" = "The current argument has: {.code {paste0(deparse(x), collapse = '')}}." ) ) } @@ -124,24 +124,20 @@ tune_id <- function(x) { NA_character_ } -tune_tbl <- function(name = character(), - tunable = logical(), - id = character(), - source = character(), - component = character(), - component_id = character(), - full = FALSE) { +tune_tbl <- function( + name = character(), + tunable = logical(), + id = character(), + source = character(), + component = character(), + component_id = character(), + full = FALSE +) { complete_id <- id[!is.na(id)] dups <- duplicated(complete_id) if (any(dups)) { - rlang::abort( - paste( - "There are duplicate `id` values listed in [tune()]: ", - paste0("'", unique(complete_id[dups]), "'", collapse = ", "), - ".", - sep = "" - ), - call. = FALSE + cli::cli_abort( + "There are duplicate {.code id} values listed in [{.fn tune}]: {.val {unique(complete_id[dups])}}." ) } diff --git a/R/tune_cluster.R b/R/tune_cluster.R index 0cf06310..d46db5ca 100644 --- a/R/tune_cluster.R +++ b/R/tune_cluster.R @@ -62,24 +62,27 @@ tune_cluster <- function(object, ...) { #' @export tune_cluster.default <- function(object, ...) { - msg <- paste0( - "The first argument to [tune_cluster()] should be either ", - "a model or workflow." + cli::cli_abort( + "The first argument to {.fn tune_cluster} should be either a model or workflow." ) - rlang::abort(msg) } #' @export #' @rdname tune_cluster -tune_cluster.cluster_spec <- function(object, preprocessor, resamples, ..., - param_info = NULL, grid = 10, - metrics = NULL, - control = tune::control_grid()) { +tune_cluster.cluster_spec <- function( + object, + preprocessor, + resamples, + ..., + param_info = NULL, + grid = 10, + metrics = NULL, + control = tune::control_grid() +) { if (rlang::is_missing(preprocessor) || !tune::is_preprocessor(preprocessor)) { - rlang::abort(paste( - "To tune a model spec, you must preprocess", - "with a formula or recipe" - )) + cli::cli_abort( + "To tune a model spec, you must preprocess with a formula or recipe." + ) } tune::empty_ellipses(...) @@ -106,9 +109,15 @@ tune_cluster.cluster_spec <- function(object, preprocessor, resamples, ..., #' @export #' @rdname tune_cluster -tune_cluster.workflow <- function(object, resamples, ..., param_info = NULL, - grid = 10, metrics = NULL, - control = tune::control_grid()) { +tune_cluster.workflow <- function( + object, + resamples, + ..., + param_info = NULL, + grid = 10, + metrics = NULL, + control = tune::control_grid() +) { tune::empty_ellipses(...) control <- parsnip::condense_control(control, tune::control_grid()) @@ -116,7 +125,7 @@ tune_cluster.workflow <- function(object, resamples, ..., param_info = NULL, # Disallow `NULL` grids in `tune_cluster()`, as this is the special signal # used when no tuning is required if (is.null(grid)) { - rlang::abort(grid_msg) + cli::cli_abort(grid_msg) } tune_cluster_workflow( @@ -131,13 +140,15 @@ tune_cluster.workflow <- function(object, resamples, ..., param_info = NULL, # ------------------------------------------------------------------------------ -tune_cluster_workflow <- function(workflow, - resamples, - grid = 10, - metrics = NULL, - pset = NULL, - control = NULL, - rng = TRUE) { +tune_cluster_workflow <- function( + workflow, + resamples, + grid = 10, + metrics = NULL, + pset = NULL, + control = NULL, + rng = TRUE +) { tune::check_rset(resamples) metrics <- check_metrics(metrics, workflow) @@ -171,7 +182,12 @@ tune_cluster_workflow <- function(workflow, ) if (is_cataclysmic(resamples)) { - rlang::warn("All models failed. See the `.notes` column.") + cli::cli_warn( + c( + "All models failed.", + "i" = "See the {.code .notes} column." + ) + ) } workflow <- set_workflow(workflow, control) @@ -185,12 +201,14 @@ tune_cluster_workflow <- function(workflow, ) } -tune_cluster_loop <- function(resamples, - grid, - workflow, - metrics, - control, - rng) { +tune_cluster_loop <- function( + resamples, + grid, + workflow, + metrics, + control, + rng +) { `%op%` <- get_operator(control$allow_par, workflow) `%:%` <- foreach::`%:%` @@ -222,69 +240,73 @@ tune_cluster_loop <- function(resamples, # created by `eval()`. This causes the handler to run much too early. By evaluating in # a local environment, we prevent `defer()`/`on.exit()` from finding the short-lived # context of `%op%`. Instead it looks all the way up here to register the handler. - + results <- local({ suppressPackageStartupMessages( foreach::foreach( - split = splits, - seed = seeds, - .packages = packages, - .errorhandling = "pass" - ) %op% { - # Extract internal function from tune namespace - tune_cluster_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - tune_cluster_loop_iter_safely( - split = split, - grid_info = grid_info, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - ) - }) + split = splits, + seed = seeds, + .packages = packages, + .errorhandling = "pass" + ) %op% + { + # Extract internal function from tune namespace + tune_cluster_loop_iter_safely <- utils::getFromNamespace( + x = "tune_cluster_loop_iter_safely", + ns = "tidyclust" + ) + + tune_cluster_loop_iter_safely( + split = split, + grid_info = grid_info, + workflow = workflow, + metrics = metrics, + control = control, + seed = seed + ) + } + ) + }) } else if (identical(parallel_over, "everything")) { seeds <- generate_seeds(rng, n_resamples * n_grid_info) - results <- local(suppressPackageStartupMessages( - foreach::foreach( - iteration = iterations, - split = splits, - .packages = packages, - .errorhandling = "pass" - ) %:% + results <- local( + suppressPackageStartupMessages( foreach::foreach( - row = rows, - seed = slice_seeds(seeds, iteration, n_grid_info), + iteration = iterations, + split = splits, .packages = packages, - .errorhandling = "pass", - .combine = iter_combine - ) %op% { - # Extract internal function from tidyclust namespace - tune_grid_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - grid_info_row <- vctrs::vec_slice(grid_info, row) - - tune_grid_loop_iter_safely( - split = split, - grid_info = grid_info_row, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - )) + .errorhandling = "pass" + ) %:% + foreach::foreach( + row = rows, + seed = slice_seeds(seeds, iteration, n_grid_info), + .packages = packages, + .errorhandling = "pass", + .combine = iter_combine + ) %op% + { + # Extract internal function from tidyclust namespace + tune_grid_loop_iter_safely <- utils::getFromNamespace( + x = "tune_cluster_loop_iter_safely", + ns = "tidyclust" + ) + + grid_info_row <- vctrs::vec_slice(grid_info, row) + + tune_grid_loop_iter_safely( + split = split, + grid_info = grid_info_row, + workflow = workflow, + metrics = metrics, + control = control, + seed = seed + ) + } + ) + ) } else { - rlang::abort("Internal error: Invalid `parallel_over`.") + cli::cli_abort("Internal error: Invalid {.arg parallel_over}.") } resamples <- pull_metrics(resamples, results, control) @@ -311,7 +333,8 @@ compute_grid_info <- function(workflow, grid) { if (any_parameters_preprocessor) { compute_grid_info_model_and_preprocessor( workflow, - grid, parameters_model + grid, + parameters_model ) } else { compute_grid_info_model(workflow, grid, parameters_model) @@ -320,23 +343,24 @@ compute_grid_info <- function(workflow, grid) { if (any_parameters_preprocessor) { compute_grid_info_preprocessor(workflow, grid, parameters_model) } else { - rlang::abort( - paste0( - "Internal error: ", - "`workflow` should have some tunable parameters ", - "if `grid` is not `NULL`." + cli::cli_abort( + c( + "Internal error: {.code workflow} should have some tunable parameters + if {.code grid} is not {.code NULL}." ) ) } } } -tune_cluster_loop_iter <- function(split, - grid_info, - workflow, - metrics, - control, - seed) { +tune_cluster_loop_iter <- function( + split, + grid_info, + workflow, + metrics, + control, + seed +) { load_pkgs(workflow) load_namespace(control$pkgs) @@ -541,12 +565,14 @@ tune_cluster_loop_iter <- function(split, ) } -tune_cluster_loop_iter_safely <- function(split, - grid_info, - workflow, - metrics, - control, - seed) { +tune_cluster_loop_iter_safely <- function( + split, + grid_info, + workflow, + metrics, + control, + seed +) { tune_cluster_loop_iter_wrapper <- super_safely(tune_cluster_loop_iter) time <- proc.time() @@ -617,7 +643,8 @@ super_safely <- function(fn) { expr = tryCatch( expr = list( result = fn(...), - error = NULL, warnings = warnings + error = NULL, + warnings = warnings ), error = handle_error ), @@ -636,37 +663,45 @@ compute_grid_info_model <- function(workflow, grid, parameters_model) { msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L) msgs_preprocessor <- rep(msgs_preprocessor, times = n_fit_models) msgs_model <- new_msgs_model( - i = seq_fit_models, n = n_fit_models, + i = seq_fit_models, + n = n_fit_models, msgs_preprocessor = msgs_preprocessor ) iter_configs <- compute_config_ids(out, "Preprocessor1") out <- tibble::add_column( - .data = out, .iter_preprocessor = 1L, + .data = out, + .iter_preprocessor = 1L, .before = 1L ) out <- tibble::add_column( - .data = out, .msg_preprocessor = msgs_preprocessor, + .data = out, + .msg_preprocessor = msgs_preprocessor, .after = ".iter_preprocessor" ) out <- tibble::add_column( - .data = out, .iter_model = seq_fit_models, + .data = out, + .iter_model = seq_fit_models, .after = ".msg_preprocessor" ) out <- tibble::add_column( - .data = out, .iter_config = iter_configs, + .data = out, + .iter_config = iter_configs, .after = ".iter_model" ) out <- tibble::add_column( - .data = out, .msg_model = msgs_model, + .data = out, + .msg_model = msgs_model, .after = ".iter_config" ) out } # https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L484 -compute_grid_info_model_and_preprocessor <- function(workflow, - grid, - parameters_model) { +compute_grid_info_model_and_preprocessor <- function( + workflow, + grid, + parameters_model +) { parameter_names_model <- parameters_model[["id"]] # Nest model parameters, keep preprocessor parameters outside @@ -751,9 +786,7 @@ compute_grid_info_model_and_preprocessor <- function(workflow, } # https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L359 -compute_grid_info_preprocessor <- function(workflow, - grid, - parameters_model) { +compute_grid_info_preprocessor <- function(workflow, grid, parameters_model) { out <- grid n_preprocessors <- nrow(out) @@ -825,19 +858,17 @@ check_metrics <- function(x, object) { mode <- extract_spec_parsnip(object)$mode if (is.null(x)) { - switch(mode, + switch( + mode, partition = { x <- cluster_metric_set(sse_within_total, sse_total) }, unknown = { - rlang::abort( - paste0( - "Internal error: ", - "`check_installs()` should have caught an `unknown` mode." - ) + cli::cli_abort( + "Internal error: {.fn check_installs} should have caught an {.code unknown} mode." ) }, - rlang::abort("Unknown `mode` for parsnip model.") + cli::cli_abort("Unknown {.arg mode} for tidyclust model.") ) return(x) @@ -846,21 +877,20 @@ check_metrics <- function(x, object) { is_cluster_metric_set <- inherits(x, "cluster_metric_set") if (!is_cluster_metric_set) { - rlang::abort( - paste0( - "The `metrics` argument should be the results of ", - "[cluster_metric_set()]." - ) + cli::cli_abort( + "The {.arg metrics} argument should be the results of {.fn cluster_metric_set}." ) } x } # https://github.com/tidymodels/tune/blob/main/R/checks.R#L144 -check_parameters <- function(workflow, - pset = NULL, - data, - grid_names = character(0)) { +check_parameters <- function( + workflow, + pset = NULL, + data, + grid_names = character(0) +) { if (is.null(pset)) { pset <- hardhat::extract_parameter_set_dials(workflow) } @@ -874,11 +904,11 @@ check_parameters <- function(workflow, if (needs_finalization(pset, grid_names)) { if (tune_recipe) { - rlang::abort( - paste( - "Some tuning parameters require finalization but there are recipe", - "parameters that require tuning. Please use `parameters()` to", - "finalize the parameter ranges." + cli::cli_abort( + c( + "Some tuning parameters require finalization but there are recipe + parameters that require tuning.", + "i" = "Please use {.fn parameters} to finalize the parameter ranges." ) ) } @@ -913,15 +943,17 @@ needs_finalization <- function(x, nms = character(0)) { # https://github.com/tidymodels/tune/blob/main/R/checks.R#L274 check_workflow <- function(x, pset = NULL, check_dials = FALSE) { if (!inherits(x, "workflow")) { - rlang::abort("The `object` argument should be a 'workflow' object.") + cli::cli_abort( + "The {.arg object} argument should be a {.cls workflow} object." + ) } if (!has_preprocessor(x)) { - rlang::abort("A formula, recipe, or variables preprocessor is required.") + cli::cli_abort("A formula, recipe, or variables preprocessor is required.") } if (!has_spec(x)) { - rlang::abort("A tidyclust model is required.") + cli::cli_abort("A tidyclust model is required.") } if (check_dials) { @@ -934,10 +966,9 @@ check_workflow <- function(x, pset = NULL, check_dials = FALSE) { incompl <- dials::has_unknowns(pset$object) if (any(incompl)) { - rlang::abort(paste0( - "The workflow has arguments whose ranges are not finalized: ", - paste0("'", pset$id[incompl], "'", collapse = ", ") - )) + cli::cli_abort( + "The workflow has arguments whose ranges are not finalized: {.arg {pset$id[incompl]}}." + ) } } @@ -952,11 +983,9 @@ check_param_objects <- function(pset) { params <- map_lgl(pset$object, inherits, "param") if (!all(params)) { - rlang::abort(paste0( - "The workflow has arguments to be tuned that are missing some ", - "parameter objects: ", - paste0("'", pset$id[!params], "'", collapse = ", ") - )) + cli::cli_abort( + "The workflow has arguments to be tuned that are missing parameter objects: {.arg {pset$id[!params]}}." + ) } invisible(pset) } @@ -975,12 +1004,13 @@ check_grid <- function(grid, workflow, pset = NULL) { } if (nrow(pset) == 0L) { - msg <- paste0( - "No tuning parameters have been detected, ", - "performance will be evaluated using the resamples with no tuning. ", - "Did you want to [tune()] parameters?" + cli::cli_warn( + c( + "No tuning parameters have been detected, performance will be evaluated using + the resamples with no tuning.", + "i" = "Did you want to {.fn tune} parameters?" + ) ) - rlang::warn(msg) # Return `NULL` as the new `grid`, like what is used in `fit_resamples()` return(NULL) @@ -988,12 +1018,12 @@ check_grid <- function(grid, workflow, pset = NULL) { if (!is.numeric(grid)) { if (!is.data.frame(grid)) { - rlang::abort(grid_msg) + cli::cli_abort(grid_msg) } grid_distinct <- dplyr::distinct(grid) if (!identical(nrow(grid_distinct), nrow(grid))) { - rlang::warn( + cli::cli_warn( "Duplicate rows in grid of tuning combinations found and removed." ) } @@ -1014,33 +1044,29 @@ check_grid <- function(grid, workflow, pset = NULL) { extra_grid_params <- glue::single_quote(extra_grid_params) extra_grid_params <- glue::glue_collapse(extra_grid_params, sep = ", ") - msg <- glue::glue( - "The provided `grid` has the following parameter columns that have ", - "not been marked for tuning by `tune()`: {extra_grid_params}." + cli::cli_abort( + "The provided {.arg grid} has parameter column{?s} {extra_grid_params} + that {?has/have} not been marked for tuning by {.fn tune}." ) - - rlang::abort(msg) } if (length(extra_tune_params) != 0L) { extra_tune_params <- glue::single_quote(extra_tune_params) extra_tune_params <- glue::glue_collapse(extra_tune_params, sep = ", ") - msg <- glue::glue( - "The provided `grid` is missing the following parameter columns that ", - "have been marked for tuning by `tune()`: {extra_tune_params}." + cli::cli_abort( + "The provided {.arg grid} is missing parameter column{?s} {.val {extra_tune_params}} + that {?has/have} been marked for tuning by {.fn tune}." ) - - rlang::abort(msg) } } else { grid <- as.integer(grid[1]) if (grid < 1) { - rlang::abort(grid_msg) + cli::cli_abort(grid_msg) } check_workflow(workflow, pset = pset, check_dials = TRUE) - grid <- dials::grid_latin_hypercube(pset, size = grid) + grid <- dials::grid_space_filling(pset, size = grid) grid <- dplyr::distinct(grid) } diff --git a/R/tune_helpers.R b/R/tune_helpers.R index 86597404..9f75615d 100644 --- a/R/tune_helpers.R +++ b/R/tune_helpers.R @@ -6,13 +6,20 @@ new_bare_tibble <- function(x, ..., class = character()) { } is_cataclysmic <- function(x) { - is_err <- map_lgl(x$.metrics, inherits, c( - "simpleError", - "error" - )) + is_err <- map_lgl( + x$.metrics, + inherits, + c( + "simpleError", + "error" + ) + ) if (any(!is_err)) { - is_good <- map_lgl(x$.metrics[!is_err], ~ tibble::is_tibble(.x) && - nrow(.x) > 0) + is_good <- map_lgl( + x$.metrics[!is_err], + ~tibble::is_tibble(.x) && + nrow(.x) > 0 + ) is_err[!is_err] <- !is_good } all(is_err) @@ -24,19 +31,11 @@ set_workflow <- function(workflow, control) { if (!is.null(workflow$pre$actions$recipe)) { w_size <- utils::object.size(workflow$pre$actions$recipe) if (w_size / 1024^2 > 5) { - msg <- paste0( - "The workflow being saved contains a recipe, which is ", - format(w_size, units = "Mb", digits = 2), - " in memory. If this was not intentional, please set the control ", - "setting `save_workflow = FALSE`." + cli::cli_inform( + "The workflow being saved contains a recipe, which is {format(w_size, units = 'Mb', + digits = 2)} in memory. If this was not intentional, please set the control + setting {.code save_workflow = FALSE}." ) - cols <- get_tidyclust_colors() - msg <- strwrap(msg, prefix = paste0( - cols$symbol$info(cli::symbol$info), - " " - )) - msg <- cols$message$info(paste0(msg, collapse = "\n")) - rlang::inform(msg) } } workflow @@ -46,8 +45,14 @@ set_workflow <- function(workflow, control) { } # https://github.com/tidymodels/tune/blob/main/R/tune_results.R -new_tune_results <- function(x, parameters, metrics, - rset_info, ..., class = character()) { +new_tune_results <- function( + x, + parameters, + metrics, + rset_info, + ..., + class = character() +) { new_bare_tibble( x = x, parameters = parameters, @@ -92,8 +97,11 @@ new_grid_info_resamples <- function() { ) iter_config <- list("Preprocessor1_Model1") out <- tibble::tibble( - .iter_preprocessor = 1L, .msg_preprocessor = msgs_preprocessor, - .iter_model = 1L, .iter_config = iter_config, .msg_model = msgs_model, + .iter_preprocessor = 1L, + .msg_preprocessor = msgs_preprocessor, + .iter_model = 1L, + .iter_config = iter_config, + .msg_model = msgs_model, .submodels = list(list()) ) out @@ -153,7 +161,7 @@ min_grid.cluster_spec <- function(x, grid, ...) { blank_submodels <- function(grid) { grid %>% dplyr::mutate( - .submodels = map(seq_along(nrow(grid)), ~ list()) + .submodels = map(seq_along(nrow(grid)), ~list()) ) %>% dplyr::mutate_if(is.factor, as.character) } @@ -218,9 +226,7 @@ catcher <- function(expr) { signals <<- append(signals, list(cnd)) rlang::cnd_muffle(cnd) } - res <- try(withCallingHandlers(warning = add_cond, expr), - silent = TRUE - ) + res <- try(withCallingHandlers(warning = add_cond, expr), silent = TRUE) list(res = res, signals = signals) } @@ -232,16 +238,17 @@ siren <- function(x, type = "info") { symb <- dplyr::case_when( type == "warning" ~ tidyclust_color$symbol$warning("!"), type == "go" ~ tidyclust_color$symbol$go(cli::symbol$pointer), - type == "danger" ~ tidyclust_color$symbol$danger("x"), type == - "success" ~ tidyclust_color$symbol$success(tidyclust_symbol$success), + type == "danger" ~ tidyclust_color$symbol$danger("x"), + type == "success" ~ + tidyclust_color$symbol$success(tidyclust_symbol$success), type == "info" ~ tidyclust_color$symbol$info("i") ) msg <- dplyr::case_when( type == "warning" ~ tidyclust_color$message$warning(msg), - type == "go" ~ tidyclust_color$message$go(msg), type == "danger" ~ - tidyclust_color$message$danger(msg), type == "success" ~ - tidyclust_color$message$success(msg), type == "info" ~ - tidyclust_color$message$info(msg) + type == "go" ~ tidyclust_color$message$go(msg), + type == "danger" ~ tidyclust_color$message$danger(msg), + type == "success" ~ tidyclust_color$message$success(msg), + type == "info" ~ tidyclust_color$message$info(msg) ) if (inherits(msg, "character")) { msg <- as.character(msg) @@ -254,15 +261,17 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { control2$verbose <- TRUE wrn <- res$signals if (length(wrn) > 0) { - wrn_msg <- map_chr(wrn, ~ .x$message) + wrn_msg <- map_chr(wrn, ~.x$message) wrn_msg <- unique(wrn_msg) wrn_msg <- paste(wrn_msg, collapse = ", ") wrn_msg <- tibble::tibble( - location = loc, type = "warning", + location = loc, + type = "warning", note = wrn_msg ) notes <- dplyr::bind_rows(notes, wrn_msg) - wrn_msg <- glue::glue_collapse(paste0(loc, ": ", wrn_msg$note), + wrn_msg <- glue::glue_collapse( + paste0(loc, ": ", wrn_msg$note), width = options()$width - 5 ) tune_log(control2, split, wrn_msg, type = "warning") @@ -271,11 +280,13 @@ log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { err_msg <- as.character(attr(res$res, "condition")) err_msg <- gsub("\n$", "", err_msg) err_msg <- tibble::tibble( - location = loc, type = "error", + location = loc, + type = "error", note = err_msg ) notes <- dplyr::bind_rows(notes, err_msg) - err_msg <- glue::glue_collapse(paste0(loc, ": ", err_msg$note), + err_msg <- glue::glue_collapse( + paste0(loc, ": ", err_msg$note), width = options()$width - 5 ) tune_log(control2, split, err_msg, type = "danger") @@ -309,7 +320,7 @@ merge.cluster_spec <- function(x, y, ...) { merger <- function(x, y, ...) { if (!is.data.frame(y)) { - rlang::abort("The second argument should be a data frame.") + cli::cli_abort("The second argument should be a data frame.") } pset <- hardhat::extract_parameter_set_dials(x) if (nrow(pset) == 0) { @@ -319,7 +330,7 @@ merger <- function(x, y, ...) { grid_name <- colnames(y) if (inherits(x, "recipe")) { updater <- update_recipe - step_ids <- map_chr(x$steps, ~ .x$id) + step_ids <- map_chr(x$steps, ~.x$id) } else { updater <- update_model step_ids <- NULL @@ -332,7 +343,7 @@ merger <- function(x, y, ...) { dplyr::mutate( ..object = map( seq_along(nrow(y)), - ~ updater(y[.x, ], x, pset, step_ids, grid_name) + ~updater(y[.x, ], x, pset, step_ids, grid_name) ) ) %>% dplyr::select(x = ..object) @@ -344,7 +355,7 @@ update_model <- function(grid, object, pset, step_id, nms, ...) { param_info <- pset %>% dplyr::filter(id == i & source == "cluster_spec") if (nrow(param_info) > 1) { # TODO figure this out and write a better message - rlang::abort("There are too many things.") + cli::cli_abort("There are too many things.") } if (nrow(param_info) == 1) { if (param_info$component_id == "main") { @@ -384,7 +395,7 @@ catch_and_log_fit <- function(expr, ..., notes) { return(result) } if (!is_workflow(result)) { - rlang::abort("Internal error: Model result is not a workflow!") + cli::cli_abort("Internal error: Model result is not a workflow!") } fit <- result$fit$fit$fit if (is_failure(fit)) { @@ -427,7 +438,7 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL) { ) } - rlang::abort(msg) + cli::cli_abort(msg) } # Determine the type of prediction that is required @@ -440,8 +451,8 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL) { # Regular predictions tmp_res <- stats::predict(model, x_vals, type = type_iter) %>% - dplyr::mutate(.row = orig_rows) %>% - cbind(grid, row.names = NULL) + dplyr::mutate(.row = orig_rows) %>% + cbind(grid, row.names = NULL) if (!is.null(submodels)) { submod_length <- lengths(submodels) @@ -570,16 +581,18 @@ slice_seeds <- function(x, i, n) { iter_combine <- function(...) { results <- list(...) - metrics <- map(results, ~ .x[[".metrics"]]) - extracts <- map(results, ~ .x[[".extracts"]]) - predictions <- map(results, ~ .x[[".predictions"]]) - notes <- map(results, ~ .x[[".notes"]]) + metrics <- map(results, ~.x[[".metrics"]]) + extracts <- map(results, ~.x[[".extracts"]]) + predictions <- map(results, ~.x[[".predictions"]]) + notes <- map(results, ~.x[[".notes"]]) metrics <- vctrs::vec_c(!!!metrics) extracts <- vctrs::vec_c(!!!extracts) predictions <- vctrs::vec_c(!!!predictions) notes <- vctrs::vec_c(!!!notes) list( - .metrics = metrics, .extracts = extracts, .predictions = predictions, + .metrics = metrics, + .extracts = extracts, + .predictions = predictions, .notes = notes ) } diff --git a/dev/cross_val_kmeans.R b/dev/cross_val_kmeans.R index 55f972ab..eae5e935 100644 --- a/dev/cross_val_kmeans.R +++ b/dev/cross_val_kmeans.R @@ -17,12 +17,10 @@ res <- data.frame( ) for (k in 2:10) { - km <- k_means(k = k) %>% set_engine("stats") for (i in 1:5) { - tmp_train <- training(cvs$splits[[i]]) tmp_test <- testing(cvs$splits[[i]]) @@ -36,11 +34,8 @@ for (k in 2:10) { sil <- km_fit %>% silhouette_avg(tmp_test) - res <- rbind(res, - c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) - + res <- rbind(res, c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2)) } - } res %>% @@ -63,14 +58,12 @@ res <- data.frame( ) for (k in 2:10) { - km <- k_means(k = k) %>% set_engine("stats") full_fit <- km %>% fit(~., data = ir) for (i in 1:10) { - tmp_train <- training(cvs$splits[[i]]) tmp_test <- testing(cvs$splits[[i]]) @@ -87,11 +80,11 @@ for (k in 2:10) { acc <- accuracy(thing, clusters_1, clusters_2) f1 <- f_meas(thing, clusters_1, clusters_2) - res <- rbind(res, - c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate)) - + res <- rbind( + res, + c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate) + ) } - } res %>% diff --git a/dev/test_hc.R b/dev/test_hc.R index 006ef3d2..963d5c0e 100644 --- a/dev/test_hc.R +++ b/dev/test_hc.R @@ -1,14 +1,14 @@ library(tidyverse) library(celery) -ir <- iris[,-5] +ir <- iris[, -5] hclust(dist(ir)) bob <- hclust_fit(ir) hc <- hier_clust(k = 3) %>% - fit(~ ., data = ir) + fit(~., data = ir) km <- k_means(k = 3) %>% fit(~., data = ir) @@ -20,7 +20,7 @@ thing <- tibble( ) thing %>% - count(hc_c,truth) + count(hc_c, truth) cutree(hc$fit, k = 3) diff --git a/man/dot-hier_clust_fit_stats.Rd b/man/dot-hier_clust_fit_stats.Rd index 372e7cea..08d33f3f 100644 --- a/man/dot-hier_clust_fit_stats.Rd +++ b/man/dot-hier_clust_fit_stats.Rd @@ -9,7 +9,7 @@ num_clusters = NULL, cut_height = NULL, linkage_method = NULL, - dist_fun = Rfast::Dist + dist_fun = philentropy::distance ) } \arguments{ diff --git a/man/get_centroid_dists.Rd b/man/get_centroid_dists.Rd index 267a98a9..887923e6 100644 --- a/man/get_centroid_dists.Rd +++ b/man/get_centroid_dists.Rd @@ -4,7 +4,14 @@ \alias{get_centroid_dists} \title{Computes distance from observations to centroids} \usage{ -get_centroid_dists(new_data, centroids, dist_fun = Rfast::dista) +get_centroid_dists( + new_data, + centroids, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = + "euclidean") + } +) } \arguments{ \item{new_data}{A data frame} @@ -12,7 +19,8 @@ get_centroid_dists(new_data, centroids, dist_fun = Rfast::dista) \item{centroids}{A data frame where each row is a centroid.} \item{dist_fun}{A function for computing matrix-to-matrix distances. Defaults -to \code{Rfast::dista()}} +to +\code{function(x, y) philentropy::dist_many_many(x, y, method = "euclidean")}.} } \description{ Computes distance from observations to centroids diff --git a/man/prep_data_dist.Rd b/man/prep_data_dist.Rd index 163c0ea8..229bfd86 100644 --- a/man/prep_data_dist.Rd +++ b/man/prep_data_dist.Rd @@ -4,7 +4,12 @@ \alias{prep_data_dist} \title{Prepares data and distance matrices for metric calculation} \usage{ -prep_data_dist(object, new_data = NULL, dists = NULL, dist_fun = Rfast::Dist) +prep_data_dist( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance +) } \arguments{ \item{object}{A fitted \code{\link{cluster_spec}} object.} diff --git a/man/set_args.cluster_spec.Rd b/man/set_args.cluster_spec.Rd index a9143d23..d36e8a90 100644 --- a/man/set_args.cluster_spec.Rd +++ b/man/set_args.cluster_spec.Rd @@ -7,7 +7,7 @@ \method{set_args}{cluster_spec}(object, ...) } \arguments{ -\item{object}{A model specification.} +\item{object}{A \link[parsnip:model_spec]{model specification}.} \item{...}{One or more named model arguments.} } diff --git a/man/set_engine.cluster_spec.Rd b/man/set_engine.cluster_spec.Rd index cfd14127..f0600ff0 100644 --- a/man/set_engine.cluster_spec.Rd +++ b/man/set_engine.cluster_spec.Rd @@ -7,7 +7,7 @@ \method{set_engine}{cluster_spec}(object, engine, ...) } \arguments{ -\item{object}{A model specification.} +\item{object}{A \link[parsnip:model_spec]{model specification}.} \item{engine}{A character string for the software that should be used to fit the model. This is highly dependent on the type diff --git a/man/set_mode.cluster_spec.Rd b/man/set_mode.cluster_spec.Rd index 5c620021..03d6b7d9 100644 --- a/man/set_mode.cluster_spec.Rd +++ b/man/set_mode.cluster_spec.Rd @@ -4,13 +4,15 @@ \alias{set_mode.cluster_spec} \title{Change mode of a cluster specification} \usage{ -\method{set_mode}{cluster_spec}(object, mode) +\method{set_mode}{cluster_spec}(object, mode, ...) } \arguments{ -\item{object}{A model specification.} +\item{object}{A \link[parsnip:model_spec]{model specification}.} \item{mode}{A character string for the model type (e.g. "classification" or "regression")} + +\item{...}{One or more named model arguments.} } \value{ An updated \code{\link{cluster_spec}} object. diff --git a/man/silhouette.Rd b/man/silhouette.Rd index 6840ad88..7a76d610 100644 --- a/man/silhouette.Rd +++ b/man/silhouette.Rd @@ -4,7 +4,12 @@ \alias{silhouette} \title{Measures silhouette between clusters} \usage{ -silhouette(object, new_data = NULL, dists = NULL, dist_fun = Rfast::Dist) +silhouette( + object, + new_data = NULL, + dists = NULL, + dist_fun = philentropy::distance +) } \arguments{ \item{object}{A fitted tidyclust model} diff --git a/man/silhouette_avg.Rd b/man/silhouette_avg.Rd index 8036d677..b8d8ed43 100644 --- a/man/silhouette_avg.Rd +++ b/man/silhouette_avg.Rd @@ -20,7 +20,7 @@ silhouette_avg_vec( object, new_data = NULL, dists = NULL, - dist_fun = Rfast::Dist, + dist_fun = philentropy::distance, ... ) } diff --git a/man/sse_ratio.Rd b/man/sse_ratio.Rd index 2adaceea..4e14ba0a 100644 --- a/man/sse_ratio.Rd +++ b/man/sse_ratio.Rd @@ -16,7 +16,15 @@ sse_ratio(object, ...) \method{sse_ratio}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) -sse_ratio_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +sse_ratio_vec( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = + "euclidean") + }, + ... +) } \arguments{ \item{object}{A fitted kmeans tidyclust model} diff --git a/man/sse_total.Rd b/man/sse_total.Rd index bea3c468..805b5e91 100644 --- a/man/sse_total.Rd +++ b/man/sse_total.Rd @@ -16,7 +16,15 @@ sse_total(object, ...) \method{sse_total}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) -sse_total_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +sse_total_vec( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = + "euclidean") + }, + ... +) } \arguments{ \item{object}{A fitted kmeans tidyclust model} diff --git a/man/sse_within.Rd b/man/sse_within.Rd index da4b3ecd..c79fb0df 100644 --- a/man/sse_within.Rd +++ b/man/sse_within.Rd @@ -4,7 +4,14 @@ \alias{sse_within} \title{Calculates Sum of Squared Error in each cluster} \usage{ -sse_within(object, new_data = NULL, dist_fun = Rfast::dista) +sse_within( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = + "euclidean") + } +) } \arguments{ \item{object}{A fitted kmeans tidyclust model} diff --git a/man/sse_within_total.Rd b/man/sse_within_total.Rd index 9483cb09..6baf5d88 100644 --- a/man/sse_within_total.Rd +++ b/man/sse_within_total.Rd @@ -16,7 +16,15 @@ sse_within_total(object, ...) \method{sse_within_total}{workflow}(object, new_data = NULL, dist_fun = NULL, ...) -sse_within_total_vec(object, new_data = NULL, dist_fun = Rfast::dista, ...) +sse_within_total_vec( + object, + new_data = NULL, + dist_fun = function(x, y) { + philentropy::dist_many_many(x, y, method = + "euclidean") + }, + ... +) } \arguments{ \item{object}{A fitted kmeans tidyclust model} diff --git a/tests/testthat/_snaps/arguments.md b/tests/testthat/_snaps/arguments.md index a23f6fa5..e1f2b21c 100644 --- a/tests/testthat/_snaps/arguments.md +++ b/tests/testthat/_snaps/arguments.md @@ -12,29 +12,30 @@ k_means() %>% set_mode() Condition Error in `modelenv::stop_incompatible_mode()`: - ! Available modes for model type k_means are: 'unknown', 'partition' + x Available modes for model type k_means are: + * "unknown" and "partition". --- Code k_means() %>% set_mode(2) Condition - Error in `modelenv::check_spec_mode_engine_val()`: - ! '2' is not a known mode for model `k_means()`. + Error in `set_mode()`: + ! 2 is not a known mode for model `k_means()`. --- Code k_means() %>% set_mode("haberdashery") Condition - Error in `modelenv::check_spec_mode_engine_val()`: - ! 'haberdashery' is not a known mode for model `k_means()`. + Error in `set_mode()`: + ! "haberdashery" is not a known mode for model `k_means()`. # can't set a mode that isn't allowed by the model spec Code set_mode(k_means(), "classification") Condition - Error in `modelenv::check_spec_mode_engine_val()`: - ! 'classification' is not a known mode for model `k_means()`. + Error in `set_mode()`: + ! "classification" is not a known mode for model `k_means()`. diff --git a/tests/testthat/_snaps/cluster_metric_set.md b/tests/testthat/_snaps/cluster_metric_set.md index a671956a..1b97e4fa 100644 --- a/tests/testthat/_snaps/cluster_metric_set.md +++ b/tests/testthat/_snaps/cluster_metric_set.md @@ -4,8 +4,7 @@ my_metrics(kmeans_fit) Condition Error in `value[[3L]]()`: - ! In metric: `silhouette_avg` - Must supply either a dataset or distance matrix to compute silhouettes. + ! In metric: `silhouette_avg` Must supply either a dataset or distance matrix to compute silhouettes. # cluster_metric_set error with wrong input @@ -13,10 +12,8 @@ cluster_metric_set(mean) Condition Error in `validate_function_class()`: - ! - The combination of metric functions must be: - - only clustering metrics - The following metric function types are being mixed: + ! The combination of metric functions must be only clustering metrics. + i The following metric function types are being mixed: - other (mean ) --- @@ -25,10 +22,8 @@ cluster_metric_set(sse_ratio, mean) Condition Error in `validate_function_class()`: - ! - The combination of metric functions must be: - - only clustering metrics - The following metric function types are being mixed: + ! The combination of metric functions must be only clustering metrics. + i The following metric function types are being mixed: - cluster (sse_ratio) - other (mean ) @@ -38,7 +33,7 @@ cluster_metric_set(silhouette) Condition Error in `cluster_metric_set()`: - ! `silhouette` is not a cluster metric. Did you mean `silhouette_avg`? + ! The value "silhouette" is not a cluster metric. Did you mean `silhouette_avg`? --- diff --git a/tests/testthat/_snaps/engines.md b/tests/testthat/_snaps/engines.md index d5685e00..43ba20b2 100644 --- a/tests/testthat/_snaps/engines.md +++ b/tests/testthat/_snaps/engines.md @@ -4,5 +4,6 @@ set_engine(k_means()) Condition Error in `set_engine()`: - ! Missing engine. Possible mode/engine combinations are: partition {stats, ClusterR, clustMixType, klaR} + ! Missing engine. + i Possible mode/engine combinations are: partition {stats, ClusterR, clustMixType, klaR}. diff --git a/tests/testthat/_snaps/extract_centroids.md b/tests/testthat/_snaps/extract_centroids.md index 3003671c..c2b75e8c 100644 --- a/tests/testthat/_snaps/extract_centroids.md +++ b/tests/testthat/_snaps/extract_centroids.md @@ -4,7 +4,8 @@ extract_centroids(spec) Condition Error in `extract_centroids()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. # extract_centroids() errors for hier_clust() with missing args @@ -20,7 +21,8 @@ hclust_fit %>% extract_centroids(k = 3) Condition Error in `extract_centroids()`: - ! Using `k` argument is not supported. Please use `num_clusters` instead. + ! Using `k` argument is not supported. + i Please use `num_clusters` instead. # extract_centroids() errors for hier_clust() with h arg @@ -28,5 +30,6 @@ hclust_fit %>% extract_centroids(h = 3) Condition Error in `extract_centroids()`: - ! Using `h` argument is not supported. Please use `cut_height` instead. + ! Using `h` argument is not supported. + i Please use `cut_height` instead. diff --git a/tests/testthat/_snaps/extract_cluster_assignment.md b/tests/testthat/_snaps/extract_cluster_assignment.md index d1381c37..fdb378cb 100644 --- a/tests/testthat/_snaps/extract_cluster_assignment.md +++ b/tests/testthat/_snaps/extract_cluster_assignment.md @@ -4,7 +4,8 @@ extract_cluster_assignment(spec) Condition Error in `extract_cluster_assignment()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. # extract_cluster_assignment() errors for hier_clust() with missing args @@ -20,7 +21,8 @@ hclust_fit %>% extract_cluster_assignment(k = 3) Condition Error in `extract_cluster_assignment()`: - ! Using `k` argument is not supported. Please use `num_clusters` instead. + ! Using `k` argument is not supported. + i Please use `num_clusters` instead. # extract_cluster_assignment() errors for hier_clust() with h arg @@ -28,5 +30,6 @@ hclust_fit %>% extract_cluster_assignment(h = 3) Condition Error in `extract_cluster_assignment()`: - ! Using `h` argument is not supported. Please use `cut_height` instead. + ! Using `h` argument is not supported. + i Please use `cut_height` instead. diff --git a/tests/testthat/_snaps/extract_fit_summary.md b/tests/testthat/_snaps/extract_fit_summary.md index 42b14a91..2828d5b7 100644 --- a/tests/testthat/_snaps/extract_fit_summary.md +++ b/tests/testthat/_snaps/extract_fit_summary.md @@ -4,5 +4,6 @@ extract_fit_summary(spec) Condition Error in `extract_fit_summary()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. diff --git a/tests/testthat/_snaps/fiting.md b/tests/testthat/_snaps/fiting.md index 73345116..534d0770 100644 --- a/tests/testthat/_snaps/fiting.md +++ b/tests/testthat/_snaps/fiting.md @@ -4,7 +4,7 @@ k_means(num_clusters = 5) %>% fit_xy(mtcars, y = mtcars$mpg) Condition Error in `x_x()`: - ! Outcomes are not used in `cluster_spec` objects. + ! Outcomes are not used in objects. --- @@ -12,5 +12,5 @@ workflows::workflow(mpg ~ ., km) %>% fit(mtcars) Condition Error in `x_x()`: - ! Outcomes are not used in `cluster_spec` objects. + ! Outcomes are not used in objects. diff --git a/tests/testthat/_snaps/hier_clust.md b/tests/testthat/_snaps/hier_clust.md index 16bbb373..e6f80bf5 100644 --- a/tests/testthat/_snaps/hier_clust.md +++ b/tests/testthat/_snaps/hier_clust.md @@ -3,8 +3,8 @@ Code hier_clust(mode = "bogus") Condition - Error in `modelenv::check_spec_mode_engine_val()`: - ! 'bogus' is not a known mode for model `hier_clust()`. + Error in `hier_clust()`: + ! "bogus" is not a known mode for model `hier_clust()`. --- diff --git a/tests/testthat/_snaps/k_means.md b/tests/testthat/_snaps/k_means.md index fd1dbd65..ed769b86 100644 --- a/tests/testthat/_snaps/k_means.md +++ b/tests/testthat/_snaps/k_means.md @@ -3,8 +3,8 @@ Code k_means(mode = "bogus") Condition - Error in `modelenv::check_spec_mode_engine_val()`: - ! 'bogus' is not a known mode for model `k_means()`. + Error in `k_means()`: + ! "bogus" is not a known mode for model `k_means()`. --- diff --git a/tests/testthat/_snaps/metric-silhouette.md b/tests/testthat/_snaps/metric-silhouette.md index 155ad10b..a76c1ce4 100644 --- a/tests/testthat/_snaps/metric-silhouette.md +++ b/tests/testthat/_snaps/metric-silhouette.md @@ -4,7 +4,8 @@ silhouette(spec) Condition Error in `silhouette()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. # silhouette_avg() errors for cluster spec @@ -12,5 +13,6 @@ silhouette_avg(spec) Condition Error in `silhouette_avg()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. diff --git a/tests/testthat/_snaps/metric-sse.md b/tests/testthat/_snaps/metric-sse.md index c8a3b91e..fbbceb28 100644 --- a/tests/testthat/_snaps/metric-sse.md +++ b/tests/testthat/_snaps/metric-sse.md @@ -28,5 +28,6 @@ sse_ratio(spec) Condition Error in `sse_ratio()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. diff --git a/tests/testthat/_snaps/predict.md b/tests/testthat/_snaps/predict.md index 4fd978b6..f5c91b90 100644 --- a/tests/testthat/_snaps/predict.md +++ b/tests/testthat/_snaps/predict.md @@ -4,7 +4,8 @@ predict(spec) Condition Error in `predict()`: - ! This function requires a fitted model. Please use `fit()` on your cluster specification. + ! This function requires a fitted model. + i Please use `fit()` on your cluster specification. # predict() errors for hier_clust() with missing args @@ -20,7 +21,8 @@ hclust_fit %>% predict(mtcars, k = 3) Condition Error in `predict()`: - ! Using `k` argument is not supported. Please use `num_clusters` instead. + ! Using `k` argument is not supported. + i Please use `num_clusters` instead. # predict() errors for hier_clust() with h arg @@ -28,5 +30,6 @@ hclust_fit %>% predict(mtcars, h = 3) Condition Error in `predict()`: - ! Using `h` argument is not supported. Please use `cut_height` instead. + ! Using `h` argument is not supported. + i Please use `cut_height` instead. diff --git a/tests/testthat/_snaps/reconcile_clusterings.md b/tests/testthat/_snaps/reconcile_clusterings.md index 77bd9932..24904d06 100644 --- a/tests/testthat/_snaps/reconcile_clusterings.md +++ b/tests/testthat/_snaps/reconcile_clusterings.md @@ -5,7 +5,7 @@ alt_cluster_assignment, one_to_one = TRUE) Condition Error in `reconcile_clusterings_mapping()`: - ! For one-to-one matching, must have the same number of clusters inprimary and alt. + ! For one-to-one matching, must have the same number of clusters in primary and alt. # reconciliation errors for uneven lengths diff --git a/tests/testthat/_snaps/tune_cluster.md b/tests/testthat/_snaps/tune_cluster.md index 9fa7c69f..64be851a 100644 --- a/tests/testthat/_snaps/tune_cluster.md +++ b/tests/testthat/_snaps/tune_cluster.md @@ -85,7 +85,8 @@ ! The following predi... Condition Warning: - All models failed. See the `.notes` column. + All models failed. + i See the `.notes` column. # argument order gives errors for recipes @@ -94,7 +95,7 @@ rsample::vfold_cv(mtcars, v = 2)) Condition Error in `tune_cluster()`: - ! The first argument to [tune_cluster()] should be either a model or workflow. + ! The first argument to `tune_cluster()` should be either a model or workflow. # argument order gives errors for formula @@ -103,7 +104,7 @@ mtcars, v = 2)) Condition Error in `tune_cluster()`: - ! The first argument to [tune_cluster()] should be either a model or workflow. + ! The first argument to `tune_cluster()` should be either a model or workflow. # ellipses with tune_cluster @@ -118,8 +119,8 @@ # A tibble: 2 x 4 splits id .metrics .notes - 1 Fold1 - 2 Fold2 + 1 Fold1 + 2 Fold2 # select_best() and show_best() works diff --git a/tests/testthat/_snaps/workflows.md b/tests/testthat/_snaps/workflows.md index 01de672d..e8898e3f 100644 --- a/tests/testthat/_snaps/workflows.md +++ b/tests/testthat/_snaps/workflows.md @@ -4,7 +4,7 @@ fit(wf_spec, data = mtcars) Condition Error in `x_x()`: - ! Outcomes are not used in `cluster_spec` objects. + ! Outcomes are not used in objects. # integrates with workflows::add_formula() @@ -12,7 +12,7 @@ fit(wf_spec, data = mtcars) Condition Error in `x_x()`: - ! Outcomes are not used in `cluster_spec` objects. + ! Outcomes are not used in objects. # integrates with workflows::add_recipe() @@ -20,5 +20,5 @@ fit(wf_spec, data = mtcars) Condition Error in `x_x()`: - ! Outcomes are not used in `cluster_spec` objects. + ! Outcomes are not used in objects. diff --git a/tests/testthat/helper-tidyclust-package.R b/tests/testthat/helper-tidyclust-package.R index 6635a112..aad0010a 100644 --- a/tests/testthat/helper-tidyclust-package.R +++ b/tests/testthat/helper-tidyclust-package.R @@ -1,14 +1,18 @@ -new_rng_snapshots <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 +new_rng_snapshots <- utils::compareVersion( + "3.6.0", + as.character(getRversion()) +) > + 0 helper_objects_tidyclust <- function() { rec_tune_1 <- recipes::recipe(~., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) %>% - recipes::step_pca(recipes::all_predictors(), num_comp = tune()) + recipes::step_normalize(recipes::all_predictors()) %>% + recipes::step_pca(recipes::all_predictors(), num_comp = tune()) rec_no_tune_1 <- recipes::recipe(~., data = mtcars) %>% - recipes::step_normalize(recipes::all_predictors()) + recipes::step_normalize(recipes::all_predictors()) kmeans_mod_no_tune <- k_means(num_clusters = 2) diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index 48cf3951..0b09c75e 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -9,8 +9,18 @@ test_that("partition models", { expect_equal( colnames(augment(reg_form, head(mtcars))), c( - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster" + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ".pred_cluster" ) ) expect_equal(nrow(augment(reg_form, head(mtcars))), 6) @@ -18,8 +28,18 @@ test_that("partition models", { expect_equal( colnames(augment(reg_xy, head(mtcars))), c( - "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am", - "gear", "carb", ".pred_cluster" + "mpg", + "cyl", + "disp", + "hp", + "drat", + "wt", + "qsec", + "vs", + "am", + "gear", + "carb", + ".pred_cluster" ) ) expect_equal(nrow(augment(reg_xy, head(mtcars))), 6) diff --git a/tests/testthat/test-cluster_metric_set.R b/tests/testthat/test-cluster_metric_set.R index bbfabe85..02c36ab3 100644 --- a/tests/testthat/test-cluster_metric_set.R +++ b/tests/testthat/test-cluster_metric_set.R @@ -4,13 +4,23 @@ test_that("cluster_metric_set works", { kmeans_fit <- fit(kmeans_spec, ~., mtcars) - my_metrics <- cluster_metric_set(sse_ratio, sse_total, sse_within_total, silhouette_avg) + my_metrics <- cluster_metric_set( + sse_ratio, + sse_total, + sse_within_total, + silhouette_avg + ) exp_res <- tibble::tibble( .metric = c("sse_ratio", "sse_total", "sse_within_total", "silhouette_avg"), .estimator = "standard", .estimate = vapply( - list(sse_ratio_vec, sse_total_vec, sse_within_total_vec, silhouette_avg_vec), + list( + sse_ratio_vec, + sse_total_vec, + sse_within_total_vec, + silhouette_avg_vec + ), function(x) x(kmeans_fit, new_data = mtcars), FUN.VALUE = numeric(1) ) diff --git a/tests/testthat/test-extract_centroids.R b/tests/testthat/test-extract_centroids.R index ea5ada38..ebbc8920 100644 --- a/tests/testthat/test-extract_centroids.R +++ b/tests/testthat/test-extract_centroids.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in extract_centroids()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_centroids(spec, prefix = "C_") diff --git a/tests/testthat/test-extract_cluster_assignment.R b/tests/testthat/test-extract_cluster_assignment.R index 7cca9a1f..bf2bb5d0 100644 --- a/tests/testthat/test-extract_cluster_assignment.R +++ b/tests/testthat/test-extract_cluster_assignment.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in extract_cluster_assignment()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_cluster_assignment(spec, prefix = "C_") diff --git a/tests/testthat/test-extract_fit_summary.R b/tests/testthat/test-extract_fit_summary.R index e0de8ae2..a931be38 100644 --- a/tests/testthat/test-extract_fit_summary.R +++ b/tests/testthat/test-extract_fit_summary.R @@ -66,7 +66,7 @@ test_that("extract_fit_summary() errors for cluster spec", { test_that("prefix is passed in extract_fit_summary()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- extract_fit_summary(spec, prefix = "C_") diff --git a/tests/testthat/test-fiting.R b/tests/testthat/test-fiting.R index 3d58f22e..8dc1e5a3 100644 --- a/tests/testthat/test-fiting.R +++ b/tests/testthat/test-fiting.R @@ -1,7 +1,6 @@ test_that("fit and fit_xy errors if outcome is provided", { - expect_error( - k_means(num_clusters = 5) %>% fit_xy(mtcars), - regexp = NA + expect_no_error( + k_means(num_clusters = 5) %>% fit_xy(mtcars) ) expect_snapshot( @@ -12,9 +11,8 @@ test_that("fit and fit_xy errors if outcome is provided", { km <- k_means(num_clusters = 5) - expect_error( - workflows::workflow(~., km) %>% fit(mtcars), - regexp = NA + expect_no_error( + workflows::workflow(~., km) %>% fit(mtcars) ) expect_snapshot( error = TRUE, diff --git a/tests/testthat/test-hier_clust-stats.R b/tests/testthat/test-hier_clust-stats.R index 36c9585c..850d5734 100644 --- a/tests/testthat/test-hier_clust-stats.R +++ b/tests/testthat/test-hier_clust-stats.R @@ -53,8 +53,15 @@ test_that("extract_centroids() works", { expect_identical( colnames(centroids), - c(".cluster", "Sepal.Length", "Sepal.Width", "Petal.Length", - "Petal.Width", "Speciesversicolor", "Speciesvirginica") + c( + ".cluster", + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Speciesversicolor", + "Speciesvirginica" + ) ) expect_identical( @@ -73,7 +80,9 @@ test_that("extract_cluster_assignment() works", { clusters <- extract_cluster_assignment(res) expected <- vctrs::vec_cbind( - tibble::tibble(.cluster = factor(paste0("Cluster_", cutree(res$fit, k = 3)))) + tibble::tibble( + .cluster = factor(paste0("Cluster_", cutree(res$fit, k = 3))) + ) ) expect_identical( diff --git a/tests/testthat/test-k_means-clustMixType.R b/tests/testthat/test-k_means-clustMixType.R index 8928ff8e..e07a90e9 100644 --- a/tests/testthat/test-k_means-clustMixType.R +++ b/tests/testthat/test-k_means-clustMixType.R @@ -115,4 +115,3 @@ test_that("modifies errors about suggested other models", { fit(~., data = data.frame(letters, LETTERS)) ) }) - diff --git a/tests/testthat/test-k_means-clusterR.R b/tests/testthat/test-k_means-clusterR.R index 49d2d75c..b630e596 100644 --- a/tests/testthat/test-k_means-clusterR.R +++ b/tests/testthat/test-k_means-clusterR.R @@ -27,8 +27,12 @@ test_that("predicting", { expect_identical( preds, - tibble::tibble(.pred_cluster = factor(paste0("Cluster_", c(1, 1, 1, 2, 2)), - paste0("Cluster_", 1:3))) + tibble::tibble( + .pred_cluster = factor( + paste0("Cluster_", c(1, 1, 1, 2, 2)), + paste0("Cluster_", 1:3) + ) + ) ) }) diff --git a/tests/testthat/test-k_means-klaR.R b/tests/testthat/test-k_means-klaR.R index e205458a..4c554129 100644 --- a/tests/testthat/test-k_means-klaR.R +++ b/tests/testthat/test-k_means-klaR.R @@ -46,8 +46,12 @@ test_that("predicting", { expect_identical( preds, - tibble::tibble(.pred_cluster = factor(paste0("Cluster_", c(1, 1, 1, 1, 2)), - paste0("Cluster_", 1:3))) + tibble::tibble( + .pred_cluster = factor( + paste0("Cluster_", c(1, 1, 1, 1, 2)), + paste0("Cluster_", 1:3) + ) + ) ) }) @@ -89,14 +93,12 @@ test_that("predicting ties argument works", { expect_identical( predict(res, data.frame(x = "C", y = "C"), ties = "first"), - tibble::tibble(.pred_cluster = factor("Cluster_1", - paste0("Cluster_", 1:2))) + tibble::tibble(.pred_cluster = factor("Cluster_1", paste0("Cluster_", 1:2))) ) expect_identical( predict(res, data.frame(x = "C", y = "C"), ties = "last"), - tibble::tibble(.pred_cluster = factor("Cluster_2", - paste0("Cluster_", 1:2))) + tibble::tibble(.pred_cluster = factor("Cluster_2", paste0("Cluster_", 1:2))) ) }) diff --git a/tests/testthat/test-k_means.R b/tests/testthat/test-k_means.R index 9b530602..c9db68da 100644 --- a/tests/testthat/test-k_means.R +++ b/tests/testthat/test-k_means.R @@ -161,7 +161,7 @@ test_that("reordering is done correctly for ClusterR k_means", { expect_identical( summ$n_members, unname(as.integer(table(summ$cluster_assignments))) - ) + ) }) test_that("errors if `num_clust` isn't specified", { @@ -169,13 +169,13 @@ test_that("errors if `num_clust` isn't specified", { error = TRUE, k_means() %>% set_engine("stats") %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) ) expect_snapshot( error = TRUE, k_means() %>% set_engine("ClusterR") %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) ) }) diff --git a/tests/testthat/test-k_means_diagnostics.R b/tests/testthat/test-k_means_diagnostics.R index 98f355e0..dba6415f 100644 --- a/tests/testthat/test-k_means_diagnostics.R +++ b/tests/testthat/test-k_means_diagnostics.R @@ -16,7 +16,8 @@ test_that("kmeans sse metrics work", { clusters = 3 ) - expect_equal(sse_within(kmeans_fit_stats)$wss, + expect_equal( + sse_within(kmeans_fit_stats)$wss, c(42877.103, 76954.010, 7654.146), # hard coded because of order tolerance = 0.005 ) @@ -37,7 +38,8 @@ test_that("kmeans sse metrics work", { tolerance = 0.005 ) - expect_equal(sse_within(kmeans_fit_ClusterR)$wss, + expect_equal( + sse_within(kmeans_fit_ClusterR)$wss, c(42877.103, 56041.432, 4665.041), # hard coded because of order tolerance = 0.005 ) @@ -66,7 +68,8 @@ test_that("kmeans sse metrics work on new data", { new_data <- mtcars[1:4, ] - expect_equal(sse_within(kmeans_fit_stats, new_data)$wss, + expect_equal( + sse_within(kmeans_fit_stats, new_data)$wss, c(2799.21, 12855.17), tolerance = 0.005 ) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 349868e3..9f2003f6 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -56,7 +56,7 @@ test_that("passed arguments overwrites model arguments", { test_that("prefix is passed in predict()", { spec <- tidyclust::k_means(num_clusters = 4) %>% - fit(~ ., data = mtcars) + fit(~., data = mtcars) res <- predict(spec, mtcars, prefix = "C_") diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index a368abae..0fd09bae 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -1,8 +1,8 @@ test_that("partition predictions", { kmeans_fit <- k_means(num_clusters = 3, mode = "partition") %>% - set_engine("stats") %>% - fit(~., data = mtcars) + set_engine("stats") %>% + fit(~., data = mtcars) expect_true(tibble::is_tibble(predict(kmeans_fit, new_data = mtcars))) expect_true( diff --git a/tests/testthat/test-reconcile_clusterings.R b/tests/testthat/test-reconcile_clusterings.R index 2b5d60f2..6ead4c6e 100644 --- a/tests/testthat/test-reconcile_clusterings.R +++ b/tests/testthat/test-reconcile_clusterings.R @@ -1,7 +1,11 @@ test_that("reconciliation works with one-to-one", { primary_cluster_assignment <- c( - "Apple", "Apple", "Carrot", "Carrot", - "Banana", "Banana" + "Apple", + "Apple", + "Carrot", + "Carrot", + "Banana", + "Banana" ) alt_cluster_assignment <- c("Dog", "Dog", "Cat", "Dog", "Fish", "Fish") @@ -18,8 +22,12 @@ test_that("reconciliation works with one-to-one", { test_that("reconciliation works with uneven numbers", { primary_cluster_assignment <- c( - "Apple", "Apple", "Carrot", "Carrot", - "Banana", "Banana" + "Apple", + "Apple", + "Carrot", + "Carrot", + "Banana", + "Banana" ) alt_cluster_assignment <- c("Dog", "Dog", "Cat", "Dog", "Parrot", "Fish") diff --git a/tests/testthat/test-tune_cluster.R b/tests/testthat/test-tune_cluster.R index 166daa02..922a69ae 100644 --- a/tests/testthat/test-tune_cluster.R +++ b/tests/testthat/test-tune_cluster.R @@ -163,7 +163,11 @@ test_that("tune model and recipe", { expect_equal( colnames(res$.metrics[[1]]), c( - "num_clusters", "num_comp", ".metric", ".estimator", ".estimate", + "num_clusters", + "num_comp", + ".metric", + ".estimator", + ".estimate", ".config" ) ) @@ -233,7 +237,11 @@ test_that('tune model and recipe (parallel_over = "everything")', { expect_equal( colnames(res$.metrics[[1]]), c( - "num_clusters", "num_comp", ".metric", ".estimator", ".estimate", + "num_clusters", + "num_comp", + ".metric", + ".estimator", + ".estimate", ".config" ) ) @@ -258,9 +266,12 @@ test_that("tune model only - failure in formula is caught elegantly", { ~z, resamples = data_folds, grid = cars_grid, - control = tune::control_grid(extract = function(x) { - 1 - }, save_pred = TRUE) + control = tune::control_grid( + extract = function(x) { + 1 + }, + save_pred = TRUE + ) ) ) diff --git a/tests/testthat/test-workflows.R b/tests/testthat/test-workflows.R index a0b9a5bc..d5277475 100644 --- a/tests/testthat/test-workflows.R +++ b/tests/testthat/test-workflows.R @@ -27,7 +27,7 @@ test_that("integrates with workflows::add_formula()", { kmeans_spec <- k_means(num_clusters = 2) wf_spec <- workflows::workflow() %>% - workflows::add_formula(~ .) %>% + workflows::add_formula(~.) %>% workflows::add_model(kmeans_spec) expect_no_error( @@ -51,7 +51,7 @@ test_that("integrates with workflows::add_recipe()", { kmeans_spec <- k_means(num_clusters = 2) wf_spec <- workflows::workflow() %>% - workflows::add_recipe(recipes::recipe(~ ., data = mtcars)) %>% + workflows::add_recipe(recipes::recipe(~., data = mtcars)) %>% workflows::add_model(kmeans_spec) expect_no_error( diff --git a/tidyclust.Rproj b/tidyclust.Rproj index 7f1b52b6..ee85d219 100644 --- a/tidyclust.Rproj +++ b/tidyclust.Rproj @@ -1,4 +1,5 @@ Version: 1.0 +ProjectId: 7aa6e982-3f23-4734-8da7-2ab97b3d66a5 RestoreWorkspace: No SaveWorkspace: No diff --git a/vignettes/articles/k_means.Rmd b/vignettes/articles/k_means.Rmd index bbd514db..9b718e8e 100644 --- a/vignettes/articles/k_means.Rmd +++ b/vignettes/articles/k_means.Rmd @@ -291,11 +291,11 @@ matrix (i.e., all pairwise distances between observations). ```{r} my_dist_1 <- function(x) { - Rfast::Dist(x, method = "manhattan") + philentropy::distance(x, method = "manhattan") } my_dist_2 <- function(x, y) { - Rfast::dista(x, y, method = "manhattan") + philentropy::dist_many_many(x, y, method = "manhattan") } kmeans_fit %>% sse_ratio(dist_fun = my_dist_2) @@ -404,7 +404,7 @@ pens %>% ```{r, echo = FALSE} #| fig-alt: "scatter chart. bill_length_mm along the x-axis, bill_depth_mm along the y-axis. 3 vague cluster appears in the point cloud. Point are colored according to how close they were to the color points." -closest_center <- Rfast::dista(as.matrix(pens), as.matrix(pens[init, ])) %>% +closest_center <- philentropy::dist_many_many(as.matrix(pens), as.matrix(pens[init, ]), method = "euclidean") %>% apply(1, which.min) pens %>%