diff --git a/.gitignore b/.gitignore
index ce37a082..114b666f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,7 @@
-.Rproj.user
-.Rhistory
-.Rdata
-.httr-oauth
-.DS_Store
-docs
-hex sticker/
+.Rproj.user
+.Rhistory
+.Rdata
+.httr-oauth
+.DS_Store
+docs
+hex sticker/
diff --git a/DESCRIPTION b/DESCRIPTION
index 1211cb41..11338e59 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,57 +1,57 @@
-Package: tidyclust
-Title: What the Package Does (One Line, Title Case)
-Version: 0.0.0.9000
-Authors@R: c(
- person("Emil", "Hvitfeldt", , "emilhhvitfeldt@gmail.com", role = c("aut", "cre"),
- comment = c(ORCID = "0000-0002-0679-1945")),
- person("Kelly", "Bodwin", , "kelly@bodwin.us", role = "aut"),
- person("RStudio", role = c("cph", "fnd"))
- )
-Description: What the package does (one paragraph).
-License: MIT + file LICENSE
-URL: https://github.com/EmilHvitfeldt/tidyclust
-BugReports: https://github.com/EmilHvitfeldt/tidyclust/issues
-Imports:
- cli,
- dials,
- dplyr,
- forcats,
- foreach,
- generics,
- glue,
- hardhat (>= 0.1.6.9001),
- magrittr,
- parsnip,
- prettyunits,
- RcppHungarian,
- rlang,
- rsample,
- stats,
- tibble,
- tidyr,
- tune,
- utils,
- vctrs
-Suggests:
- cluster,
- ClusterR,
- covr,
- flexclust,
- janitor,
- knitr,
- modeldata,
- recipes,
- rmarkdown,
- Rfast,
- testthat (>= 3.0.0),
- workflows
-VignetteBuilder:
- knitr
-Remotes:
- tidymodels/parsnip,
- tidymodels/workflows@celery
-Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins
-Config/testthat/edition: 3
-Encoding: UTF-8
-Roxygen: list(markdown = TRUE)
-RoxygenNote: 7.2.0.9000
+Package: tidyclust
+Title: What the Package Does (One Line, Title Case)
+Version: 0.0.0.9000
+Authors@R: c(
+ person("Emil", "Hvitfeldt", , "emilhhvitfeldt@gmail.com", role = c("aut", "cre"),
+ comment = c(ORCID = "0000-0002-0679-1945")),
+ person("Kelly", "Bodwin", , "kelly@bodwin.us", role = "aut"),
+ person("RStudio", role = c("cph", "fnd"))
+ )
+Description: What the package does (one paragraph).
+License: MIT + file LICENSE
+URL: https://github.com/EmilHvitfeldt/tidyclust
+BugReports: https://github.com/EmilHvitfeldt/tidyclust/issues
+Imports:
+ cli,
+ dials,
+ dplyr,
+ forcats,
+ foreach,
+ generics,
+ glue,
+ hardhat (>= 0.1.6.9001),
+ magrittr,
+ parsnip,
+ prettyunits,
+ RcppHungarian,
+ rlang,
+ rsample,
+ stats,
+ tibble,
+ tidyr,
+ tune,
+ utils,
+ vctrs
+Suggests:
+ cluster,
+ ClusterR,
+ covr,
+ flexclust,
+ janitor,
+ knitr,
+ modeldata,
+ recipes,
+ rmarkdown,
+ Rfast,
+ testthat (>= 3.0.0),
+ workflows
+VignetteBuilder:
+ knitr
+Remotes:
+ tidymodels/parsnip,
+ tidymodels/workflows@celery
+Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins
+Config/testthat/edition: 3
+Encoding: UTF-8
+Roxygen: list(markdown = TRUE)
+RoxygenNote: 7.2.0.9000
diff --git a/NAMESPACE b/NAMESPACE
index 77354f66..8dca7bb0 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -1,173 +1,173 @@
-# Generated by roxygen2: do not edit by hand
-
-S3method(as_tibble,cluster_metric_set)
-S3method(augment,cluster_fit)
-S3method(avg_silhouette,cluster_fit)
-S3method(avg_silhouette,workflow)
-S3method(extract_cluster_assignment,KMeansCluster)
-S3method(extract_cluster_assignment,cluster_fit)
-S3method(extract_cluster_assignment,hclust)
-S3method(extract_cluster_assignment,kmeans)
-S3method(extract_cluster_assignment,workflow)
-S3method(extract_fit_summary,KMeansCluster)
-S3method(extract_fit_summary,cluster_fit)
-S3method(extract_fit_summary,hclust)
-S3method(extract_fit_summary,kmeans)
-S3method(extract_fit_summary,workflow)
-S3method(extract_parameter_set_dials,cluster_spec)
-S3method(fit,cluster_spec)
-S3method(fit_xy,cluster_spec)
-S3method(glance,cluster_fit)
-S3method(load_pkgs,cluster_spec)
-S3method(min_grid,cluster_spec)
-S3method(predict,cluster_fit)
-S3method(predict_cluster,cluster_fit)
-S3method(predict_raw,cluster_fit)
-S3method(print,cluster_fit)
-S3method(print,cluster_metric_set)
-S3method(print,cluster_spec)
-S3method(print,control_cluster)
-S3method(print,hier_clust)
-S3method(print,k_means)
-S3method(sse_ratio,cluster_fit)
-S3method(sse_ratio,workflow)
-S3method(tidy,cluster_fit)
-S3method(tot_sse,cluster_fit)
-S3method(tot_sse,workflow)
-S3method(tot_wss,cluster_fit)
-S3method(tot_wss,workflow)
-S3method(translate_tidyclust,default)
-S3method(translate_tidyclust,hier_clust)
-S3method(translate_tidyclust,k_means)
-S3method(tune_cluster,cluster_spec)
-S3method(tune_cluster,default)
-S3method(tune_cluster,workflow)
-S3method(update,k_means)
-export("%>%")
-export(.convert_form_to_x_fit)
-export(.convert_form_to_x_new)
-export(.convert_x_to_form_fit)
-export(.convert_x_to_form_new)
-export(ClusterR_kmeans_fit)
-export(augment)
-export(avg_silhouette)
-export(avg_silhouette_vec)
-export(check_empty_ellipse_tidyclust)
-export(check_model_doesnt_exist_tidyclust)
-export(check_model_exists_tidyclust)
-export(cluster_metric_set)
-export(control_cluster)
-export(enrichment)
-export(extract_centroids)
-export(extract_cluster_assignment)
-export(extract_fit_parsnip)
-export(extract_fit_summary)
-export(extract_parameter_set_dials)
-export(extract_preprocessor)
-export(extract_spec_parsnip)
-export(finalize_model_tidyclust)
-export(finalize_workflow_tidyclust)
-export(fit)
-export(fit.cluster_spec)
-export(fit_xy)
-export(fit_xy.cluster_spec)
-export(get_dependency_tidyclust)
-export(get_encoding_tidyclust)
-export(get_fit_tidyclust)
-export(get_from_env_tidyclust)
-export(get_model_env_tidyclust)
-export(get_pred_type_tidyclust)
-export(glance)
-export(hclust_fit)
-export(hier_clust)
-export(k_means)
-export(load_pkgs)
-export(make_classes_tidyclust)
-export(min_grid)
-export(new_cluster_metric)
-export(new_cluster_spec)
-export(num_clusters)
-export(predict.cluster_fit)
-export(predict_cluster)
-export(predict_cluster.cluster_fit)
-export(predict_raw)
-export(predict_raw.cluster_fit)
-export(prepare_data)
-export(reconcile_clusterings)
-export(required_pkgs)
-export(set_args)
-export(set_args.cluster_spec)
-export(set_dependency_tidyclust)
-export(set_encoding_tidyclust)
-export(set_engine)
-export(set_engine.cluster_spec)
-export(set_env_val_tidyclust)
-export(set_fit_tidyclust)
-export(set_mode)
-export(set_mode.cluster_spec)
-export(set_model_arg_tidyclust)
-export(set_model_engine_tidyclust)
-export(set_model_mode_tidyclust)
-export(set_new_model_tidyclust)
-export(set_pred_tidyclust)
-export(show_model_info_tidyclust)
-export(silhouettes)
-export(sse_ratio)
-export(sse_ratio_vec)
-export(tidy)
-export(tot_sse)
-export(tot_sse_vec)
-export(tot_wss)
-export(tot_wss_vec)
-export(translate_tidyclust)
-export(translate_tidyclust.default)
-export(tune)
-export(tune_cluster)
-export(within_cluster_sse)
-importFrom(dplyr,bind_cols)
-importFrom(generics,augment)
-importFrom(generics,fit)
-importFrom(generics,fit_xy)
-importFrom(generics,glance)
-importFrom(generics,min_grid)
-importFrom(generics,required_pkgs)
-importFrom(generics,tidy)
-importFrom(hardhat,extract_fit_parsnip)
-importFrom(hardhat,extract_parameter_set_dials)
-importFrom(hardhat,extract_preprocessor)
-importFrom(hardhat,extract_spec_parsnip)
-importFrom(hardhat,tune)
-importFrom(magrittr,"%>%")
-importFrom(parsnip,make_call)
-importFrom(parsnip,maybe_data_frame)
-importFrom(parsnip,maybe_matrix)
-importFrom(parsnip,model_printer)
-importFrom(parsnip,null_value)
-importFrom(parsnip,predict_raw)
-importFrom(parsnip,set_args)
-importFrom(parsnip,set_engine)
-importFrom(parsnip,set_mode)
-importFrom(parsnip,show_call)
-importFrom(rlang,"%||%")
-importFrom(rlang,abort)
-importFrom(rlang,as_function)
-importFrom(rlang,enquo)
-importFrom(rlang,enquos)
-importFrom(rlang,get_expr)
-importFrom(rlang,global_env)
-importFrom(rlang,is_logical)
-importFrom(rlang,is_true)
-importFrom(rlang,missing_arg)
-importFrom(rlang,quos)
-importFrom(rlang,set_names)
-importFrom(rlang,sym)
-importFrom(stats,.getXlevels)
-importFrom(stats,as.formula)
-importFrom(stats,model.frame)
-importFrom(stats,model.matrix)
-importFrom(stats,model.offset)
-importFrom(stats,model.weights)
-importFrom(stats,na.omit)
-importFrom(tibble,as_tibble)
-importFrom(tune,load_pkgs)
-importFrom(utils,capture.output)
+# Generated by roxygen2: do not edit by hand
+
+S3method(as_tibble,cluster_metric_set)
+S3method(augment,cluster_fit)
+S3method(avg_silhouette,cluster_fit)
+S3method(avg_silhouette,workflow)
+S3method(extract_cluster_assignment,KMeansCluster)
+S3method(extract_cluster_assignment,cluster_fit)
+S3method(extract_cluster_assignment,hclust)
+S3method(extract_cluster_assignment,kmeans)
+S3method(extract_cluster_assignment,workflow)
+S3method(extract_fit_summary,KMeansCluster)
+S3method(extract_fit_summary,cluster_fit)
+S3method(extract_fit_summary,hclust)
+S3method(extract_fit_summary,kmeans)
+S3method(extract_fit_summary,workflow)
+S3method(extract_parameter_set_dials,cluster_spec)
+S3method(fit,cluster_spec)
+S3method(fit_xy,cluster_spec)
+S3method(glance,cluster_fit)
+S3method(load_pkgs,cluster_spec)
+S3method(min_grid,cluster_spec)
+S3method(predict,cluster_fit)
+S3method(predict_cluster,cluster_fit)
+S3method(predict_raw,cluster_fit)
+S3method(print,cluster_fit)
+S3method(print,cluster_metric_set)
+S3method(print,cluster_spec)
+S3method(print,control_cluster)
+S3method(print,hier_clust)
+S3method(print,k_means)
+S3method(set_args,cluster_spec)
+S3method(set_engine,cluster_spec)
+S3method(set_mode,cluster_spec)
+S3method(sse_ratio,cluster_fit)
+S3method(sse_ratio,workflow)
+S3method(tidy,cluster_fit)
+S3method(tot_sse,cluster_fit)
+S3method(tot_sse,workflow)
+S3method(tot_wss,cluster_fit)
+S3method(tot_wss,workflow)
+S3method(translate_tidyclust,default)
+S3method(translate_tidyclust,hier_clust)
+S3method(translate_tidyclust,k_means)
+S3method(tune_cluster,cluster_spec)
+S3method(tune_cluster,default)
+S3method(tune_cluster,workflow)
+S3method(update,k_means)
+export("%>%")
+export(.convert_form_to_x_fit)
+export(.convert_form_to_x_new)
+export(.convert_x_to_form_fit)
+export(.convert_x_to_form_new)
+export(ClusterR_kmeans_fit)
+export(augment)
+export(avg_silhouette)
+export(avg_silhouette_vec)
+export(check_empty_ellipse_tidyclust)
+export(check_model_doesnt_exist_tidyclust)
+export(check_model_exists_tidyclust)
+export(cluster_metric_set)
+export(control_cluster)
+export(enrichment)
+export(extract_centroids)
+export(extract_cluster_assignment)
+export(extract_fit_parsnip)
+export(extract_fit_summary)
+export(extract_parameter_set_dials)
+export(extract_preprocessor)
+export(extract_spec_parsnip)
+export(finalize_model_tidyclust)
+export(finalize_workflow_tidyclust)
+export(fit)
+export(fit.cluster_spec)
+export(fit_xy)
+export(fit_xy.cluster_spec)
+export(get_dependency_tidyclust)
+export(get_encoding_tidyclust)
+export(get_fit_tidyclust)
+export(get_from_env_tidyclust)
+export(get_model_env_tidyclust)
+export(get_pred_type_tidyclust)
+export(glance)
+export(hclust_fit)
+export(hier_clust)
+export(k_means)
+export(load_pkgs)
+export(make_classes_tidyclust)
+export(min_grid)
+export(new_cluster_metric)
+export(new_cluster_spec)
+export(num_clusters)
+export(predict.cluster_fit)
+export(predict_cluster)
+export(predict_cluster.cluster_fit)
+export(predict_raw)
+export(predict_raw.cluster_fit)
+export(prepare_data)
+export(reconcile_clusterings)
+export(required_pkgs)
+export(set_args)
+export(set_dependency_tidyclust)
+export(set_encoding_tidyclust)
+export(set_engine)
+export(set_env_val_tidyclust)
+export(set_fit_tidyclust)
+export(set_mode)
+export(set_model_arg_tidyclust)
+export(set_model_engine_tidyclust)
+export(set_model_mode_tidyclust)
+export(set_new_model_tidyclust)
+export(set_pred_tidyclust)
+export(show_model_info_tidyclust)
+export(silhouettes)
+export(sse_ratio)
+export(sse_ratio_vec)
+export(tidy)
+export(tot_sse)
+export(tot_sse_vec)
+export(tot_wss)
+export(tot_wss_vec)
+export(translate_tidyclust)
+export(translate_tidyclust.default)
+export(tune)
+export(tune_cluster)
+export(within_cluster_sse)
+importFrom(dplyr,bind_cols)
+importFrom(generics,augment)
+importFrom(generics,fit)
+importFrom(generics,fit_xy)
+importFrom(generics,glance)
+importFrom(generics,min_grid)
+importFrom(generics,required_pkgs)
+importFrom(generics,tidy)
+importFrom(hardhat,extract_fit_parsnip)
+importFrom(hardhat,extract_parameter_set_dials)
+importFrom(hardhat,extract_preprocessor)
+importFrom(hardhat,extract_spec_parsnip)
+importFrom(hardhat,tune)
+importFrom(magrittr,"%>%")
+importFrom(parsnip,make_call)
+importFrom(parsnip,maybe_data_frame)
+importFrom(parsnip,maybe_matrix)
+importFrom(parsnip,model_printer)
+importFrom(parsnip,null_value)
+importFrom(parsnip,predict_raw)
+importFrom(parsnip,set_args)
+importFrom(parsnip,set_engine)
+importFrom(parsnip,set_mode)
+importFrom(parsnip,show_call)
+importFrom(rlang,"%||%")
+importFrom(rlang,abort)
+importFrom(rlang,as_function)
+importFrom(rlang,enquo)
+importFrom(rlang,enquos)
+importFrom(rlang,get_expr)
+importFrom(rlang,global_env)
+importFrom(rlang,is_logical)
+importFrom(rlang,is_true)
+importFrom(rlang,missing_arg)
+importFrom(rlang,quos)
+importFrom(rlang,set_names)
+importFrom(rlang,sym)
+importFrom(stats,.getXlevels)
+importFrom(stats,as.formula)
+importFrom(stats,model.frame)
+importFrom(stats,model.matrix)
+importFrom(stats,model.offset)
+importFrom(stats,model.weights)
+importFrom(stats,na.omit)
+importFrom(tibble,as_tibble)
+importFrom(tune,load_pkgs)
+importFrom(utils,capture.output)
diff --git a/R/arguments.R b/R/arguments.R
index 52279b47..62e9bef2 100644
--- a/R/arguments.R
+++ b/R/arguments.R
@@ -1,107 +1,107 @@
-check_eng_args <- function(args, obj, core_args) {
- # Make sure that we are not trying to modify an argument that
- # is explicitly protected in the method metadata or arg_key
- protected_args <- unique(c(obj$protect, core_args))
- common_args <- intersect(protected_args, names(args))
- if (length(common_args) > 0) {
- args <- args[!(names(args) %in% common_args)]
- common_args <- paste0(common_args, collapse = ", ")
- rlang::warn(glue::glue(
- "The following arguments cannot be manually modified ",
- "and were removed: {common_args}."
- ))
- }
- args
-}
-
-make_x_call <- function(object, target) {
- fit_args <- object$method$fit$args
-
- # Get the arguments related to data:
- if (is.null(object$method$fit$data)) {
- data_args <- c(x = "x")
- } else {
- data_args <- object$method$fit$data
- }
-
- object$method$fit$args[[unname(data_args["x"])]] <-
- switch(target,
- none = rlang::expr(x),
- data.frame = rlang::expr(maybe_data_frame(x)),
- matrix = rlang::expr(maybe_matrix(x)),
- rlang::abort(glue::glue("Invalid data type target: {target}."))
- )
-
- fit_call <- make_call(
- fun = object$method$fit$func["fun"],
- ns = object$method$fit$func["pkg"],
- object$method$fit$args
- )
-
- fit_call
-}
-
-make_form_call <- function(object, env = NULL) {
- fit_args <- object$method$fit$args
-
- # Get the arguments related to data:
- if (is.null(object$method$fit$data)) {
- data_args <- c(formula = "formula", data = "data")
- } else {
- data_args <- object$method$fit$data
- }
-
- # add data arguments
- for (i in seq_along(data_args)) {
- fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i])
- }
-
- # sub in actual formula
- fit_args[[unname(data_args["formula"])]] <- env$formula
-
- fit_call <- make_call(
- fun = object$method$fit$func["fun"],
- ns = object$method$fit$func["pkg"],
- fit_args
- )
- fit_call
-}
-
-#' @export
-set_args.cluster_spec <- function(object, ...) {
- the_dots <- enquos(...)
- if (length(the_dots) == 0)
- rlang::abort("Please pass at least one named argument.")
- main_args <- names(object$args)
- new_args <- names(the_dots)
- for (i in new_args) {
- if (any(main_args == i)) {
- object$args[[i]] <- the_dots[[i]]
- } else {
- object$eng_args[[i]] <- the_dots[[i]]
- }
- }
- new_cluster_spec(
- cls = class(object)[1],
- args = object$args,
- eng_args = object$eng_args,
- mode = object$mode,
- method = NULL,
- engine = object$engine
- )
-}
-
-#' @export
-set_mode.cluster_spec <- function(object, mode) {
- cls <- class(object)[1]
- if (rlang::is_missing(mode)) {
- spec_modes <- rlang::env_get(
- get_model_env_tidyclust(),
- paste0(cls, "_modes")
- )
- stop_incompatible_mode(spec_modes, cls = cls)
- }
- check_spec_mode_engine_val(cls, object$engine, mode)
- object$mode <- mode
- object
-}
+check_eng_args <- function(args, obj, core_args) {
+ # Make sure that we are not trying to modify an argument that
+ # is explicitly protected in the method metadata or arg_key
+ protected_args <- unique(c(obj$protect, core_args))
+ common_args <- intersect(protected_args, names(args))
+ if (length(common_args) > 0) {
+ args <- args[!(names(args) %in% common_args)]
+ common_args <- paste0(common_args, collapse = ", ")
+ rlang::warn(glue::glue(
+ "The following arguments cannot be manually modified ",
+ "and were removed: {common_args}."
+ ))
+ }
+ args
+}
+
+make_x_call <- function(object, target) {
+ fit_args <- object$method$fit$args
+
+ # Get the arguments related to data:
+ if (is.null(object$method$fit$data)) {
+ data_args <- c(x = "x")
+ } else {
+ data_args <- object$method$fit$data
+ }
+
+ object$method$fit$args[[unname(data_args["x"])]] <-
+ switch(target,
+ none = rlang::expr(x),
+ data.frame = rlang::expr(maybe_data_frame(x)),
+ matrix = rlang::expr(maybe_matrix(x)),
+ rlang::abort(glue::glue("Invalid data type target: {target}."))
+ )
+
+ fit_call <- make_call(
+ fun = object$method$fit$func["fun"],
+ ns = object$method$fit$func["pkg"],
+ object$method$fit$args
+ )
+
+ fit_call
+}
+
+make_form_call <- function(object, env = NULL) {
+ fit_args <- object$method$fit$args
+
+ # Get the arguments related to data:
+ if (is.null(object$method$fit$data)) {
+ data_args <- c(formula = "formula", data = "data")
+ } else {
+ data_args <- object$method$fit$data
+ }
+
+ # add data arguments
+ for (i in seq_along(data_args)) {
+ fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i])
+ }
+
+ # sub in actual formula
+ fit_args[[unname(data_args["formula"])]] <- env$formula
+
+ fit_call <- make_call(
+ fun = object$method$fit$func["fun"],
+ ns = object$method$fit$func["pkg"],
+ fit_args
+ )
+ fit_call
+}
+
+#' @export
+set_args.cluster_spec <- function(object, ...) {
+ the_dots <- enquos(...)
+ if (length(the_dots) == 0)
+ rlang::abort("Please pass at least one named argument.")
+ main_args <- names(object$args)
+ new_args <- names(the_dots)
+ for (i in new_args) {
+ if (any(main_args == i)) {
+ object$args[[i]] <- the_dots[[i]]
+ } else {
+ object$eng_args[[i]] <- the_dots[[i]]
+ }
+ }
+ new_cluster_spec(
+ cls = class(object)[1],
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
+}
+
+#' @export
+set_mode.cluster_spec <- function(object, mode) {
+ cls <- class(object)[1]
+ if (rlang::is_missing(mode)) {
+ spec_modes <- rlang::env_get(
+ get_model_env_tidyclust(),
+ paste0(cls, "_modes")
+ )
+ stop_incompatible_mode(spec_modes, cls = cls)
+ }
+ check_spec_mode_engine_val(cls, object$engine, mode)
+ object$mode <- mode
+ object
+}
diff --git a/R/augment.R b/R/augment.R
index c8d7299d..05b422e4 100644
--- a/R/augment.R
+++ b/R/augment.R
@@ -1,33 +1,33 @@
-#' Augment data with predictions
-#'
-#' `augment()` will add column(s) for predictions to the given data.
-#'
-#' For partition models, a `.pred_cluster` column is added.
-#'
-#' @param x A `cluster_fit` object produced by [fit.cluster_spec()] or
-#' [fit_xy.cluster_spec()] .
-#' @param new_data A data frame or matrix.
-#' @param ... Not currently used.
-#' @rdname augment
-#' @export
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' kmeans_fit %>%
-#' augment(new_data = mtcars)
-augment.cluster_fit <- function(x, new_data, ...) {
- ret <- new_data
- if (x$spec$mode == "partition") {
- check_spec_pred_type(x, "cluster")
- ret <- dplyr::bind_cols(
- ret,
- stats::predict(x, new_data = new_data)
- )
- } else {
- rlang::abort(paste("Unknown mode:", x$spec$mode))
- }
- as_tibble(ret)
-}
+#' Augment data with predictions
+#'
+#' `augment()` will add column(s) for predictions to the given data.
+#'
+#' For partition models, a `.pred_cluster` column is added.
+#'
+#' @param x A `cluster_fit` object produced by [fit.cluster_spec()] or
+#' [fit_xy.cluster_spec()] .
+#' @param new_data A data frame or matrix.
+#' @param ... Not currently used.
+#' @rdname augment
+#' @export
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' kmeans_fit %>%
+#' augment(new_data = mtcars)
+augment.cluster_fit <- function(x, new_data, ...) {
+ ret <- new_data
+ if (x$spec$mode == "partition") {
+ check_spec_pred_type(x, "cluster")
+ ret <- dplyr::bind_cols(
+ ret,
+ stats::predict(x, new_data = new_data)
+ )
+ } else {
+ rlang::abort(paste("Unknown mode:", x$spec$mode))
+ }
+ as_tibble(ret)
+}
diff --git a/R/dials.R b/R/dials.R
index 9b4baf42..d4f91ffd 100644
--- a/R/dials.R
+++ b/R/dials.R
@@ -1,16 +1,16 @@
-#' Number of Clusters
-#'
-#' @inheritParams dials::Laplace
-#' @examples
-#' num_clusters()
-#' @export
-num_clusters <- function(range = c(1L, 10L), trans = NULL) {
- dials::new_quant_param(
- type = "integer",
- range = range,
- inclusive = c(TRUE, TRUE),
- trans = trans,
- label = c(num_clusters = "# Clusters"),
- finalize = NULL
- )
-}
+#' Number of Clusters
+#'
+#' @inheritParams dials::Laplace
+#' @examples
+#' num_clusters()
+#' @export
+num_clusters <- function(range = c(1L, 10L), trans = NULL) {
+ dials::new_quant_param(
+ type = "integer",
+ range = range,
+ inclusive = c(TRUE, TRUE),
+ trans = trans,
+ label = c(num_clusters = "# Clusters"),
+ finalize = NULL
+ )
+}
diff --git a/R/engines.R b/R/engines.R
index bfb10c00..5d6d4cab 100644
--- a/R/engines.R
+++ b/R/engines.R
@@ -1,86 +1,86 @@
-#' @export
-set_engine.cluster_spec <- function(object, engine, ...) {
- mod_type <- class(object)[1]
-
- if (rlang::is_missing(engine)) {
- stop_missing_engine(mod_type)
- }
- object$engine <- engine
- check_spec_mode_engine_val(mod_type, object$engine, object$mode)
-
- new_cluster_spec(
- cls = mod_type,
- args = object$args,
- eng_args = enquos(...),
- mode = object$mode,
- method = NULL,
- engine = object$engine
- )
-}
-
-stop_missing_engine <- function(cls) {
- info <-
- get_from_env_tidyclust(cls) %>%
- dplyr::group_by(mode) %>%
- dplyr::summarize(
- msg = paste0(
- unique(mode), " {",
- paste0(unique(engine), collapse = ", "),
- "}"
- ),
- .groups = "drop"
- )
- if (nrow(info) == 0) {
- rlang::abort(paste0("No known engines for `", cls, "()`."))
- }
- msg <- paste0(info$msg, collapse = ", ")
- msg <- paste("Missing engine. Possible mode/engine combinations are:", msg)
- rlang::abort(msg)
-}
-
-load_libs <- function(x, quiet, attach = FALSE) {
- for (pkg in x$method$libs) {
- if (!attach) {
- suppressPackageStartupMessages(requireNamespace(pkg, quietly = quiet))
- } else {
- library(pkg, character.only = TRUE)
- }
- }
- invisible(x)
-}
-
-specific_model <- function(x) {
- cls <- class(x)
- cls[cls != "cluster_spec"]
-}
-
-possible_engines <- function(object, ...) {
- m_env <- get_model_env_tidyclust()
- engs <- rlang::env_get(m_env, specific_model(object))
- unique(engs$engine)
-}
-
-shhhh <- function(x) {
- suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE))
-}
-
-is_installed <- function(pkg) {
- res <- try(shhhh(pkg), silent = TRUE)
- res
-}
-
-check_installs <- function(x) {
- if (length(x$method$libs) > 0) {
- is_inst <- map_lgl(x$method$libs, is_installed)
- if (any(!is_inst)) {
- missing_pkg <- x$method$libs[!is_inst]
- missing_pkg <- paste0(missing_pkg, collapse = ", ")
- rlang::abort(
- glue::glue(
- "This engine requires some package installs: ",
- glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ")
- )
- )
- }
- }
-}
+#' @export
+set_engine.cluster_spec <- function(object, engine, ...) {
+ mod_type <- class(object)[1]
+
+ if (rlang::is_missing(engine)) {
+ stop_missing_engine(mod_type)
+ }
+ object$engine <- engine
+ check_spec_mode_engine_val(mod_type, object$engine, object$mode)
+
+ new_cluster_spec(
+ cls = mod_type,
+ args = object$args,
+ eng_args = enquos(...),
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
+}
+
+stop_missing_engine <- function(cls) {
+ info <-
+ get_from_env_tidyclust(cls) %>%
+ dplyr::group_by(mode) %>%
+ dplyr::summarize(
+ msg = paste0(
+ unique(mode), " {",
+ paste0(unique(engine), collapse = ", "),
+ "}"
+ ),
+ .groups = "drop"
+ )
+ if (nrow(info) == 0) {
+ rlang::abort(paste0("No known engines for `", cls, "()`."))
+ }
+ msg <- paste0(info$msg, collapse = ", ")
+ msg <- paste("Missing engine. Possible mode/engine combinations are:", msg)
+ rlang::abort(msg)
+}
+
+load_libs <- function(x, quiet, attach = FALSE) {
+ for (pkg in x$method$libs) {
+ if (!attach) {
+ suppressPackageStartupMessages(requireNamespace(pkg, quietly = quiet))
+ } else {
+ library(pkg, character.only = TRUE)
+ }
+ }
+ invisible(x)
+}
+
+specific_model <- function(x) {
+ cls <- class(x)
+ cls[cls != "cluster_spec"]
+}
+
+possible_engines <- function(object, ...) {
+ m_env <- get_model_env_tidyclust()
+ engs <- rlang::env_get(m_env, specific_model(object))
+ unique(engs$engine)
+}
+
+shhhh <- function(x) {
+ suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE))
+}
+
+is_installed <- function(pkg) {
+ res <- try(shhhh(pkg), silent = TRUE)
+ res
+}
+
+check_installs <- function(x) {
+ if (length(x$method$libs) > 0) {
+ is_inst <- map_lgl(x$method$libs, is_installed)
+ if (any(!is_inst)) {
+ missing_pkg <- x$method$libs[!is_inst]
+ missing_pkg <- paste0(missing_pkg, collapse = ", ")
+ rlang::abort(
+ glue::glue(
+ "This engine requires some package installs: ",
+ glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ")
+ )
+ )
+ }
+ }
+}
diff --git a/R/extract_assignment.R b/R/extract_assignment.R
index 5050532e..7f0655c1 100644
--- a/R/extract_assignment.R
+++ b/R/extract_assignment.R
@@ -1,61 +1,61 @@
-#' Extract cluster assignments from model
-#'
-#' @param object An cluster_spec object.
-#' @param ... Other arguments passed to methods.
-#'
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' kmeans_fit %>%
-#' extract_cluster_assignment()
-#' @export
-extract_cluster_assignment <- function(object, ...) {
- UseMethod("extract_cluster_assignment")
-}
-
-#' @export
-extract_cluster_assignment.cluster_fit <- function(object, ...) {
- extract_cluster_assignment(object$fit, ...)
-}
-
-#' @export
-extract_cluster_assignment.workflow <- function(object, ...) {
- extract_cluster_assignment(object$fit$fit$fit)
-}
-
-#' @export
-extract_cluster_assignment.kmeans <- function(object, ...) {
- cluster_assignment_tibble(object$cluster, length(object$size))
-}
-
-#' @export
-extract_cluster_assignment.KMeansCluster <- function(object, ...) {
- cluster_assignment_tibble(object$clusters, length(object$obs_per_cluster))
-}
-
-#' @export
-extract_cluster_assignment.hclust <- function(object, ...) {
-
- # if k or h is passed in the dots, use those. Otherwise, use attributes
- # from original model specification
- args <- list(...)
- if (!("k" %in% names(args) | "cut_height" %in% names(args))) {
- k <- attr(object, "k")
- cut_height <- attr(object, "cut_height")
- }
- clusters <- stats::cutree(object, k, h = cut_height)
- cluster_assignment_tibble(clusters, length(unique(clusters)))
-}
-
-# ------------------------------------------------------------------------------
-
-cluster_assignment_tibble <- function(clusters, n_clusters) {
- reorder_clusts <- order(unique(clusters))
- names <- paste0("Cluster_", 1:n_clusters)
- res <- names[reorder_clusts][clusters]
-
- tibble::tibble(.cluster = factor(res))
-}
+#' Extract cluster assignments from model
+#'
+#' @param object An cluster_spec object.
+#' @param ... Other arguments passed to methods.
+#'
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' kmeans_fit %>%
+#' extract_cluster_assignment()
+#' @export
+extract_cluster_assignment <- function(object, ...) {
+ UseMethod("extract_cluster_assignment")
+}
+
+#' @export
+extract_cluster_assignment.cluster_fit <- function(object, ...) {
+ extract_cluster_assignment(object$fit, ...)
+}
+
+#' @export
+extract_cluster_assignment.workflow <- function(object, ...) {
+ extract_cluster_assignment(object$fit$fit$fit)
+}
+
+#' @export
+extract_cluster_assignment.kmeans <- function(object, ...) {
+ cluster_assignment_tibble(object$cluster, length(object$size))
+}
+
+#' @export
+extract_cluster_assignment.KMeansCluster <- function(object, ...) {
+ cluster_assignment_tibble(object$clusters, length(object$obs_per_cluster))
+}
+
+#' @export
+extract_cluster_assignment.hclust <- function(object, ...) {
+
+ # if k or h is passed in the dots, use those. Otherwise, use attributes
+ # from original model specification
+ args <- list(...)
+ if (!("k" %in% names(args) | "cut_height" %in% names(args))) {
+ k <- attr(object, "k")
+ cut_height <- attr(object, "cut_height")
+ }
+ clusters <- stats::cutree(object, k, h = cut_height)
+ cluster_assignment_tibble(clusters, length(unique(clusters)))
+}
+
+# ------------------------------------------------------------------------------
+
+cluster_assignment_tibble <- function(clusters, n_clusters) {
+ reorder_clusts <- order(unique(clusters))
+ names <- paste0("Cluster_", 1:n_clusters)
+ res <- names[reorder_clusts][clusters]
+
+ tibble::tibble(.cluster = factor(res))
+}
diff --git a/R/extract_characterization.R b/R/extract_characterization.R
index 20a216f6..b4c61477 100644
--- a/R/extract_characterization.R
+++ b/R/extract_characterization.R
@@ -1,20 +1,20 @@
-#' Extract clusters from model
-#'
-#' @param object An cluster_spec object.
-#' @param ... Other arguments passed to methods.
-#'
-#' @examples
-#' set.seed(1234)
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' kmeans_fit %>%
-#' extract_centroids()
-#' @export
-extract_centroids <- function(object, ...) {
- summ <- extract_fit_summary(object)
- clusters <- tibble::tibble(.cluster = summ$cluster_names)
- bind_cols(clusters, summ$centroids)
-}
+#' Extract clusters from model
+#'
+#' @param object An cluster_spec object.
+#' @param ... Other arguments passed to methods.
+#'
+#' @examples
+#' set.seed(1234)
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' kmeans_fit %>%
+#' extract_centroids()
+#' @export
+extract_centroids <- function(object, ...) {
+ summ <- extract_fit_summary(object)
+ clusters <- tibble::tibble(.cluster = summ$cluster_names)
+ bind_cols(clusters, summ$centroids)
+}
diff --git a/R/extract_summary.R b/R/extract_summary.R
index 96d2b70a..feaac179 100644
--- a/R/extract_summary.R
+++ b/R/extract_summary.R
@@ -1,98 +1,98 @@
-
-#' S3 method to get fitted model summary info depending on engine
-#'
-#' @param object a fitted cluster_spec object
-#' @param ... other arguments passed to methods
-#'
-#' @return A list with various summary elements
-#'
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' kmeans_fit %>%
-#' extract_fit_summary()
-#' @export
-extract_fit_summary <- function(object, ...) {
- UseMethod("extract_fit_summary")
-}
-
-#' @export
-extract_fit_summary.cluster_fit <- function(object, ...) {
- extract_fit_summary(object$fit)
-}
-
-#' @export
-extract_fit_summary.workflow <- function(object, ...) {
- extract_fit_summary(object$fit$fit$fit)
-}
-
-#' @export
-extract_fit_summary.kmeans <- function(object, ...) {
- reorder_clusts <- order(unique(object$cluster))
- names <- paste0("Cluster_", seq_len(nrow(object$centers)))
-
- list(
- cluster_names = names,
- centroids = tibble::as_tibble(object$centers[reorder_clusts, , drop = FALSE]),
- n_members = object$size[reorder_clusts],
- within_sse = object$withinss[reorder_clusts],
- tot_sse = object$totss,
- orig_labels = unname(object$cluster),
- cluster_assignments = names[reorder_clusts][object$cluster]
- )
-}
-
-#' @export
-extract_fit_summary.KMeansCluster <- function(object, ...) {
- reorder_clusts <- order(unique(object$cluster))
- names <- paste0("Cluster_", seq_len(nrow(object$centroids)))
-
- list(
- cluster_names = names,
- centroids = tibble::as_tibble(object$centroids[reorder_clusts, , drop = FALSE]),
- n_members = object$obs_per_cluster[reorder_clusts],
- within_sse = object$WCSS_per_cluster[reorder_clusts],
- tot_sse = object$total_SSE,
- orig_labels = object$clusters,
- cluster_assignments = names[reorder_clusts][object$clusters]
- )
-}
-
-#' @export
-extract_fit_summary.hclust <- function(object, ...) {
-
- clusts <- extract_cluster_assignment(object, ...)$.cluster
- n_clust <- dplyr::n_distinct(clusts)
-
- training_data <- attr(object, "training_data")
-
- overall_centroid <- colMeans(training_data)
-
- by_clust <- training_data %>%
- tibble::as_tibble() %>%
- dplyr::mutate(
- .cluster = clusts
- ) %>%
- dplyr::group_by(.cluster) %>%
- tidyr::nest()
-
- centroids <- by_clust$data %>%
- purrr::map_dfr(~ .x %>% dplyr::summarize_all(mean))
-
- within_sse <- by_clust$data %>%
- purrr::map2_dbl(1:n_clust,
- ~ sum(Rfast::dista(centroids[.y,], .x)))
-
- list(
- cluster_names = unique(clusts),
- centroids = centroids,
- n_members = unname(table(clusts)),
- within_sse = within_sse,
- tot_sse = sum(Rfast::dista(t(overall_centroid), training_data)),
- orig_labels = NULL,
- cluster_assignments = clusts
- )
-}
+
+#' S3 method to get fitted model summary info depending on engine
+#'
+#' @param object a fitted cluster_spec object
+#' @param ... other arguments passed to methods
+#'
+#' @return A list with various summary elements
+#'
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' kmeans_fit %>%
+#' extract_fit_summary()
+#' @export
+extract_fit_summary <- function(object, ...) {
+ UseMethod("extract_fit_summary")
+}
+
+#' @export
+extract_fit_summary.cluster_fit <- function(object, ...) {
+ extract_fit_summary(object$fit)
+}
+
+#' @export
+extract_fit_summary.workflow <- function(object, ...) {
+ extract_fit_summary(object$fit$fit$fit)
+}
+
+#' @export
+extract_fit_summary.kmeans <- function(object, ...) {
+ reorder_clusts <- order(unique(object$cluster))
+ names <- paste0("Cluster_", seq_len(nrow(object$centers)))
+
+ list(
+ cluster_names = names,
+ centroids = tibble::as_tibble(object$centers[reorder_clusts, , drop = FALSE]),
+ n_members = object$size[reorder_clusts],
+ within_sse = object$withinss[reorder_clusts],
+ tot_sse = object$totss,
+ orig_labels = unname(object$cluster),
+ cluster_assignments = names[reorder_clusts][object$cluster]
+ )
+}
+
+#' @export
+extract_fit_summary.KMeansCluster <- function(object, ...) {
+ reorder_clusts <- order(unique(object$cluster))
+ names <- paste0("Cluster_", seq_len(nrow(object$centroids)))
+
+ list(
+ cluster_names = names,
+ centroids = tibble::as_tibble(object$centroids[reorder_clusts, , drop = FALSE]),
+ n_members = object$obs_per_cluster[reorder_clusts],
+ within_sse = object$WCSS_per_cluster[reorder_clusts],
+ tot_sse = object$total_SSE,
+ orig_labels = object$clusters,
+ cluster_assignments = names[reorder_clusts][object$clusters]
+ )
+}
+
+#' @export
+extract_fit_summary.hclust <- function(object, ...) {
+
+ clusts <- extract_cluster_assignment(object, ...)$.cluster
+ n_clust <- dplyr::n_distinct(clusts)
+
+ training_data <- attr(object, "training_data")
+
+ overall_centroid <- colMeans(training_data)
+
+ by_clust <- training_data %>%
+ tibble::as_tibble() %>%
+ dplyr::mutate(
+ .cluster = clusts
+ ) %>%
+ dplyr::group_by(.cluster) %>%
+ tidyr::nest()
+
+ centroids <- by_clust$data %>%
+ purrr::map_dfr(~ .x %>% dplyr::summarize_all(mean))
+
+ within_sse <- by_clust$data %>%
+ purrr::map2_dbl(1:n_clust,
+ ~ sum(Rfast::dista(centroids[.y,], .x)))
+
+ list(
+ cluster_names = unique(clusts),
+ centroids = centroids,
+ n_members = unname(table(clusts)),
+ within_sse = within_sse,
+ tot_sse = sum(Rfast::dista(t(overall_centroid), training_data)),
+ orig_labels = NULL,
+ cluster_assignments = clusts
+ )
+}
diff --git a/R/finalize.R b/R/finalize.R
index 8da34a4e..12d279f7 100644
--- a/R/finalize.R
+++ b/R/finalize.R
@@ -1,62 +1,62 @@
-#' Splice final parameters into objects
-#'
-#' The `finalize_*` functions take a list or tibble of tuning parameter values and
-#' update objects with those values.
-#'
-#' @param x A recipe, `parsnip` model specification, or workflow.
-#' @param parameters A list or 1-row tibble of parameter values. Note that the
-#' column names of the tibble should be the `id` fields attached to `tune()`.
-#' For example, in the `Examples` section below, the model has `tune("K")`. In
-#' this case, the parameter tibble should be "K" and not "neighbors".
-#' @return An updated version of `x`.
-#' @export
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = tune())
-#'
-#' best_params <- data.frame(num_clusters = 5)
-#' best_params
-#'
-#' kmeans_spec
-#' finalize_model_tidyclust(kmeans_spec, best_params)
-finalize_model_tidyclust <- function(x, parameters) {
- if (!inherits(x, "cluster_spec")) {
- rlang::abort("`x` should be a tidyclust model specification.")
- }
- parsnip::check_final_param(parameters)
- pset <- hardhat::extract_parameter_set_dials(x)
- if (tibble::is_tibble(parameters)) {
- parameters <- as.list(parameters)
- }
-
- parameters <- parameters[names(parameters) %in% pset$id]
-
- discordant <- dplyr::filter(pset, id != name & id %in% names(parameters))
- if (nrow(discordant) > 0) {
- for (i in 1:nrow(discordant)) {
- names(parameters)[names(parameters) == discordant$id[i]] <-
- discordant$name[i]
- }
- }
- rlang::exec(stats::update, object = x, !!!parameters)
-}
-
-#' @rdname finalize_model_tidyclust
-#' @export
-finalize_workflow_tidyclust <- function(x, parameters) {
- if (!inherits(x, "workflow")) {
- rlang::abort("`x` should be a workflow")
- }
- parsnip::check_final_param(parameters)
-
- mod <- extract_spec_parsnip(x)
- mod <- finalize_model_tidyclust(mod, parameters)
- x <- set_workflow_spec(x, mod)
-
- if (has_preprocessor_recipe(x)) {
- rec <- extract_preprocessor(x)
- rec <- tune::finalize_recipe(rec, parameters)
- x <- set_workflow_recipe(x, rec)
- }
-
- x
-}
+#' Splice final parameters into objects
+#'
+#' The `finalize_*` functions take a list or tibble of tuning parameter values and
+#' update objects with those values.
+#'
+#' @param x A recipe, `parsnip` model specification, or workflow.
+#' @param parameters A list or 1-row tibble of parameter values. Note that the
+#' column names of the tibble should be the `id` fields attached to `tune()`.
+#' For example, in the `Examples` section below, the model has `tune("K")`. In
+#' this case, the parameter tibble should be "K" and not "neighbors".
+#' @return An updated version of `x`.
+#' @export
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = tune())
+#'
+#' best_params <- data.frame(num_clusters = 5)
+#' best_params
+#'
+#' kmeans_spec
+#' finalize_model_tidyclust(kmeans_spec, best_params)
+finalize_model_tidyclust <- function(x, parameters) {
+ if (!inherits(x, "cluster_spec")) {
+ rlang::abort("`x` should be a tidyclust model specification.")
+ }
+ parsnip::check_final_param(parameters)
+ pset <- hardhat::extract_parameter_set_dials(x)
+ if (tibble::is_tibble(parameters)) {
+ parameters <- as.list(parameters)
+ }
+
+ parameters <- parameters[names(parameters) %in% pset$id]
+
+ discordant <- dplyr::filter(pset, id != name & id %in% names(parameters))
+ if (nrow(discordant) > 0) {
+ for (i in 1:nrow(discordant)) {
+ names(parameters)[names(parameters) == discordant$id[i]] <-
+ discordant$name[i]
+ }
+ }
+ rlang::exec(stats::update, object = x, !!!parameters)
+}
+
+#' @rdname finalize_model_tidyclust
+#' @export
+finalize_workflow_tidyclust <- function(x, parameters) {
+ if (!inherits(x, "workflow")) {
+ rlang::abort("`x` should be a workflow")
+ }
+ parsnip::check_final_param(parameters)
+
+ mod <- extract_spec_parsnip(x)
+ mod <- finalize_model_tidyclust(mod, parameters)
+ x <- set_workflow_spec(x, mod)
+
+ if (has_preprocessor_recipe(x)) {
+ rec <- extract_preprocessor(x)
+ rec <- tune::finalize_recipe(rec, parameters)
+ x <- set_workflow_recipe(x, rec)
+ }
+
+ x
+}
diff --git a/R/fit.R b/R/fit.R
index f31f69f0..c7a71492 100644
--- a/R/fit.R
+++ b/R/fit.R
@@ -1,334 +1,334 @@
-#' Fit a Model Specification to a Data Set
-#'
-#' `fit()` and `fit_xy()` take a model specification, translate_tidyclust the
-#' required code by substituting arguments, and execute the model fit routine.
-#'
-#' @param object An object of class `cluster_spec` that has a chosen engine (via
-#' [set_engine()]).
-#' @param formula An object of class `formula` (or one that can be coerced to
-#' that class): a symbolic description of the model to be fitted.
-#' @param data Optional, depending on the interface (see Details below). A data
-#' frame containing all relevant variables (e.g. predictors, case weights,
-#' etc). Note: when needed, a \emph{named argument} should be used.
-#' @param control A named list with elements `verbosity` and `catch`. See
-#' [control_cluster()].
-#' @param ... Not currently used; values passed here will be ignored. Other
-#' options required to fit the model should be passed using
-#' `set_engine()`.
-#' @details `fit()` and `fit_xy()` substitute the current arguments in the
-#' model specification into the computational engine's code, check them for
-#' validity, then fit the model using the data and the engine-specific code.
-#' Different model functions have different interfaces (e.g. formula or
-#' `x`/`y`) and these functions translate_tidyclust between the interface used
-#' when `fit()` or `fit_xy()` was invoked and the one required by the
-#' underlying model.
-#'
-#' When possible, these functions attempt to avoid making copies of the data.
-#' For example, if the underlying model uses a formula and `fit()` is invoked,
-#' the original data are references when the model is fit. However, if the
-#' underlying model uses something else, such as `x`/`y`, the formula is
-#' evaluated and the data are converted to the required format. In this case,
-#' any calls in the resulting model objects reference the temporary objects
-#' used to fit the model.
-#'
-#' If the model engine has not been set, the model's default engine will be
-#' used (as discussed on each model page). If the `verbosity` option of
-#' [control_cluster()] is greater than zero, a warning will be produced.
-#'
-#' If you would like to use an alternative method for generating contrasts
-#' when supplying a formula to `fit()`, set the global option `contrasts` to
-#' your preferred method. For example, you might set it to: `options(contrasts
-#' = c(unordered = "contr.helmert", ordered = "contr.poly"))`. See the help
-#' page for [stats::contr.treatment()] for more possible contrast types.
-#' @examples
-#' library(dplyr)
-#'
-#' kmeans_mod <- k_means(num_clusters = 5)
-#'
-#' using_formula <-
-#' kmeans_mod %>%
-#' set_engine("stats") %>%
-#' fit(~., data = mtcars)
-#'
-#' using_x <-
-#' kmeans_mod %>%
-#' set_engine("stats") %>%
-#' fit_xy(x = mtcars)
-#'
-#' using_formula
-#' using_x
-#' @return A `cluster_fit` object that contains several elements:
-#' \itemize{
-#' \item \code{spec}: The model specification object (\code{object} in the
-#' call to \code{fit})
-#' \item \code{fit}: when the model is executed without error, this is the
-#' model object. Otherwise, it is a \code{try-error}
-#' object with the error message.
-#' \item \code{preproc}: any objects needed to convert between a formula and
-#' non-formula interface
-#' (such as the \code{terms} object)
-#' }
-#' The return value will also have a class related to the fitted model (e.g.
-#' `"_kmeans"`) before the base class of `"cluster_fit"`.
-#'
-#' @seealso [set_engine()], [control_cluster()], `cluster_spec`,
-#' `cluster_fit`
-#' @param x A matrix, sparse matrix, or data frame of predictors. Only some
-#' models have support for sparse matrix input. See
-#' `tidyclust::get_encoding_tidyclust()` for details. `x` should have column names.
-#' @param case_weights An optional classed vector of numeric case weights. This
-#' must return `TRUE` when [hardhat::is_case_weights()] is run on it. See
-#' [hardhat::frequency_weights()] and [hardhat::importance_weights()] for
-#' examples.
-#' @rdname fit
-#' @export
-#' @export fit.cluster_spec
-fit.cluster_spec <- function(object,
- formula,
- data,
- control = control_cluster(),
- ...) {
- if (object$mode == "unknown") {
- rlang::abort("Please set the mode in the model specification.")
- }
- # if (!inherits(control, "control_cluster")) {
- # rlang::abort("The 'control' argument should have class 'control_cluster'.")
- # }
- dots <- quos(...)
- if (is.null(object$engine)) {
- eng_vals <- possible_engines(object)
- object$engine <- eng_vals[1]
- if (control$verbosity > 0) {
- rlang::warn(glue::glue("Engine set to `{object$engine}`."))
- }
- }
-
- if (all(c("x", "y") %in% names(dots))) {
- rlang::abort("`fit.cluster_spec()` is for the formula methods. Use `fit_xy()` instead.")
- }
- cl <- match.call(expand.dots = TRUE)
- # Create an environment with the evaluated argument objects. This will be
- # used when a model call is made later.
- eval_env <- rlang::env()
-
- eval_env$data <- data
- eval_env$formula <- formula
- fit_interface <-
- check_interface(eval_env$formula, eval_env$data, cl, object)
-
- # populate `method` with the details for this model type
- object <- add_methods(object, engine = object$engine)
-
- check_installs(object)
-
- interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
-
- # Now call the wrappers that transition between the interface
- # called here ("fit" interface) that will direct traffic to
- # what the underlying model uses. For example, if a formula is
- # used here, `fit_interface_formula` will determine if a
- # translation has to be made if the model interface is x/y/
- res <-
- switch(interfaces,
- # homogeneous combinations:
- formula_formula =
- form_form(
- object = object,
- control = control,
- env = eval_env
- ),
-
- # heterogenous combinations
- formula_matrix =
- form_x(
- object = object,
- control = control,
- env = eval_env,
- target = object$method$fit$interface,
- ...
- ),
- formula_data.frame =
- form_x(
- object = object,
- control = control,
- env = eval_env,
- target = object$method$fit$interface,
- ...
- ),
- rlang::abort(glue::glue("{interfaces} is unknown."))
- )
- model_classes <- class(res$fit)
- class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
- res
-}
-
-check_interface <- function(formula, data, cl, model) {
- inher(formula, "formula", cl)
-
- # Determine the `fit()` interface
- form_interface <- !is.null(formula) & !is.null(data)
-
- if (form_interface) {
- return("formula")
- }
- rlang::abort("Error when checking the interface.")
-}
-
-inher <- function(x, cls, cl) {
- if (!is.null(x) && !inherits(x, cls)) {
- call <- match.call()
- obj <- deparse(call[["x"]])
- if (length(cls) > 1) {
- rlang::abort(
- glue::glue(
- "`{obj}` should be one of the following classes: ",
- glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ")
- )
- )
- } else {
- rlang::abort(
- glue::glue("`{obj}` should be a {cls} object")
- )
- }
- }
- invisible(x)
-}
-
-add_methods <- function(x, engine) {
- x$engine <- engine
- check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
- x$method <- get_cluster_spec(specific_model(x), x$mode, x$engine)
- x
-}
-
-# ------------------------------------------------------------------------------
-
-eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) {
- if (capture) {
- if (catch) {
- junk <- capture.output(res <- try(rlang::eval_tidy(e, ...), silent = TRUE))
- } else {
- junk <- capture.output(res <- rlang::eval_tidy(e, ...))
- }
- } else {
- if (catch) {
- res <- try(rlang::eval_tidy(e, ...), silent = TRUE)
- } else {
- res <- rlang::eval_tidy(e, ...)
- }
- }
- res
-}
-
-# ------------------------------------------------------------------------------
-
-#' @rdname fit
-#' @export
-#' @export fit_xy.cluster_spec
-fit_xy.cluster_spec <-
- function(object, x, case_weights = NULL, control = control_cluster(), ...) {
- # if (!inherits(control, "control_cluster")) {
- # rlang::abort("The 'control' argument should have class 'control_cluster'.")
- # }
- if (is.null(colnames(x))) {
- rlang::abort("'x' should have column names.")
- }
-
- dots <- quos(...)
- if (is.null(object$engine)) {
- eng_vals <- possible_engines(object)
- object$engine <- eng_vals[1]
- if (control$verbosity > 0) {
- rlang::warn(glue::glue("Engine set to `{object$engine}`."))
- }
- }
-
- cl <- match.call(expand.dots = TRUE)
- eval_env <- rlang::env()
- eval_env$x <- x
- fit_interface <- check_x_interface(eval_env$x, cl, object)
-
- # populate `method` with the details for this model type
- object <- add_methods(object, engine = object$engine)
-
- check_installs(object)
-
- interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
-
- # Now call the wrappers that transition between the interface
- # called here ("fit" interface) that will direct traffic to
- # what the underlying model uses. For example, if a formula is
- # used here, `fit_interface_formula` will determine if a
- # translation has to be made if the model interface is x/y/
- res <-
- switch(interfaces,
- # homogeneous combinations:
- matrix_matrix = ,
- data.frame_matrix =
- x_x(
- object = object,
- env = eval_env,
- control = control,
- target = "matrix",
- ...
- ),
- data.frame_data.frame = ,
- matrix_data.frame =
- x_x(
- object = object,
- env = eval_env,
- control = control,
- target = "data.frame",
- ...
- ),
-
- # heterogenous combinations
- matrix_formula = ,
- data.frame_formula =
- x_form(
- object = object,
- env = eval_env,
- control = control,
- ...
- ),
- rlang::abort(glue::glue("{interfaces} is unknown."))
- )
- model_classes <- class(res$fit)
- class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
- res
- }
-
-check_x_interface <- function(x, cl, model) {
- sparse_ok <- allow_sparse(model)
- sparse_x <- inherits(x, "dgCMatrix")
- if (!sparse_ok & sparse_x) {
- rlang::abort("Sparse matrices not supported by this model/engine combination.")
- }
-
- if (sparse_ok) {
- inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
- } else {
- inher(x, c("data.frame", "matrix"), cl)
- }
-
- if (sparse_ok) {
- matrix_interface <- !is.null(x) && (is.matrix(x) | sparse_x)
- } else {
- matrix_interface <- !is.null(x) && is.matrix(x)
- }
-
- df_interface <- !is.null(x) && is.data.frame(x)
-
- if (matrix_interface) {
- return("matrix")
- }
- if (df_interface) {
- return("data.frame")
- }
- rlang::abort("Error when checking the interface")
-}
-
-allow_sparse <- function(x) {
- res <- get_from_env_tidyclust(paste0(class(x)[1], "_encoding"))
- all(res$allow_sparse_x[res$engine == x$engine])
-}
+#' Fit a Model Specification to a Data Set
+#'
+#' `fit()` and `fit_xy()` take a model specification, translate_tidyclust the
+#' required code by substituting arguments, and execute the model fit routine.
+#'
+#' @param object An object of class `cluster_spec` that has a chosen engine (via
+#' [set_engine()]).
+#' @param formula An object of class `formula` (or one that can be coerced to
+#' that class): a symbolic description of the model to be fitted.
+#' @param data Optional, depending on the interface (see Details below). A data
+#' frame containing all relevant variables (e.g. predictors, case weights,
+#' etc). Note: when needed, a \emph{named argument} should be used.
+#' @param control A named list with elements `verbosity` and `catch`. See
+#' [control_cluster()].
+#' @param ... Not currently used; values passed here will be ignored. Other
+#' options required to fit the model should be passed using
+#' `set_engine()`.
+#' @details `fit()` and `fit_xy()` substitute the current arguments in the
+#' model specification into the computational engine's code, check them for
+#' validity, then fit the model using the data and the engine-specific code.
+#' Different model functions have different interfaces (e.g. formula or
+#' `x`/`y`) and these functions translate_tidyclust between the interface used
+#' when `fit()` or `fit_xy()` was invoked and the one required by the
+#' underlying model.
+#'
+#' When possible, these functions attempt to avoid making copies of the data.
+#' For example, if the underlying model uses a formula and `fit()` is invoked,
+#' the original data are references when the model is fit. However, if the
+#' underlying model uses something else, such as `x`/`y`, the formula is
+#' evaluated and the data are converted to the required format. In this case,
+#' any calls in the resulting model objects reference the temporary objects
+#' used to fit the model.
+#'
+#' If the model engine has not been set, the model's default engine will be
+#' used (as discussed on each model page). If the `verbosity` option of
+#' [control_cluster()] is greater than zero, a warning will be produced.
+#'
+#' If you would like to use an alternative method for generating contrasts
+#' when supplying a formula to `fit()`, set the global option `contrasts` to
+#' your preferred method. For example, you might set it to: `options(contrasts
+#' = c(unordered = "contr.helmert", ordered = "contr.poly"))`. See the help
+#' page for [stats::contr.treatment()] for more possible contrast types.
+#' @examples
+#' library(dplyr)
+#'
+#' kmeans_mod <- k_means(num_clusters = 5)
+#'
+#' using_formula <-
+#' kmeans_mod %>%
+#' set_engine("stats") %>%
+#' fit(~., data = mtcars)
+#'
+#' using_x <-
+#' kmeans_mod %>%
+#' set_engine("stats") %>%
+#' fit_xy(x = mtcars)
+#'
+#' using_formula
+#' using_x
+#' @return A `cluster_fit` object that contains several elements:
+#' \itemize{
+#' \item \code{spec}: The model specification object (\code{object} in the
+#' call to \code{fit})
+#' \item \code{fit}: when the model is executed without error, this is the
+#' model object. Otherwise, it is a \code{try-error}
+#' object with the error message.
+#' \item \code{preproc}: any objects needed to convert between a formula and
+#' non-formula interface
+#' (such as the \code{terms} object)
+#' }
+#' The return value will also have a class related to the fitted model (e.g.
+#' `"_kmeans"`) before the base class of `"cluster_fit"`.
+#'
+#' @seealso [set_engine()], [control_cluster()], `cluster_spec`,
+#' `cluster_fit`
+#' @param x A matrix, sparse matrix, or data frame of predictors. Only some
+#' models have support for sparse matrix input. See
+#' `tidyclust::get_encoding_tidyclust()` for details. `x` should have column names.
+#' @param case_weights An optional classed vector of numeric case weights. This
+#' must return `TRUE` when [hardhat::is_case_weights()] is run on it. See
+#' [hardhat::frequency_weights()] and [hardhat::importance_weights()] for
+#' examples.
+#' @rdname fit
+#' @export
+#' @export fit.cluster_spec
+fit.cluster_spec <- function(object,
+ formula,
+ data,
+ control = control_cluster(),
+ ...) {
+ if (object$mode == "unknown") {
+ rlang::abort("Please set the mode in the model specification.")
+ }
+ # if (!inherits(control, "control_cluster")) {
+ # rlang::abort("The 'control' argument should have class 'control_cluster'.")
+ # }
+ dots <- quos(...)
+ if (is.null(object$engine)) {
+ eng_vals <- possible_engines(object)
+ object$engine <- eng_vals[1]
+ if (control$verbosity > 0) {
+ rlang::warn(glue::glue("Engine set to `{object$engine}`."))
+ }
+ }
+
+ if (all(c("x", "y") %in% names(dots))) {
+ rlang::abort("`fit.cluster_spec()` is for the formula methods. Use `fit_xy()` instead.")
+ }
+ cl <- match.call(expand.dots = TRUE)
+ # Create an environment with the evaluated argument objects. This will be
+ # used when a model call is made later.
+ eval_env <- rlang::env()
+
+ eval_env$data <- data
+ eval_env$formula <- formula
+ fit_interface <-
+ check_interface(eval_env$formula, eval_env$data, cl, object)
+
+ # populate `method` with the details for this model type
+ object <- add_methods(object, engine = object$engine)
+
+ check_installs(object)
+
+ interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
+
+ # Now call the wrappers that transition between the interface
+ # called here ("fit" interface) that will direct traffic to
+ # what the underlying model uses. For example, if a formula is
+ # used here, `fit_interface_formula` will determine if a
+ # translation has to be made if the model interface is x/y/
+ res <-
+ switch(interfaces,
+ # homogeneous combinations:
+ formula_formula =
+ form_form(
+ object = object,
+ control = control,
+ env = eval_env
+ ),
+
+ # heterogenous combinations
+ formula_matrix =
+ form_x(
+ object = object,
+ control = control,
+ env = eval_env,
+ target = object$method$fit$interface,
+ ...
+ ),
+ formula_data.frame =
+ form_x(
+ object = object,
+ control = control,
+ env = eval_env,
+ target = object$method$fit$interface,
+ ...
+ ),
+ rlang::abort(glue::glue("{interfaces} is unknown."))
+ )
+ model_classes <- class(res$fit)
+ class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
+ res
+}
+
+check_interface <- function(formula, data, cl, model) {
+ inher(formula, "formula", cl)
+
+ # Determine the `fit()` interface
+ form_interface <- !is.null(formula) & !is.null(data)
+
+ if (form_interface) {
+ return("formula")
+ }
+ rlang::abort("Error when checking the interface.")
+}
+
+inher <- function(x, cls, cl) {
+ if (!is.null(x) && !inherits(x, cls)) {
+ call <- match.call()
+ obj <- deparse(call[["x"]])
+ if (length(cls) > 1) {
+ rlang::abort(
+ glue::glue(
+ "`{obj}` should be one of the following classes: ",
+ glue::glue_collapse(glue::glue("'{cls}'"), sep = ", ")
+ )
+ )
+ } else {
+ rlang::abort(
+ glue::glue("`{obj}` should be a {cls} object")
+ )
+ }
+ }
+ invisible(x)
+}
+
+add_methods <- function(x, engine) {
+ x$engine <- engine
+ check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
+ x$method <- get_cluster_spec(specific_model(x), x$mode, x$engine)
+ x
+}
+
+# ------------------------------------------------------------------------------
+
+eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) {
+ if (capture) {
+ if (catch) {
+ junk <- capture.output(res <- try(rlang::eval_tidy(e, ...), silent = TRUE))
+ } else {
+ junk <- capture.output(res <- rlang::eval_tidy(e, ...))
+ }
+ } else {
+ if (catch) {
+ res <- try(rlang::eval_tidy(e, ...), silent = TRUE)
+ } else {
+ res <- rlang::eval_tidy(e, ...)
+ }
+ }
+ res
+}
+
+# ------------------------------------------------------------------------------
+
+#' @rdname fit
+#' @export
+#' @export fit_xy.cluster_spec
+fit_xy.cluster_spec <-
+ function(object, x, case_weights = NULL, control = control_cluster(), ...) {
+ # if (!inherits(control, "control_cluster")) {
+ # rlang::abort("The 'control' argument should have class 'control_cluster'.")
+ # }
+ if (is.null(colnames(x))) {
+ rlang::abort("'x' should have column names.")
+ }
+
+ dots <- quos(...)
+ if (is.null(object$engine)) {
+ eng_vals <- possible_engines(object)
+ object$engine <- eng_vals[1]
+ if (control$verbosity > 0) {
+ rlang::warn(glue::glue("Engine set to `{object$engine}`."))
+ }
+ }
+
+ cl <- match.call(expand.dots = TRUE)
+ eval_env <- rlang::env()
+ eval_env$x <- x
+ fit_interface <- check_x_interface(eval_env$x, cl, object)
+
+ # populate `method` with the details for this model type
+ object <- add_methods(object, engine = object$engine)
+
+ check_installs(object)
+
+ interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
+
+ # Now call the wrappers that transition between the interface
+ # called here ("fit" interface) that will direct traffic to
+ # what the underlying model uses. For example, if a formula is
+ # used here, `fit_interface_formula` will determine if a
+ # translation has to be made if the model interface is x/y/
+ res <-
+ switch(interfaces,
+ # homogeneous combinations:
+ matrix_matrix = ,
+ data.frame_matrix =
+ x_x(
+ object = object,
+ env = eval_env,
+ control = control,
+ target = "matrix",
+ ...
+ ),
+ data.frame_data.frame = ,
+ matrix_data.frame =
+ x_x(
+ object = object,
+ env = eval_env,
+ control = control,
+ target = "data.frame",
+ ...
+ ),
+
+ # heterogenous combinations
+ matrix_formula = ,
+ data.frame_formula =
+ x_form(
+ object = object,
+ env = eval_env,
+ control = control,
+ ...
+ ),
+ rlang::abort(glue::glue("{interfaces} is unknown."))
+ )
+ model_classes <- class(res$fit)
+ class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
+ res
+ }
+
+check_x_interface <- function(x, cl, model) {
+ sparse_ok <- allow_sparse(model)
+ sparse_x <- inherits(x, "dgCMatrix")
+ if (!sparse_ok & sparse_x) {
+ rlang::abort("Sparse matrices not supported by this model/engine combination.")
+ }
+
+ if (sparse_ok) {
+ inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
+ } else {
+ inher(x, c("data.frame", "matrix"), cl)
+ }
+
+ if (sparse_ok) {
+ matrix_interface <- !is.null(x) && (is.matrix(x) | sparse_x)
+ } else {
+ matrix_interface <- !is.null(x) && is.matrix(x)
+ }
+
+ df_interface <- !is.null(x) && is.data.frame(x)
+
+ if (matrix_interface) {
+ return("matrix")
+ }
+ if (df_interface) {
+ return("data.frame")
+ }
+ rlang::abort("Error when checking the interface")
+}
+
+allow_sparse <- function(x) {
+ res <- get_from_env_tidyclust(paste0(class(x)[1], "_encoding"))
+ all(res$allow_sparse_x[res$engine == x$engine])
+}
diff --git a/R/hier_clust.R b/R/hier_clust.R
index 36554e97..d50b7450 100644
--- a/R/hier_clust.R
+++ b/R/hier_clust.R
@@ -1,96 +1,96 @@
-#' Hierarchical (Agglomerative) Clustering
-#'
-#' @description
-#'
-#' `hier_clust()` defines a model that fits clusters based on a distance-based
-#' dendrogram
-#'
-#' @param mode A single character string for the type of model.
-#' The only possible value for this model is "partition".
-#' @param engine A single character string specifying what computational engine
-#' to use for fitting. Possible engines are listed below. The default for this
-#' model is `"stats"`.
-#' @param k Positive integer, number of clusters in model (optional).
-#' @param h Positive double, height at which to cut dendrogram to obtain cluster
-#' assignments (only used if `k` is `NULL`)
-#' @param linkage_method the agglomeration method to be used. This should be (an
-#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`,
-#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"`
-#' (= WPGMC) or `"centroid"` (= UPGMC).
-#' @param dist_fun A distance function to use
-#'
-#' @examples
-#' # show_engines("hier_clust")
-#'
-#' hier_clust()
-#' @export
-hier_clust <-
- function(mode = "partition",
- engine = "stats",
- k = NULL,
- h = NULL,
- linkage_method = "complete") {
- args <- list(
- k = enquo(k),
- h = enquo(h),
- linkage_method = enquo(linkage_method)
- )
-
- new_cluster_spec(
- "hier_clust",
- args = args,
- eng_args = NULL,
- mode = mode,
- method = NULL,
- engine = engine
- )
- }
-
-#' @export
-print.hier_clust <- function(x, ...) {
- cat("Hierarchical Clustering Specification (", x$mode, ")\n\n", sep = "")
- model_printer(x, ...)
-
- if (!is.null(x$method$fit$args)) {
- cat("Model fit template:\n")
- print(show_call(x))
- }
-
- invisible(x)
-}
-
-#' @export
-translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) {
- x <- translate_tidyclust.default(x, engine, ...)
- x
-}
-
-# ------------------------------------------------------------------------------
-
-#' Simple Wrapper around hclust function
-#'
-#' This wrapper prepares the data into a distance matrix to send to
-#' `stats::hclust` and retains the parameters `k` or `h` as an attribute.
-#'
-#' @param x matrix or data frame
-#' @param k the number of clusters
-#' @param h the height to cut the dendrogram
-#' @param linkage_method the agglomeration method to be used. This should be (an
-#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`,
-#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"`
-#' (= WPGMC) or `"centroid"` (= UPGMC).
-#' @param dist_fun A distance function to use
-#'
-#' @return A dendrogram
-#' @keywords internal
-#' @export
-hclust_fit <- function(x, k = NULL, cut_height = NULL,
- linkage_method = NULL,
- dist_fun = Rfast::Dist) {
- dmat <- dist_fun(x)
- res <- hclust(as.dist(dmat), method = linkage_method)
- attr(res, "k") <- k
- attr(res, "cut_height") <- cut_height
- attr(res, "training_data") <- x
- return(res)
-}
+#' Hierarchical (Agglomerative) Clustering
+#'
+#' @description
+#'
+#' `hier_clust()` defines a model that fits clusters based on a distance-based
+#' dendrogram
+#'
+#' @param mode A single character string for the type of model.
+#' The only possible value for this model is "partition".
+#' @param engine A single character string specifying what computational engine
+#' to use for fitting. Possible engines are listed below. The default for this
+#' model is `"stats"`.
+#' @param k Positive integer, number of clusters in model (optional).
+#' @param h Positive double, height at which to cut dendrogram to obtain cluster
+#' assignments (only used if `k` is `NULL`)
+#' @param linkage_method the agglomeration method to be used. This should be (an
+#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`,
+#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"`
+#' (= WPGMC) or `"centroid"` (= UPGMC).
+#' @param dist_fun A distance function to use
+#'
+#' @examples
+#' # show_engines("hier_clust")
+#'
+#' hier_clust()
+#' @export
+hier_clust <-
+ function(mode = "partition",
+ engine = "stats",
+ k = NULL,
+ h = NULL,
+ linkage_method = "complete") {
+ args <- list(
+ k = enquo(k),
+ h = enquo(h),
+ linkage_method = enquo(linkage_method)
+ )
+
+ new_cluster_spec(
+ "hier_clust",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = engine
+ )
+ }
+
+#' @export
+print.hier_clust <- function(x, ...) {
+ cat("Hierarchical Clustering Specification (", x$mode, ")\n\n", sep = "")
+ model_printer(x, ...)
+
+ if (!is.null(x$method$fit$args)) {
+ cat("Model fit template:\n")
+ print(show_call(x))
+ }
+
+ invisible(x)
+}
+
+#' @export
+translate_tidyclust.hier_clust <- function(x, engine = x$engine, ...) {
+ x <- translate_tidyclust.default(x, engine, ...)
+ x
+}
+
+# ------------------------------------------------------------------------------
+
+#' Simple Wrapper around hclust function
+#'
+#' This wrapper prepares the data into a distance matrix to send to
+#' `stats::hclust` and retains the parameters `k` or `h` as an attribute.
+#'
+#' @param x matrix or data frame
+#' @param k the number of clusters
+#' @param h the height to cut the dendrogram
+#' @param linkage_method the agglomeration method to be used. This should be (an
+#' unambiguous abbreviation of) one of `"ward.D"`, `"ward.D2"`, `"single"`,
+#' `"complete"`, `"average"` (= UPGMA), `"mcquitty"` (= WPGMA), `"median"`
+#' (= WPGMC) or `"centroid"` (= UPGMC).
+#' @param dist_fun A distance function to use
+#'
+#' @return A dendrogram
+#' @keywords internal
+#' @export
+hclust_fit <- function(x, k = NULL, cut_height = NULL,
+ linkage_method = NULL,
+ dist_fun = Rfast::Dist) {
+ dmat <- dist_fun(x)
+ res <- hclust(as.dist(dmat), method = linkage_method)
+ attr(res, "k") <- k
+ attr(res, "cut_height") <- cut_height
+ attr(res, "training_data") <- x
+ return(res)
+}
diff --git a/R/hier_clust_data.R b/R/hier_clust_data.R
index edc93559..b4cf14e1 100644
--- a/R/hier_clust_data.R
+++ b/R/hier_clust_data.R
@@ -1,132 +1,132 @@
-set_new_model_tidyclust("hier_clust")
-
-set_model_mode_tidyclust("hier_clust", "partition")
-
-# ------------------------------------------------------------------------------
-
-set_model_engine_tidyclust("hier_clust", "partition", "stats")
-set_dependency_tidyclust("hier_clust", "stats", "stats")
-
-set_fit_tidyclust(
- model = "hier_clust",
- eng = "stats",
- mode = "partition",
- value = list(
- interface = "matrix",
- protect = c("data"),
- func = c(pkg = "tidyclust", fun = "hclust_fit"),
- defaults = list()
- )
-)
-
-set_encoding_tidyclust(
- model = "hier_clust",
- eng = "stats",
- mode = "partition",
- options = list(
- predictor_indicators = "traditional",
- compute_intercept = TRUE,
- remove_intercept = TRUE,
- allow_sparse_x = FALSE
- )
-)
-
-set_model_arg_tidyclust(
- model = "hier_clust",
- eng = "stats",
- tidyclust = "k",
- original = "k",
- func = list(pkg = "tidyclust", fun = "k"),
- has_submodel = TRUE
-)
-
-set_model_arg_tidyclust(
- model = "hier_clust",
- eng = "stats",
- tidyclust = "linkage_method",
- original = "linkage_method",
- func = list(pkg = "tidyclust", fun = "linkage_method"),
- has_submodel = TRUE
-)
-
-set_model_arg_tidyclust(
- model = "hier_clust",
- eng = "stats",
- tidyclust = "cut_height",
- original = "cut_height",
- func = list(pkg = "tidyclust", fun = "cut_height"),
- has_submodel = TRUE
-)
-
-set_pred_tidyclust(
- model = "hier_clust",
- eng = "stats",
- mode = "partition",
- type = "cluster",
- value = list(
- pre = NULL,
- post = NULL,
- func = c(fun = "stats_hier_clust_predict"),
- args =
- list(
- object = rlang::expr(object$fit),
- new_data = rlang::expr(new_data)
- )
- )
-)
-
-# ------------------------------------------------------------------------------
-#
-# set_model_engine_tidyclust("k_means", "partition", "ClusterR")
-# set_dependency_tidyclust("k_means", "ClusterR", "ClusterR")
-#
-# set_fit_tidyclust(
-# model = "k_means",
-# eng = "ClusterR",
-# mode = "partition",
-# value = list(
-# interface = "matrix",
-# data = c(x = "data"),
-# protect = c("data", "clusters"),
-# func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"),
-# defaults = list()
-# )
-# )
-#
-# set_encoding_tidyclust(
-# model = "k_means",
-# eng = "ClusterR",
-# mode = "partition",
-# options = list(
-# predictor_indicators = "traditional",
-# compute_intercept = TRUE,
-# remove_intercept = TRUE,
-# allow_sparse_x = FALSE
-# )
-# )
-#
-# set_model_arg_tidyclust(
-# model = "k_means",
-# eng = "ClusterR",
-# tidyclust = "k",
-# original = "clusters",
-# func = list(pkg = "dials", fun = "k"),
-# has_submodel = TRUE
-# )
-#
-# set_pred_tidyclust(
-# model = "k_means",
-# eng = "ClusterR",
-# mode = "partition",
-# type = "cluster",
-# value = list(
-# pre = NULL,
-# post = NULL,
-# func = c(fun = "clusterR_kmeans_predict"),
-# args =
-# list(
-# object = rlang::expr(object$fit),
-# new_data = rlang::expr(new_data)
-# )
-# )
-# )
+set_new_model_tidyclust("hier_clust")
+
+set_model_mode_tidyclust("hier_clust", "partition")
+
+# ------------------------------------------------------------------------------
+
+set_model_engine_tidyclust("hier_clust", "partition", "stats")
+set_dependency_tidyclust("hier_clust", "stats", "stats")
+
+set_fit_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ mode = "partition",
+ value = list(
+ interface = "matrix",
+ protect = c("data"),
+ func = c(pkg = "tidyclust", fun = "hclust_fit"),
+ defaults = list()
+ )
+)
+
+set_encoding_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ mode = "partition",
+ options = list(
+ predictor_indicators = "traditional",
+ compute_intercept = TRUE,
+ remove_intercept = TRUE,
+ allow_sparse_x = FALSE
+ )
+)
+
+set_model_arg_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ tidyclust = "k",
+ original = "k",
+ func = list(pkg = "tidyclust", fun = "k"),
+ has_submodel = TRUE
+)
+
+set_model_arg_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ tidyclust = "linkage_method",
+ original = "linkage_method",
+ func = list(pkg = "tidyclust", fun = "linkage_method"),
+ has_submodel = TRUE
+)
+
+set_model_arg_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ tidyclust = "cut_height",
+ original = "cut_height",
+ func = list(pkg = "tidyclust", fun = "cut_height"),
+ has_submodel = TRUE
+)
+
+set_pred_tidyclust(
+ model = "hier_clust",
+ eng = "stats",
+ mode = "partition",
+ type = "cluster",
+ value = list(
+ pre = NULL,
+ post = NULL,
+ func = c(fun = "stats_hier_clust_predict"),
+ args =
+ list(
+ object = rlang::expr(object$fit),
+ new_data = rlang::expr(new_data)
+ )
+ )
+)
+
+# ------------------------------------------------------------------------------
+#
+# set_model_engine_tidyclust("k_means", "partition", "ClusterR")
+# set_dependency_tidyclust("k_means", "ClusterR", "ClusterR")
+#
+# set_fit_tidyclust(
+# model = "k_means",
+# eng = "ClusterR",
+# mode = "partition",
+# value = list(
+# interface = "matrix",
+# data = c(x = "data"),
+# protect = c("data", "clusters"),
+# func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"),
+# defaults = list()
+# )
+# )
+#
+# set_encoding_tidyclust(
+# model = "k_means",
+# eng = "ClusterR",
+# mode = "partition",
+# options = list(
+# predictor_indicators = "traditional",
+# compute_intercept = TRUE,
+# remove_intercept = TRUE,
+# allow_sparse_x = FALSE
+# )
+# )
+#
+# set_model_arg_tidyclust(
+# model = "k_means",
+# eng = "ClusterR",
+# tidyclust = "k",
+# original = "clusters",
+# func = list(pkg = "dials", fun = "k"),
+# has_submodel = TRUE
+# )
+#
+# set_pred_tidyclust(
+# model = "k_means",
+# eng = "ClusterR",
+# mode = "partition",
+# type = "cluster",
+# value = list(
+# pre = NULL,
+# post = NULL,
+# func = c(fun = "clusterR_kmeans_predict"),
+# args =
+# list(
+# object = rlang::expr(object$fit),
+# new_data = rlang::expr(new_data)
+# )
+# )
+# )
diff --git a/R/k_means.R b/R/k_means.R
index abdc2261..5b37ea22 100644
--- a/R/k_means.R
+++ b/R/k_means.R
@@ -1,153 +1,153 @@
-#' K-Means
-#'
-#' @description
-#'
-#' `k_means()` defines a model that fits clusters based on distances to a number
-#' of centers.
-#'
-#' @param mode A single character string for the type of model.
-#' The only possible value for this model is "partition".
-#' @param engine A single character string specifying what computational engine
-#' to use for fitting. Possible engines are listed below. The default for this
-#' model is `"stats"`.
-#' @param num_clusters Positive integer, number of clusters in model.
-#'
-#' @examples
-#' # show_engines("k_means")
-#'
-#' k_means()
-#' @export
-k_means <-
- function(mode = "partition",
- engine = "stats",
- num_clusters = NULL) {
- args <- list(
- num_clusters = enquo(num_clusters)
- )
-
- new_cluster_spec(
- "k_means",
- args = args,
- eng_args = NULL,
- mode = mode,
- method = NULL,
- engine = engine
- )
- }
-
-#' @export
-print.k_means <- function(x, ...) {
- cat("K Means Cluster Specification (", x$mode, ")\n\n", sep = "")
- model_printer(x, ...)
-
- if (!is.null(x$method$fit$args)) {
- cat("Model fit template:\n")
- print(show_call(x))
- }
-
- invisible(x)
-}
-
-#' @export
-translate_tidyclust.k_means <- function(x, engine = x$engine, ...) {
- x <- translate_tidyclust.default(x, engine, ...)
- x
-}
-
-# ------------------------------------------------------------------------------
-
-#' @method update k_means
-#' @rdname tidyclust_update
-#' @export
-update.k_means <- function(object,
- parameters = NULL,
- num_clusters = NULL,
- fresh = FALSE, ...) {
-
- eng_args <- parsnip::update_engine_parameters(object$eng_args, ...)
-
- if (!is.null(parameters)) {
- parameters <- parsnip::check_final_param(parameters)
- }
- args <- list(
- num_clusters = enquo(num_clusters)
- )
-
- args <- parsnip::update_main_parameters(args, parameters)
-
- if (fresh) {
- object$args <- args
- object$eng_args <- eng_args
- } else {
- null_args <- map_lgl(args, null_value)
- if (any(null_args))
- args <- args[!null_args]
- if (length(args) > 0)
- object$args[names(args)] <- args
- if (length(eng_args) > 0)
- object$eng_args[names(eng_args)] <- eng_args
- }
-
- new_cluster_spec(
- "k_means",
- args = object$args,
- eng_args = object$eng_args,
- mode = object$mode,
- method = NULL,
- engine = object$engine
- )
-}
-
-# # ------------------------------------------------------------------------------
-
-check_args.k_means <- function(object) {
-
- args <- lapply(object$args, rlang::eval_tidy)
-
- if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0))
- rlang::abort("The number of centers should be >= 0.")
-
- invisible(object)
-}
-
-# ------------------------------------------------------------------------------
-
-#' Simple Wrapper around ClusterR kmeans
-#'
-#' This wrapper runs `ClusterR::KMeans_rcpp` and adds column names to the
-#' `centroids` field.
-#'
-#' @param data matrix or data frame
-#' @param clusters the number of clusters
-#' @param num_init number of times the algorithm will be run with different
-#' centroid seeds
-#' @param max_iters the maximum number of clustering iterations
-#' @param initializer the method of initialization. One of, optimal_init,
-#' quantile_init, kmeans++ and random. See details for more information
-#' @param fuzzy either TRUE or FALSE. If TRUE, then prediction probabilities
-#' will be calculated using the distance between observations and centroids
-#' @param verbose either TRUE or FALSE, indicating whether progress is printed
-#' during clustering.
-#' @param CENTROIDS a matrix of initial cluster centroids. The rows of the
-#' CENTROIDS matrix should be equal to the number of clusters and the columns
-#' should be equal to the columns of the data.
-#' @param tol a float number. If, in case of an iteration (iteration > 1 and
-#' iteration < max_iters) 'tol' is greater than the squared norm of the
-#' centroids, then kmeans has converged
-#' @param tol_optimal_init tolerance value for the 'optimal_init' initializer.
-#' The higher this value is, the far appart from each other the centroids are.
-#' @param seed integer value for random number generator (RNG)
-#'
-#' @return a list with the following attributes: clusters, fuzzy_clusters (if
-#' fuzzy = TRUE), centroids, total_SSE, best_initialization, WCSS_per_cluster,
-#' obs_per_cluster, between.SS_DIV_total.SS
-#' @keywords internal
-#' @export
-ClusterR_kmeans_fit <- function(data, clusters, num_init = 1, max_iters = 100,
- initializer = "kmeans++", fuzzy = FALSE,
- verbose = FALSE, CENTROIDS = NULL, tol = 1e-04,
- tol_optimal_init = 0.3, seed = 1) {
- res <- ClusterR::KMeans_rcpp(data, clusters)
- colnames(res$centroids) <- colnames(data)
- res
-}
+#' K-Means
+#'
+#' @description
+#'
+#' `k_means()` defines a model that fits clusters based on distances to a number
+#' of centers.
+#'
+#' @param mode A single character string for the type of model.
+#' The only possible value for this model is "partition".
+#' @param engine A single character string specifying what computational engine
+#' to use for fitting. Possible engines are listed below. The default for this
+#' model is `"stats"`.
+#' @param num_clusters Positive integer, number of clusters in model.
+#'
+#' @examples
+#' # show_engines("k_means")
+#'
+#' k_means()
+#' @export
+k_means <-
+ function(mode = "partition",
+ engine = "stats",
+ num_clusters = NULL) {
+ args <- list(
+ num_clusters = enquo(num_clusters)
+ )
+
+ new_cluster_spec(
+ "k_means",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = engine
+ )
+ }
+
+#' @export
+print.k_means <- function(x, ...) {
+ cat("K Means Cluster Specification (", x$mode, ")\n\n", sep = "")
+ model_printer(x, ...)
+
+ if (!is.null(x$method$fit$args)) {
+ cat("Model fit template:\n")
+ print(show_call(x))
+ }
+
+ invisible(x)
+}
+
+#' @export
+translate_tidyclust.k_means <- function(x, engine = x$engine, ...) {
+ x <- translate_tidyclust.default(x, engine, ...)
+ x
+}
+
+# ------------------------------------------------------------------------------
+
+#' @method update k_means
+#' @rdname tidyclust_update
+#' @export
+update.k_means <- function(object,
+ parameters = NULL,
+ num_clusters = NULL,
+ fresh = FALSE, ...) {
+
+ eng_args <- parsnip::update_engine_parameters(object$eng_args, ...)
+
+ if (!is.null(parameters)) {
+ parameters <- parsnip::check_final_param(parameters)
+ }
+ args <- list(
+ num_clusters = enquo(num_clusters)
+ )
+
+ args <- parsnip::update_main_parameters(args, parameters)
+
+ if (fresh) {
+ object$args <- args
+ object$eng_args <- eng_args
+ } else {
+ null_args <- map_lgl(args, null_value)
+ if (any(null_args))
+ args <- args[!null_args]
+ if (length(args) > 0)
+ object$args[names(args)] <- args
+ if (length(eng_args) > 0)
+ object$eng_args[names(eng_args)] <- eng_args
+ }
+
+ new_cluster_spec(
+ "k_means",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
+}
+
+# # ------------------------------------------------------------------------------
+
+check_args.k_means <- function(object) {
+
+ args <- lapply(object$args, rlang::eval_tidy)
+
+ if (all(is.numeric(args$num_clusters)) && any(args$num_clusters < 0))
+ rlang::abort("The number of centers should be >= 0.")
+
+ invisible(object)
+}
+
+# ------------------------------------------------------------------------------
+
+#' Simple Wrapper around ClusterR kmeans
+#'
+#' This wrapper runs `ClusterR::KMeans_rcpp` and adds column names to the
+#' `centroids` field.
+#'
+#' @param data matrix or data frame
+#' @param clusters the number of clusters
+#' @param num_init number of times the algorithm will be run with different
+#' centroid seeds
+#' @param max_iters the maximum number of clustering iterations
+#' @param initializer the method of initialization. One of, optimal_init,
+#' quantile_init, kmeans++ and random. See details for more information
+#' @param fuzzy either TRUE or FALSE. If TRUE, then prediction probabilities
+#' will be calculated using the distance between observations and centroids
+#' @param verbose either TRUE or FALSE, indicating whether progress is printed
+#' during clustering.
+#' @param CENTROIDS a matrix of initial cluster centroids. The rows of the
+#' CENTROIDS matrix should be equal to the number of clusters and the columns
+#' should be equal to the columns of the data.
+#' @param tol a float number. If, in case of an iteration (iteration > 1 and
+#' iteration < max_iters) 'tol' is greater than the squared norm of the
+#' centroids, then kmeans has converged
+#' @param tol_optimal_init tolerance value for the 'optimal_init' initializer.
+#' The higher this value is, the far appart from each other the centroids are.
+#' @param seed integer value for random number generator (RNG)
+#'
+#' @return a list with the following attributes: clusters, fuzzy_clusters (if
+#' fuzzy = TRUE), centroids, total_SSE, best_initialization, WCSS_per_cluster,
+#' obs_per_cluster, between.SS_DIV_total.SS
+#' @keywords internal
+#' @export
+ClusterR_kmeans_fit <- function(data, clusters, num_init = 1, max_iters = 100,
+ initializer = "kmeans++", fuzzy = FALSE,
+ verbose = FALSE, CENTROIDS = NULL, tol = 1e-04,
+ tol_optimal_init = 0.3, seed = 1) {
+ res <- ClusterR::KMeans_rcpp(data, clusters)
+ colnames(res$centroids) <- colnames(data)
+ res
+}
diff --git a/R/k_means_data.R b/R/k_means_data.R
index 4b43d1c5..35b2238c 100644
--- a/R/k_means_data.R
+++ b/R/k_means_data.R
@@ -1,114 +1,114 @@
-set_new_model_tidyclust("k_means")
-
-set_model_mode_tidyclust("k_means", "partition")
-
-# ------------------------------------------------------------------------------
-
-set_model_engine_tidyclust("k_means", "partition", "stats")
-set_dependency_tidyclust("k_means", "stats", "stats")
-
-set_fit_tidyclust(
- model = "k_means",
- eng = "stats",
- mode = "partition",
- value = list(
- interface = "matrix",
- protect = c("x", "centers"),
- func = c(pkg = "stats", fun = "kmeans"),
- defaults = list()
- )
-)
-
-set_encoding_tidyclust(
- model = "k_means",
- eng = "stats",
- mode = "partition",
- options = list(
- predictor_indicators = "traditional",
- compute_intercept = TRUE,
- remove_intercept = TRUE,
- allow_sparse_x = FALSE
- )
-)
-
-set_model_arg_tidyclust(
- model = "k_means",
- eng = "stats",
- tidyclust = "num_clusters",
- original = "centers",
- func = list(pkg = "tidyclust", fun = "num_clusters"),
- has_submodel = TRUE
-)
-
-set_pred_tidyclust(
- model = "k_means",
- eng = "stats",
- mode = "partition",
- type = "cluster",
- value = list(
- pre = NULL,
- post = NULL,
- func = c(fun = "stats_kmeans_predict"),
- args =
- list(
- object = rlang::expr(object$fit),
- new_data = rlang::expr(new_data)
- )
- )
-)
-
-# ------------------------------------------------------------------------------
-
-set_model_engine_tidyclust("k_means", "partition", "ClusterR")
-set_dependency_tidyclust("k_means", "ClusterR", "ClusterR")
-
-set_fit_tidyclust(
- model = "k_means",
- eng = "ClusterR",
- mode = "partition",
- value = list(
- interface = "matrix",
- data = c(x = "data"),
- protect = c("data", "clusters"),
- func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"),
- defaults = list()
- )
-)
-
-set_encoding_tidyclust(
- model = "k_means",
- eng = "ClusterR",
- mode = "partition",
- options = list(
- predictor_indicators = "traditional",
- compute_intercept = TRUE,
- remove_intercept = TRUE,
- allow_sparse_x = FALSE
- )
-)
-
-set_model_arg_tidyclust(
- model = "k_means",
- eng = "ClusterR",
- tidyclust = "num_clusters",
- original = "clusters",
- func = list(pkg = "tidyclust", fun = "num_clusters"),
- has_submodel = TRUE
-)
-
-set_pred_tidyclust(
- model = "k_means",
- eng = "ClusterR",
- mode = "partition",
- type = "cluster",
- value = list(
- pre = NULL,
- post = NULL,
- func = c(fun = "clusterR_kmeans_predict"),
- args =
- list(
- object = rlang::expr(object$fit),
- new_data = rlang::expr(new_data)
- )
- )
-)
+set_new_model_tidyclust("k_means")
+
+set_model_mode_tidyclust("k_means", "partition")
+
+# ------------------------------------------------------------------------------
+
+set_model_engine_tidyclust("k_means", "partition", "stats")
+set_dependency_tidyclust("k_means", "stats", "stats")
+
+set_fit_tidyclust(
+ model = "k_means",
+ eng = "stats",
+ mode = "partition",
+ value = list(
+ interface = "matrix",
+ protect = c("x", "centers"),
+ func = c(pkg = "stats", fun = "kmeans"),
+ defaults = list()
+ )
+)
+
+set_encoding_tidyclust(
+ model = "k_means",
+ eng = "stats",
+ mode = "partition",
+ options = list(
+ predictor_indicators = "traditional",
+ compute_intercept = TRUE,
+ remove_intercept = TRUE,
+ allow_sparse_x = FALSE
+ )
+)
+
+set_model_arg_tidyclust(
+ model = "k_means",
+ eng = "stats",
+ tidyclust = "num_clusters",
+ original = "centers",
+ func = list(pkg = "tidyclust", fun = "num_clusters"),
+ has_submodel = TRUE
+)
+
+set_pred_tidyclust(
+ model = "k_means",
+ eng = "stats",
+ mode = "partition",
+ type = "cluster",
+ value = list(
+ pre = NULL,
+ post = NULL,
+ func = c(fun = "stats_kmeans_predict"),
+ args =
+ list(
+ object = rlang::expr(object$fit),
+ new_data = rlang::expr(new_data)
+ )
+ )
+)
+
+# ------------------------------------------------------------------------------
+
+set_model_engine_tidyclust("k_means", "partition", "ClusterR")
+set_dependency_tidyclust("k_means", "ClusterR", "ClusterR")
+
+set_fit_tidyclust(
+ model = "k_means",
+ eng = "ClusterR",
+ mode = "partition",
+ value = list(
+ interface = "matrix",
+ data = c(x = "data"),
+ protect = c("data", "clusters"),
+ func = c(pkg = "tidyclust", fun = "ClusterR_kmeans_fit"),
+ defaults = list()
+ )
+)
+
+set_encoding_tidyclust(
+ model = "k_means",
+ eng = "ClusterR",
+ mode = "partition",
+ options = list(
+ predictor_indicators = "traditional",
+ compute_intercept = TRUE,
+ remove_intercept = TRUE,
+ allow_sparse_x = FALSE
+ )
+)
+
+set_model_arg_tidyclust(
+ model = "k_means",
+ eng = "ClusterR",
+ tidyclust = "num_clusters",
+ original = "clusters",
+ func = list(pkg = "tidyclust", fun = "num_clusters"),
+ has_submodel = TRUE
+)
+
+set_pred_tidyclust(
+ model = "k_means",
+ eng = "ClusterR",
+ mode = "partition",
+ type = "cluster",
+ value = list(
+ pre = NULL,
+ post = NULL,
+ func = c(fun = "clusterR_kmeans_predict"),
+ args =
+ list(
+ object = rlang::expr(object$fit),
+ new_data = rlang::expr(new_data)
+ )
+ )
+)
diff --git a/R/metric-silhouette.R b/R/metric-silhouette.R
index 6a78c4dd..ef138210 100644
--- a/R/metric-silhouette.R
+++ b/R/metric-silhouette.R
@@ -1,107 +1,107 @@
-#' Measures silhouettes between clusters
-#'
-#' @param object A fitted tidyclust model
-#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering.
-#' @param dists A distance matrix. Used if `new_data` is `NULL`.
-#' @param dist_fun A function for calculating distances between observations.
-#' Defaults to Euclidean distance on processed data.
-#'
-#' @return A tibble giving the silhouettes for each observation.
-#'
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' dists <- mtcars %>%
-#' as.matrix() %>%
-#' dist()
-#'
-#' silhouettes(kmeans_fit, dists = dists)
-#' @export
-silhouettes <- function(object, new_data = NULL, dists = NULL,
- dist_fun = Rfast::Dist) {
- preproc <- prep_data_dist(object, new_data, dists, dist_fun)
-
- clust_int <- as.integer(gsub("Cluster_", "", preproc$clusters))
-
- sil <- cluster::silhouette(clust_int, preproc$dists)
-
- sil %>%
- unclass() %>%
- tibble::as_tibble() %>%
- dplyr::mutate(
- cluster = factor(paste0("Cluster_", cluster)),
- neighbor = factor(paste0("Cluster_", neighbor)),
- sil_width = as.numeric(sil_width)
- )
-}
-
-#' Measures average silhouette across all observations
-#'
-#' @param object A fitted kmeans tidyclust model
-#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering.
-#' @param dists A distance matrix. Used if `new_data` is `NULL`.
-#' @param dist_fun A function for calculating distances between observations.
-#' Defaults to Euclidean distance on processed data.
-#' @param ... Other arguments passed to methods.
-#'
-#' @return A double; the average silhouette.
-#'
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' dists <- mtcars %>%
-#' as.matrix() %>%
-#' dist()
-#'
-#' avg_silhouette(kmeans_fit, dists = dists)
-#'
-#' avg_silhouette_vec(kmeans_fit, dists = dists)
-#' @export
-avg_silhouette <- function(object, ...) {
- UseMethod("avg_silhouette")
-}
-
-avg_silhouette <- new_cluster_metric(
- avg_silhouette,
- direction = "zero"
-)
-
-#' @export
-#' @rdname avg_silhouette
-avg_silhouette.cluster_fit <- function(object, new_data = NULL, dists = NULL,
- dist_fun = NULL, ...) {
- if (is.null(dist_fun)) {
- dist_fun <- Rfast::Dist
- }
-
- res <- avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
-
- tibble::tibble(
- .metric = "avg_silhouette",
- .estimator = "standard",
- .estimate = res
- )
-}
-
-#' @export
-#' @rdname avg_silhouette
-avg_silhouette.workflow <- avg_silhouette.cluster_fit
-
-#' @export
-#' @rdname avg_silhouette
-avg_silhouette_vec <- function(object, new_data = NULL, dists = NULL,
- dist_fun = Rfast::Dist, ...) {
- avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
-
-}
-
-avg_silhouette_impl <- function(object, new_data = NULL, dists = NULL,
- dist_fun = Rfast::Dist, ...) {
- mean(silhouettes(object, new_data, dists, dist_fun, ...)$sil_width)
-}
+#' Measures silhouettes between clusters
+#'
+#' @param object A fitted tidyclust model
+#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering.
+#' @param dists A distance matrix. Used if `new_data` is `NULL`.
+#' @param dist_fun A function for calculating distances between observations.
+#' Defaults to Euclidean distance on processed data.
+#'
+#' @return A tibble giving the silhouettes for each observation.
+#'
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' dists <- mtcars %>%
+#' as.matrix() %>%
+#' dist()
+#'
+#' silhouettes(kmeans_fit, dists = dists)
+#' @export
+silhouettes <- function(object, new_data = NULL, dists = NULL,
+ dist_fun = Rfast::Dist) {
+ preproc <- prep_data_dist(object, new_data, dists, dist_fun)
+
+ clust_int <- as.integer(gsub("Cluster_", "", preproc$clusters))
+
+ sil <- cluster::silhouette(clust_int, preproc$dists)
+
+ sil %>%
+ unclass() %>%
+ tibble::as_tibble() %>%
+ dplyr::mutate(
+ cluster = factor(paste0("Cluster_", cluster)),
+ neighbor = factor(paste0("Cluster_", neighbor)),
+ sil_width = as.numeric(sil_width)
+ )
+}
+
+#' Measures average silhouette across all observations
+#'
+#' @param object A fitted kmeans tidyclust model
+#' @param new_data A dataset to predict on. If `NULL`, uses trained clustering.
+#' @param dists A distance matrix. Used if `new_data` is `NULL`.
+#' @param dist_fun A function for calculating distances between observations.
+#' Defaults to Euclidean distance on processed data.
+#' @param ... Other arguments passed to methods.
+#'
+#' @return A double; the average silhouette.
+#'
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' dists <- mtcars %>%
+#' as.matrix() %>%
+#' dist()
+#'
+#' avg_silhouette(kmeans_fit, dists = dists)
+#'
+#' avg_silhouette_vec(kmeans_fit, dists = dists)
+#' @export
+avg_silhouette <- function(object, ...) {
+ UseMethod("avg_silhouette")
+}
+
+avg_silhouette <- new_cluster_metric(
+ avg_silhouette,
+ direction = "zero"
+)
+
+#' @export
+#' @rdname avg_silhouette
+avg_silhouette.cluster_fit <- function(object, new_data = NULL, dists = NULL,
+ dist_fun = NULL, ...) {
+ if (is.null(dist_fun)) {
+ dist_fun <- Rfast::Dist
+ }
+
+ res <- avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
+
+ tibble::tibble(
+ .metric = "avg_silhouette",
+ .estimator = "standard",
+ .estimate = res
+ )
+}
+
+#' @export
+#' @rdname avg_silhouette
+avg_silhouette.workflow <- avg_silhouette.cluster_fit
+
+#' @export
+#' @rdname avg_silhouette
+avg_silhouette_vec <- function(object, new_data = NULL, dists = NULL,
+ dist_fun = Rfast::Dist, ...) {
+ avg_silhouette_impl(object, new_data, dists, dist_fun, ...)
+
+}
+
+avg_silhouette_impl <- function(object, new_data = NULL, dists = NULL,
+ dist_fun = Rfast::Dist, ...) {
+ mean(silhouettes(object, new_data, dists, dist_fun, ...)$sil_width)
+}
diff --git a/R/predict.R b/R/predict.R
index a8cf691d..c6c898ce 100644
--- a/R/predict.R
+++ b/R/predict.R
@@ -1,144 +1,144 @@
-#' Model predictions
-#'
-#' Apply a model to create different types of predictions.
-#' `predict()` can be used for all types of models and uses the
-#' "type" argument for more specificity.
-#'
-#' @param object An object of class `cluster_fit`
-#' @param new_data A rectangular data object, such as a data frame.
-#' @param type A single character value or `NULL`. Possible values
-#' are "cluster", or "raw". When `NULL`, `predict()` will choose an
-#' appropriate value based on the model's mode.
-#' @param opts A list of optional arguments to the underlying
-#' predict function that will be used when `type = "raw"`. The
-#' list should not include options for the model object or the
-#' new data being predicted.
-#' @param ... Arguments to the underlying model's prediction
-#' function cannot be passed here (see `opts`).
-#' @details If "type" is not supplied to `predict()`, then a choice
-#' is made:
-#'
-#' * `type = "cluster"` for clustering models
-#'
-#' `predict()` is designed to provide a tidy result (see "Value"
-#' section below) in a tibble output format.
-#'
-#' @return With the exception of `type = "raw"`, the results of
-#' `predict.cluster_fit()` will be a tibble as many rows in the output
-#' as there are rows in `new_data` and the column names will be
-#' predictable.
-#'
-#' For clustering results the tibble will have a `.pred_cluster` column.
-#'
-#' Using `type = "raw"` with `predict.cluster_fit()` will return
-#' the unadulterated results of the prediction function.
-#'
-#' When the model fit failed and the error was captured, the
-#' `predict()` function will return the same structure as above but
-#' filled with missing values. This does not currently work for
-#' multivariate models.
-#'
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5) %>%
-#' set_engine("stats")
-#'
-#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-#'
-#' kmeans_fit %>%
-#' predict(new_data = mtcars)
-#' @method predict cluster_fit
-#' @export predict.cluster_fit
-#' @export
-predict.cluster_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
- if (inherits(object$fit, "try-error")) {
- rlang::warn("Model fit failed; cannot make predictions.")
- return(NULL)
- }
-
- check_installs(object$spec)
- load_libs(object$spec, quiet = TRUE)
-
- type <- check_pred_type(object, type)
-
- res <- switch(type,
- cluster = predict_cluster(object = object, new_data = new_data, ...),
- raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
-
- rlang::abort(glue::glue("I don't know about type = '{type}'"))
- )
-
- res <- switch(type,
- cluster = format_cluster(res),
- res
- )
- res
-}
-
-
-check_pred_type <- function(object, type, ...) {
- if (is.null(type)) {
- type <-
- switch(object$spec$mode,
- partition = "cluster",
- rlang::abort("`type` should be 'cluster'.")
- )
- }
- if (!(type %in% pred_types)) {
- rlang::abort(
- glue::glue(
- "`type` should be one of: ",
- glue::glue_collapse(pred_types, sep = ", ", last = " and ")
- )
- )
- }
- type
-}
-
-format_cluster <- function(x) {
- tibble::tibble(.pred_cluster = unname(x))
-}
-
-#' Prepare data based on parsnip encoding information
-#' @param object A parsnip model object
-#' @param new_data A data frame
-#' @return A data frame or matrix
-#' @keywords internal
-#' @export
-prepare_data <- function(object, new_data) {
- fit_interface <- object$spec$method$fit$interface
-
- pp_names <- names(object$preproc)
- if (any(pp_names == "terms") | any(pp_names == "x_var")) {
- # Translation code
- if (fit_interface == "formula") {
- new_data <- .convert_x_to_form_new(object$preproc, new_data)
- } else {
- new_data <- .convert_form_to_x_new(object$preproc, new_data)$x
- }
- }
-
- remove_intercept <-
- get_encoding_tidyclust(class(object$spec)[1]) %>%
- dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>%
- dplyr::pull(remove_intercept)
- if (remove_intercept & any(grepl("Intercept", names(new_data)))) {
- new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)"))
- }
-
- switch(fit_interface,
- none = new_data,
- data.frame = as.data.frame(new_data),
- matrix = as.matrix(new_data),
- new_data
- )
-}
-
-make_pred_call <- function(x) {
- if ("pkg" %in% names(x$func)) {
- cl <- rlang::call2(x$func["fun"], !!!x$args, .ns = x$func["pkg"])
- } else {
- cl <- rlang::call2(x$func["fun"], !!!x$args)
- }
-
- cl
-}
+#' Model predictions
+#'
+#' Apply a model to create different types of predictions.
+#' `predict()` can be used for all types of models and uses the
+#' "type" argument for more specificity.
+#'
+#' @param object An object of class `cluster_fit`
+#' @param new_data A rectangular data object, such as a data frame.
+#' @param type A single character value or `NULL`. Possible values
+#' are "cluster", or "raw". When `NULL`, `predict()` will choose an
+#' appropriate value based on the model's mode.
+#' @param opts A list of optional arguments to the underlying
+#' predict function that will be used when `type = "raw"`. The
+#' list should not include options for the model object or the
+#' new data being predicted.
+#' @param ... Arguments to the underlying model's prediction
+#' function cannot be passed here (see `opts`).
+#' @details If "type" is not supplied to `predict()`, then a choice
+#' is made:
+#'
+#' * `type = "cluster"` for clustering models
+#'
+#' `predict()` is designed to provide a tidy result (see "Value"
+#' section below) in a tibble output format.
+#'
+#' @return With the exception of `type = "raw"`, the results of
+#' `predict.cluster_fit()` will be a tibble as many rows in the output
+#' as there are rows in `new_data` and the column names will be
+#' predictable.
+#'
+#' For clustering results the tibble will have a `.pred_cluster` column.
+#'
+#' Using `type = "raw"` with `predict.cluster_fit()` will return
+#' the unadulterated results of the prediction function.
+#'
+#' When the model fit failed and the error was captured, the
+#' `predict()` function will return the same structure as above but
+#' filled with missing values. This does not currently work for
+#' multivariate models.
+#'
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5) %>%
+#' set_engine("stats")
+#'
+#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+#'
+#' kmeans_fit %>%
+#' predict(new_data = mtcars)
+#' @method predict cluster_fit
+#' @export predict.cluster_fit
+#' @export
+predict.cluster_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
+ if (inherits(object$fit, "try-error")) {
+ rlang::warn("Model fit failed; cannot make predictions.")
+ return(NULL)
+ }
+
+ check_installs(object$spec)
+ load_libs(object$spec, quiet = TRUE)
+
+ type <- check_pred_type(object, type)
+
+ res <- switch(type,
+ cluster = predict_cluster(object = object, new_data = new_data, ...),
+ raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
+
+ rlang::abort(glue::glue("I don't know about type = '{type}'"))
+ )
+
+ res <- switch(type,
+ cluster = format_cluster(res),
+ res
+ )
+ res
+}
+
+
+check_pred_type <- function(object, type, ...) {
+ if (is.null(type)) {
+ type <-
+ switch(object$spec$mode,
+ partition = "cluster",
+ rlang::abort("`type` should be 'cluster'.")
+ )
+ }
+ if (!(type %in% pred_types)) {
+ rlang::abort(
+ glue::glue(
+ "`type` should be one of: ",
+ glue::glue_collapse(pred_types, sep = ", ", last = " and ")
+ )
+ )
+ }
+ type
+}
+
+format_cluster <- function(x) {
+ tibble::tibble(.pred_cluster = unname(x))
+}
+
+#' Prepare data based on parsnip encoding information
+#' @param object A parsnip model object
+#' @param new_data A data frame
+#' @return A data frame or matrix
+#' @keywords internal
+#' @export
+prepare_data <- function(object, new_data) {
+ fit_interface <- object$spec$method$fit$interface
+
+ pp_names <- names(object$preproc)
+ if (any(pp_names == "terms") | any(pp_names == "x_var")) {
+ # Translation code
+ if (fit_interface == "formula") {
+ new_data <- .convert_x_to_form_new(object$preproc, new_data)
+ } else {
+ new_data <- .convert_form_to_x_new(object$preproc, new_data)$x
+ }
+ }
+
+ remove_intercept <-
+ get_encoding_tidyclust(class(object$spec)[1]) %>%
+ dplyr::filter(mode == object$spec$mode, engine == object$spec$engine) %>%
+ dplyr::pull(remove_intercept)
+ if (remove_intercept & any(grepl("Intercept", names(new_data)))) {
+ new_data <- new_data %>% dplyr::select(-dplyr::one_of("(Intercept)"))
+ }
+
+ switch(fit_interface,
+ none = new_data,
+ data.frame = as.data.frame(new_data),
+ matrix = as.matrix(new_data),
+ new_data
+ )
+}
+
+make_pred_call <- function(x) {
+ if ("pkg" %in% names(x$func)) {
+ cl <- rlang::call2(x$func["fun"], !!!x$args, .ns = x$func["pkg"])
+ } else {
+ cl <- rlang::call2(x$func["fun"], !!!x$args)
+ }
+
+ cl
+}
diff --git a/R/predict_helpers.R b/R/predict_helpers.R
index 55a00f8b..3e667bb7 100644
--- a/R/predict_helpers.R
+++ b/R/predict_helpers.R
@@ -1,101 +1,101 @@
-stats_kmeans_predict <- function(object, new_data) {
- reorder_clusts <- unique(object$cluster)
- res <- apply(flexclust::dist2(object$centers[reorder_clusts, , drop = FALSE], new_data), 2, which.min)
- res <- paste0("Cluster_", res)
- factor(res)
-}
-
-clusterR_kmeans_predict <- function(object, new_data) {
- reorder_clusts <- unique(object$clusters)
- res <- apply(flexclust::dist2(object$centroids[reorder_clusts, , drop = FALSE], new_data), 2, which.min)
- res <- paste0("Cluster_", res)
- factor(res)
-}
-
-stats_hier_clust_predict <- function(object, new_data) {
-
- linkage_method <- object$method
-
- new_data <- as.matrix(new_data)
-
- training_data <- as.matrix(attr(object, "training_data"))
- clusters <- extract_cluster_assignment(object)
-
- if (linkage_method %in% c("single", "complete", "average", "median")) {
-
- ## complete, single, average, and median linkage_methods are basically the same idea,
- ## just different summary distance to cluster
-
- cluster_dist_fun <- switch(linkage_method,
- "single" = min,
- "complete" = max,
- "average" = mean,
- "median" = median
- )
-
- # need this to be obs on rows, dist to new data on cols
- dists_new <- Rfast::dista(new_data, training_data, trans = TRUE)
-
- cluster_dists <- bind_cols(data.frame(dists_new), clusters) %>%
- group_by(.cluster) %>%
- summarize_all(cluster_dist_fun)
-
- pred_clusts_num <- cluster_dists %>%
- select(-.cluster) %>%
- purrr::map_dbl(which.min)
-
- } else if (linkage_method == "centroid") {
-
- ## Centroid linkage_method, dist to center
-
- cluster_centers <- extract_centroids(object) %>% select(-.cluster)
- dists_means <- Rfast::dista(new_data, cluster_centers)
-
- pred_clusts_num <- apply(dists_means, 1, which.min)
-
- } else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) {
-
- ## Ward linkage_method: lowest change in ESS
- ## dendrograms created from already-squared distances
- ## use Ward.D2 on these plain distances for Ward.D
-
- cluster_centers <- extract_centroids(object)
- n_clust <- nrow(cluster_centers)
- cluster_names <- cluster_centers[[1]]
- cluster_centers <- as.matrix(cluster_centers[, -1])
-
- d_means <- purrr::map(1:n_clust,
- ~t(t(training_data[clusters$.cluster == cluster_names[.x],]) - cluster_centers[.x, ]))
-
- n <- nrow(training_data)
-
- d_new_list <- purrr::map(1:nrow(new_data),
- function(new_obs) {
- purrr::map(1:n_clust,
- ~ t(t(training_data[clusters$.cluster == cluster_names[.x],])
- - new_data[new_obs,])
- )
- }
- )
-
- change_in_ess <- purrr::map(d_new_list,
- function(v) {
- purrr::map2_dbl(d_means, v,
- ~ sum((n*.x + .y)^2/(n+1)^2 - .x^2)
- )}
- )
-
- pred_clusts_num <- purrr::map_dbl(change_in_ess, which.min)
-
- } else {
-
- stop(glue::glue("linkage_method {linkage_method} is not supported for prediction."))
-
- }
-
-
- pred_clusts <- unique(clusters$.cluster)[pred_clusts_num]
-
- return(factor(pred_clusts))
-
-}
+stats_kmeans_predict <- function(object, new_data) {
+ reorder_clusts <- unique(object$cluster)
+ res <- apply(flexclust::dist2(object$centers[reorder_clusts, , drop = FALSE], new_data), 2, which.min)
+ res <- paste0("Cluster_", res)
+ factor(res)
+}
+
+clusterR_kmeans_predict <- function(object, new_data) {
+ reorder_clusts <- unique(object$clusters)
+ res <- apply(flexclust::dist2(object$centroids[reorder_clusts, , drop = FALSE], new_data), 2, which.min)
+ res <- paste0("Cluster_", res)
+ factor(res)
+}
+
+stats_hier_clust_predict <- function(object, new_data) {
+
+ linkage_method <- object$method
+
+ new_data <- as.matrix(new_data)
+
+ training_data <- as.matrix(attr(object, "training_data"))
+ clusters <- extract_cluster_assignment(object)
+
+ if (linkage_method %in% c("single", "complete", "average", "median")) {
+
+ ## complete, single, average, and median linkage_methods are basically the same idea,
+ ## just different summary distance to cluster
+
+ cluster_dist_fun <- switch(linkage_method,
+ "single" = min,
+ "complete" = max,
+ "average" = mean,
+ "median" = median
+ )
+
+ # need this to be obs on rows, dist to new data on cols
+ dists_new <- Rfast::dista(new_data, training_data, trans = TRUE)
+
+ cluster_dists <- bind_cols(data.frame(dists_new), clusters) %>%
+ group_by(.cluster) %>%
+ summarize_all(cluster_dist_fun)
+
+ pred_clusts_num <- cluster_dists %>%
+ select(-.cluster) %>%
+ purrr::map_dbl(which.min)
+
+ } else if (linkage_method == "centroid") {
+
+ ## Centroid linkage_method, dist to center
+
+ cluster_centers <- extract_centroids(object) %>% select(-.cluster)
+ dists_means <- Rfast::dista(new_data, cluster_centers)
+
+ pred_clusts_num <- apply(dists_means, 1, which.min)
+
+ } else if (linkage_method %in% c("ward.D", "ward", "ward.D2")) {
+
+ ## Ward linkage_method: lowest change in ESS
+ ## dendrograms created from already-squared distances
+ ## use Ward.D2 on these plain distances for Ward.D
+
+ cluster_centers <- extract_centroids(object)
+ n_clust <- nrow(cluster_centers)
+ cluster_names <- cluster_centers[[1]]
+ cluster_centers <- as.matrix(cluster_centers[, -1])
+
+ d_means <- purrr::map(1:n_clust,
+ ~t(t(training_data[clusters$.cluster == cluster_names[.x],]) - cluster_centers[.x, ]))
+
+ n <- nrow(training_data)
+
+ d_new_list <- purrr::map(1:nrow(new_data),
+ function(new_obs) {
+ purrr::map(1:n_clust,
+ ~ t(t(training_data[clusters$.cluster == cluster_names[.x],])
+ - new_data[new_obs,])
+ )
+ }
+ )
+
+ change_in_ess <- purrr::map(d_new_list,
+ function(v) {
+ purrr::map2_dbl(d_means, v,
+ ~ sum((n*.x + .y)^2/(n+1)^2 - .x^2)
+ )}
+ )
+
+ pred_clusts_num <- purrr::map_dbl(change_in_ess, which.min)
+
+ } else {
+
+ stop(glue::glue("linkage_method {linkage_method} is not supported for prediction."))
+
+ }
+
+
+ pred_clusts <- unique(clusters$.cluster)[pred_clusts_num]
+
+ return(factor(pred_clusts))
+
+}
diff --git a/R/reexports.R b/R/reexports.R
index 6360ac59..356c98ff 100644
--- a/R/reexports.R
+++ b/R/reexports.R
@@ -1,72 +1,72 @@
-#' @importFrom magrittr %>%
-#' @export
-magrittr::`%>%`
-
-#' @importFrom generics fit
-#' @export
-generics::fit
-
-#' @importFrom generics tidy
-#' @export
-generics::tidy
-
-#' @importFrom generics glance
-#' @export
-generics::glance
-
-#' @importFrom generics augment
-#' @export
-generics::augment
-
-#' @importFrom generics fit_xy
-#' @export
-generics::fit_xy
-
-#' @importFrom hardhat extract_parameter_set_dials
-#' @export
-hardhat::extract_parameter_set_dials
-
-#' @importFrom hardhat tune
-#' @export
-hardhat::tune
-
-#' @importFrom hardhat extract_spec_parsnip
-#' @export
-hardhat::extract_spec_parsnip
-
-#' @importFrom generics min_grid
-#' @export
-generics::min_grid
-
-#' @importFrom hardhat extract_preprocessor
-#' @export
-hardhat::extract_preprocessor
-
-#' @importFrom hardhat extract_fit_parsnip
-#' @export
-hardhat::extract_fit_parsnip
-
-#' @importFrom tune load_pkgs
-#' @export
-tune::load_pkgs
-
-#' @importFrom generics required_pkgs
-#' @export
-generics::required_pkgs
-
-#' @importFrom parsnip predict_raw
-#' @export
-parsnip::predict_raw
-
-#' @importFrom parsnip set_args
-#' @export
-parsnip::set_args
-
-#' @importFrom parsnip set_engine
-#' @export
-parsnip::set_engine
-
-#' @importFrom parsnip set_mode
-#' @export
-parsnip::set_mode
-
+#' @importFrom magrittr %>%
+#' @export
+magrittr::`%>%`
+
+#' @importFrom generics fit
+#' @export
+generics::fit
+
+#' @importFrom generics tidy
+#' @export
+generics::tidy
+
+#' @importFrom generics glance
+#' @export
+generics::glance
+
+#' @importFrom generics augment
+#' @export
+generics::augment
+
+#' @importFrom generics fit_xy
+#' @export
+generics::fit_xy
+
+#' @importFrom hardhat extract_parameter_set_dials
+#' @export
+hardhat::extract_parameter_set_dials
+
+#' @importFrom hardhat tune
+#' @export
+hardhat::tune
+
+#' @importFrom hardhat extract_spec_parsnip
+#' @export
+hardhat::extract_spec_parsnip
+
+#' @importFrom generics min_grid
+#' @export
+generics::min_grid
+
+#' @importFrom hardhat extract_preprocessor
+#' @export
+hardhat::extract_preprocessor
+
+#' @importFrom hardhat extract_fit_parsnip
+#' @export
+hardhat::extract_fit_parsnip
+
+#' @importFrom tune load_pkgs
+#' @export
+tune::load_pkgs
+
+#' @importFrom generics required_pkgs
+#' @export
+generics::required_pkgs
+
+#' @importFrom parsnip predict_raw
+#' @export
+parsnip::predict_raw
+
+#' @importFrom parsnip set_args
+#' @export
+parsnip::set_args
+
+#' @importFrom parsnip set_engine
+#' @export
+parsnip::set_engine
+
+#' @importFrom parsnip set_mode
+#' @export
+parsnip::set_mode
+
diff --git a/R/translate.R b/R/translate.R
index 6a534a64..2c4d6598 100644
--- a/R/translate.R
+++ b/R/translate.R
@@ -1,151 +1,151 @@
-#' Resolve a Model Specification for a Computational Engine
-#'
-#' `translate_tidyclust()` will translate_tidyclust a model specification into a code
-#' object that is specific to a particular engine (e.g. R package).
-#' It translate_tidyclusts generic parameters to their counterparts.
-#'
-#' @param x A model specification.
-#' @param engine The computational engine for the model (see `?set_engine`).
-#' @param ... Not currently used.
-#' @details
-#' `translate_tidyclust()` produces a _template_ call that lacks the specific
-#' argument values (such as `data`, etc). These are filled in once
-#' `fit()` is called with the specifics of the data for the model.
-#' The call may also include `tune()` arguments if these are in
-#' the specification. To handle the `tune()` arguments, you need to use the
-#' [tune package](https://tune.tidymodels.org/). For more information
-#' see
-#'
-#' It does contain the resolved argument names that are specific to
-#' the model fitting function/engine.
-#'
-#' This function can be useful when you need to understand how
-#' `tidyclust` goes from a generic model specific to a model fitting
-#' function.
-#'
-#' **Note**: this function is used internally and users should only use it
-#' to understand what the underlying syntax would be. It should not be used
-#' to modify the cluster specification.
-#'
-#' @export
-translate_tidyclust <- function(x, ...) {
- UseMethod("translate_tidyclust")
-}
-
-#' @rdname translate_tidyclust
-#' @export
-#' @export translate_tidyclust.default
-translate_tidyclust.default <- function(x, engine = x$engine, ...) {
- check_empty_ellipse_tidyclust(...)
- if (is.null(engine)) {
- rlang::abort("Please set an engine.")
- }
-
- mod_name <- specific_model(x)
-
- x$engine <- engine
- if (x$mode == "unknown") {
- rlang::abort("Model code depends on the mode; please specify one.")
- }
-
- check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
-
- if (is.null(x$method)) {
- x$method <- get_cluster_spec(mod_name, x$mode, engine)
- }
-
- arg_key <- get_args(mod_name, engine)
-
- # deharmonize primary arguments
- actual_args <- deharmonize(x$args, arg_key)
-
- # check secondary arguments to see if they are in the final
- # expression unless there are dots, warn if protected args are
- # being altered
- x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original)
-
- # keep only modified args
- modifed_args <- !map_lgl(actual_args, null_value)
- actual_args <- actual_args[modifed_args]
-
- # look for defaults if not modified in other
- if (length(x$method$fit$defaults) > 0) {
- in_other <- names(x$method$fit$defaults) %in% names(x$eng_args)
- x$defaults <- x$method$fit$defaults[!in_other]
- }
-
- # combine primary, eng_args, and defaults
- protected <- lapply(x$method$fit$protect, function(x) rlang::expr(missing_arg()))
- names(protected) <- x$method$fit$protect
-
- x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults)
-
- x
-}
-
-# ------------------------------------------------------------------------------
-# new code for revised model data structures
-
-get_cluster_spec <- function(model, mode, engine) {
- m_env <- get_model_env_tidyclust()
- env_obj <- rlang::env_names(m_env)
- env_obj <- grep(model, env_obj, value = TRUE)
-
- res <- list()
- res$libs <-
- rlang::env_get(m_env, paste0(model, "_pkgs")) %>%
- dplyr::filter(engine == !!engine) %>%
- .[["pkg"]] %>%
- .[[1]]
-
- res$fit <-
- rlang::env_get(m_env, paste0(model, "_fit")) %>%
- dplyr::filter(mode == !!mode & engine == !!engine) %>%
- dplyr::pull(value) %>%
- .[[1]]
-
- pred_code <-
- rlang::env_get(m_env, paste0(model, "_predict")) %>%
- dplyr::filter(mode == !!mode & engine == !!engine) %>%
- dplyr::select(-engine, -mode)
-
- res$pred <- pred_code[["value"]]
- names(res$pred) <- pred_code$type
-
- res
-}
-
-get_args <- function(model, engine) {
- m_env <- get_model_env_tidyclust()
- rlang::env_get(m_env, paste0(model, "_args")) %>%
- dplyr::filter(engine == !!engine) %>%
- dplyr::select(-engine)
-}
-
-# to replace harmonize
-deharmonize <- function(args, key) {
- if (length(args) == 0) {
- return(args)
- }
- parsn <- tibble::tibble(tidyclust = names(args), order = seq_along(args))
- merged <-
- dplyr::left_join(parsn, key, by = "tidyclust") %>%
- dplyr::arrange(order)
- # TODO correct for bad merge?
-
- names(args) <- merged$original
- args[!is.na(merged$original)]
-}
-
-#' Check to ensure that ellipses are empty
-#' @param ... Extra arguments.
-#' @return If an error is not thrown (from non-empty ellipses), a NULL list.
-#' @keywords internal
-#' @export
-check_empty_ellipse_tidyclust <- function(...) {
- terms <- quos(...)
- if (!rlang::is_empty(terms)) {
- rlang::abort("Please pass other arguments to the model function via `set_engine()`.")
- }
- terms
-}
+#' Resolve a Model Specification for a Computational Engine
+#'
+#' `translate_tidyclust()` will translate_tidyclust a model specification into a code
+#' object that is specific to a particular engine (e.g. R package).
+#' It translate_tidyclusts generic parameters to their counterparts.
+#'
+#' @param x A model specification.
+#' @param engine The computational engine for the model (see `?set_engine`).
+#' @param ... Not currently used.
+#' @details
+#' `translate_tidyclust()` produces a _template_ call that lacks the specific
+#' argument values (such as `data`, etc). These are filled in once
+#' `fit()` is called with the specifics of the data for the model.
+#' The call may also include `tune()` arguments if these are in
+#' the specification. To handle the `tune()` arguments, you need to use the
+#' [tune package](https://tune.tidymodels.org/). For more information
+#' see
+#'
+#' It does contain the resolved argument names that are specific to
+#' the model fitting function/engine.
+#'
+#' This function can be useful when you need to understand how
+#' `tidyclust` goes from a generic model specific to a model fitting
+#' function.
+#'
+#' **Note**: this function is used internally and users should only use it
+#' to understand what the underlying syntax would be. It should not be used
+#' to modify the cluster specification.
+#'
+#' @export
+translate_tidyclust <- function(x, ...) {
+ UseMethod("translate_tidyclust")
+}
+
+#' @rdname translate_tidyclust
+#' @export
+#' @export translate_tidyclust.default
+translate_tidyclust.default <- function(x, engine = x$engine, ...) {
+ check_empty_ellipse_tidyclust(...)
+ if (is.null(engine)) {
+ rlang::abort("Please set an engine.")
+ }
+
+ mod_name <- specific_model(x)
+
+ x$engine <- engine
+ if (x$mode == "unknown") {
+ rlang::abort("Model code depends on the mode; please specify one.")
+ }
+
+ check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
+
+ if (is.null(x$method)) {
+ x$method <- get_cluster_spec(mod_name, x$mode, engine)
+ }
+
+ arg_key <- get_args(mod_name, engine)
+
+ # deharmonize primary arguments
+ actual_args <- deharmonize(x$args, arg_key)
+
+ # check secondary arguments to see if they are in the final
+ # expression unless there are dots, warn if protected args are
+ # being altered
+ x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original)
+
+ # keep only modified args
+ modifed_args <- !map_lgl(actual_args, null_value)
+ actual_args <- actual_args[modifed_args]
+
+ # look for defaults if not modified in other
+ if (length(x$method$fit$defaults) > 0) {
+ in_other <- names(x$method$fit$defaults) %in% names(x$eng_args)
+ x$defaults <- x$method$fit$defaults[!in_other]
+ }
+
+ # combine primary, eng_args, and defaults
+ protected <- lapply(x$method$fit$protect, function(x) rlang::expr(missing_arg()))
+ names(protected) <- x$method$fit$protect
+
+ x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults)
+
+ x
+}
+
+# ------------------------------------------------------------------------------
+# new code for revised model data structures
+
+get_cluster_spec <- function(model, mode, engine) {
+ m_env <- get_model_env_tidyclust()
+ env_obj <- rlang::env_names(m_env)
+ env_obj <- grep(model, env_obj, value = TRUE)
+
+ res <- list()
+ res$libs <-
+ rlang::env_get(m_env, paste0(model, "_pkgs")) %>%
+ dplyr::filter(engine == !!engine) %>%
+ .[["pkg"]] %>%
+ .[[1]]
+
+ res$fit <-
+ rlang::env_get(m_env, paste0(model, "_fit")) %>%
+ dplyr::filter(mode == !!mode & engine == !!engine) %>%
+ dplyr::pull(value) %>%
+ .[[1]]
+
+ pred_code <-
+ rlang::env_get(m_env, paste0(model, "_predict")) %>%
+ dplyr::filter(mode == !!mode & engine == !!engine) %>%
+ dplyr::select(-engine, -mode)
+
+ res$pred <- pred_code[["value"]]
+ names(res$pred) <- pred_code$type
+
+ res
+}
+
+get_args <- function(model, engine) {
+ m_env <- get_model_env_tidyclust()
+ rlang::env_get(m_env, paste0(model, "_args")) %>%
+ dplyr::filter(engine == !!engine) %>%
+ dplyr::select(-engine)
+}
+
+# to replace harmonize
+deharmonize <- function(args, key) {
+ if (length(args) == 0) {
+ return(args)
+ }
+ parsn <- tibble::tibble(tidyclust = names(args), order = seq_along(args))
+ merged <-
+ dplyr::left_join(parsn, key, by = "tidyclust") %>%
+ dplyr::arrange(order)
+ # TODO correct for bad merge?
+
+ names(args) <- merged$original
+ args[!is.na(merged$original)]
+}
+
+#' Check to ensure that ellipses are empty
+#' @param ... Extra arguments.
+#' @return If an error is not thrown (from non-empty ellipses), a NULL list.
+#' @keywords internal
+#' @export
+check_empty_ellipse_tidyclust <- function(...) {
+ terms <- quos(...)
+ if (!rlang::is_empty(terms)) {
+ rlang::abort("Please pass other arguments to the model function via `set_engine()`.")
+ }
+ terms
+}
diff --git a/R/tunable.R b/R/tunable.R
index 6b7a9335..41a4bc1f 100644
--- a/R/tunable.R
+++ b/R/tunable.R
@@ -1,81 +1,81 @@
-# Lazily registered in .onLoad()
-# Unit tests are in extratests
-# nocov start
-tunable_cluster_spec <- function(x, ...) {
- mod_env <- rlang::ns_env("tidyclust")$tidyclust
-
- if (is.null(x$engine)) {
- abort("Please declare an engine first using `set_engine()`.", call. = FALSE)
- }
-
- arg_name <- paste0(mod_type(x), "_args")
- if (!(any(arg_name == names(mod_env)))) {
- abort(
- paste(
- "The `tidyclust` model database doesn't know about the arguments for ",
- "model `", mod_type(x), "`. Was it registered?",
- sep = ""
- ),
- call. = FALSE
- )
- }
-
- arg_vals <-
- mod_env[[arg_name]] %>%
- dplyr::filter(engine == x$engine) %>%
- dplyr::select(name = tidyclust, call_info = func) %>%
- dplyr::full_join(
- tibble::tibble(name = c(names(x$args), names(x$eng_args))),
- by = "name"
- ) %>%
- dplyr::mutate(
- source = "cluster_spec",
- component = mod_type(x),
- component_id = dplyr::if_else(name %in% names(x$args), "main", "engine")
- )
-
- if (nrow(arg_vals) > 0) {
- has_info <- map_lgl(arg_vals$call_info, is.null)
- rm_list <- !(has_info & (arg_vals$component_id == "main"))
-
- arg_vals <- arg_vals[rm_list, ]
- }
- arg_vals %>% dplyr::select(name, call_info, source, component, component_id)
-}
-
-mod_type <- function(.mod) class(.mod)[class(.mod) != "cluster_spec"][1]
-
-add_engine_parameters <- function(pset, engines) {
- is_engine_param <- pset$name %in% engines$name
- if (any(is_engine_param)) {
- engine_names <- pset$name[is_engine_param]
- pset <- pset[!is_engine_param, ]
- pset <-
- dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name))
- }
- pset
-}
-
-# Lazily registered in .onLoad()
-tunable_k_means <- function(x, ...) {
- res <- NextMethod()
- if (x$engine == "stats") {
- res <- add_engine_parameters(res, stats_k_means_engine_args)
- }
- res
-}
-
-stats_k_means_engine_args <-
- tibble::tibble(
- name = c(
- "centers"
- ),
- call_info = list(
- list(pkg = "tidyclust", fun = "num_clusters")
- ),
- source = "cluster_spec",
- component = "k_means",
- component_id = "engine"
- )
-
-# nocov end
+# Lazily registered in .onLoad()
+# Unit tests are in extratests
+# nocov start
+tunable_cluster_spec <- function(x, ...) {
+ mod_env <- rlang::ns_env("tidyclust")$tidyclust
+
+ if (is.null(x$engine)) {
+ abort("Please declare an engine first using `set_engine()`.", call. = FALSE)
+ }
+
+ arg_name <- paste0(mod_type(x), "_args")
+ if (!(any(arg_name == names(mod_env)))) {
+ abort(
+ paste(
+ "The `tidyclust` model database doesn't know about the arguments for ",
+ "model `", mod_type(x), "`. Was it registered?",
+ sep = ""
+ ),
+ call. = FALSE
+ )
+ }
+
+ arg_vals <-
+ mod_env[[arg_name]] %>%
+ dplyr::filter(engine == x$engine) %>%
+ dplyr::select(name = tidyclust, call_info = func) %>%
+ dplyr::full_join(
+ tibble::tibble(name = c(names(x$args), names(x$eng_args))),
+ by = "name"
+ ) %>%
+ dplyr::mutate(
+ source = "cluster_spec",
+ component = mod_type(x),
+ component_id = dplyr::if_else(name %in% names(x$args), "main", "engine")
+ )
+
+ if (nrow(arg_vals) > 0) {
+ has_info <- map_lgl(arg_vals$call_info, is.null)
+ rm_list <- !(has_info & (arg_vals$component_id == "main"))
+
+ arg_vals <- arg_vals[rm_list, ]
+ }
+ arg_vals %>% dplyr::select(name, call_info, source, component, component_id)
+}
+
+mod_type <- function(.mod) class(.mod)[class(.mod) != "cluster_spec"][1]
+
+add_engine_parameters <- function(pset, engines) {
+ is_engine_param <- pset$name %in% engines$name
+ if (any(is_engine_param)) {
+ engine_names <- pset$name[is_engine_param]
+ pset <- pset[!is_engine_param, ]
+ pset <-
+ dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name))
+ }
+ pset
+}
+
+# Lazily registered in .onLoad()
+tunable_k_means <- function(x, ...) {
+ res <- NextMethod()
+ if (x$engine == "stats") {
+ res <- add_engine_parameters(res, stats_k_means_engine_args)
+ }
+ res
+}
+
+stats_k_means_engine_args <-
+ tibble::tibble(
+ name = c(
+ "centers"
+ ),
+ call_info = list(
+ list(pkg = "tidyclust", fun = "num_clusters")
+ ),
+ source = "cluster_spec",
+ component = "k_means",
+ component_id = "engine"
+ )
+
+# nocov end
diff --git a/R/update.R b/R/update.R
index f343035e..0e4d3d5f 100644
--- a/R/update.R
+++ b/R/update.R
@@ -1,28 +1,28 @@
-#' Update a cluster specification
-#'
-#' @description
-#' If parameters of a cluster specification need to be modified, `update()` can
-#' be used in lieu of recreating the object from scratch.
-#'
-#' @inheritParams k_means
-#' @param object A cluster specification.
-#' @param parameters A 1-row tibble or named list with _main_
-#' parameters to update. Use **either** `parameters` **or** the main arguments
-#' directly when updating. If the main arguments are used,
-#' these will supersede the values in `parameters`. Also, using
-#' engine arguments in this object will result in an error.
-#' @param ... Not used for `update()`.
-#' @param fresh A logical for whether the arguments should be
-#' modified in-place or replaced wholesale.
-#' @return An updated cluster specification.
-#' @name tidyclust_update
-#' @examples
-#' kmeans_spec <- k_means(num_clusters = 5)
-#' kmeans_spec
-#' update(kmeans_spec, num_clusters = 1)
-#' update(kmeans_spec, num_clusters = 1, fresh = TRUE)
-#'
-#' param_values <- tibble::tibble(num_clusters = 10)
-#'
-#' kmeans_spec %>% update(param_values)
-NULL
+#' Update a cluster specification
+#'
+#' @description
+#' If parameters of a cluster specification need to be modified, `update()` can
+#' be used in lieu of recreating the object from scratch.
+#'
+#' @inheritParams k_means
+#' @param object A cluster specification.
+#' @param parameters A 1-row tibble or named list with _main_
+#' parameters to update. Use **either** `parameters` **or** the main arguments
+#' directly when updating. If the main arguments are used,
+#' these will supersede the values in `parameters`. Also, using
+#' engine arguments in this object will result in an error.
+#' @param ... Not used for `update()`.
+#' @param fresh A logical for whether the arguments should be
+#' modified in-place or replaced wholesale.
+#' @return An updated cluster specification.
+#' @name tidyclust_update
+#' @examples
+#' kmeans_spec <- k_means(num_clusters = 5)
+#' kmeans_spec
+#' update(kmeans_spec, num_clusters = 1)
+#' update(kmeans_spec, num_clusters = 1, fresh = TRUE)
+#'
+#' param_values <- tibble::tibble(num_clusters = 10)
+#'
+#' kmeans_spec %>% update(param_values)
+NULL
diff --git a/README.Rmd b/README.Rmd
index 5ec8f3f7..8761424c 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -1,74 +1,74 @@
----
-output: github_document
----
-
-
-
-```{r, include = FALSE}
-knitr::opts_chunk$set(
- collapse = TRUE,
- comment = "#>",
- fig.path = "man/figures/README-",
- out.width = "100%"
-)
-```
-
-# tidyclust
-
-
-[data:image/s3,"s3://crabby-images/170c1/170c1719d9a3dd608e89fef473f0bbd0622f8ba0" alt="Codecov test coverage"](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main)
-[data:image/s3,"s3://crabby-images/8067c/8067c6e10b1c91450b8b1e7ef289401c35a91ea5" alt="R-CMD-check"](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml)
-
-
-The goal of tidyclust is to provide a tidy, unified interface to clustering models. The packages is closely modeled after the [parsnip](https://parsnip.tidymodels.org/) package.
-
-## Installation
-
-You can install the development version of tidyclust from [GitHub](https://github.com/) with:
-
-``` r
-# install.packages("devtools")
-devtools::install_github("EmilHvitfeldt/tidyclust")
-```
-
-Please note that this package currently requires a [branch of the workflows](https://github.com/tidymodels/workflows/tree/tidyclust) package to work. Use with caution.
-
-## Example
-
-The first thing you do is to create a `cluster specification`. For this example we are creating a K-means model, using the `stats` engine.
-
-```{r}
-library(tidyclust)
-
-kmeans_spec <- k_means(num_clusters = 3) %>%
- set_engine("stats")
-
-kmeans_spec
-```
-
-This specification can then be fit using data.
-
-```{r}
-kmeans_spec_fit <- kmeans_spec %>%
- fit(~., data = mtcars)
-kmeans_spec_fit
-```
-
-Once you have a fitted tidyclust object, you can do a number of things. `predict()` returns the cluster a new observation belongs to
-
-```{r}
-predict(kmeans_spec_fit, mtcars[1:4, ])
-```
-
-`extract_cluster_assignment()` returns the cluster assignments of the training observations
-
-```{r}
-extract_cluster_assignment(kmeans_spec_fit)
-```
-
-and `extract_clusters()` returns the locations of the clusters
-
-```{r}
-extract_centroids(kmeans_spec_fit)
-```
-
+---
+output: github_document
+---
+
+
+
+```{r, include = FALSE}
+knitr::opts_chunk$set(
+ collapse = TRUE,
+ comment = "#>",
+ fig.path = "man/figures/README-",
+ out.width = "100%"
+)
+```
+
+# tidyclust
+
+
+[data:image/s3,"s3://crabby-images/170c1/170c1719d9a3dd608e89fef473f0bbd0622f8ba0" alt="Codecov test coverage"](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main)
+[data:image/s3,"s3://crabby-images/8067c/8067c6e10b1c91450b8b1e7ef289401c35a91ea5" alt="R-CMD-check"](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml)
+
+
+The goal of tidyclust is to provide a tidy, unified interface to clustering models. The packages is closely modeled after the [parsnip](https://parsnip.tidymodels.org/) package.
+
+## Installation
+
+You can install the development version of tidyclust from [GitHub](https://github.com/) with:
+
+``` r
+# install.packages("devtools")
+devtools::install_github("EmilHvitfeldt/tidyclust")
+```
+
+Please note that this package currently requires a [branch of the workflows](https://github.com/tidymodels/workflows/tree/tidyclust) package to work. Use with caution.
+
+## Example
+
+The first thing you do is to create a `cluster specification`. For this example we are creating a K-means model, using the `stats` engine.
+
+```{r}
+library(tidyclust)
+
+kmeans_spec <- k_means(num_clusters = 3) %>%
+ set_engine("stats")
+
+kmeans_spec
+```
+
+This specification can then be fit using data.
+
+```{r}
+kmeans_spec_fit <- kmeans_spec %>%
+ fit(~., data = mtcars)
+kmeans_spec_fit
+```
+
+Once you have a fitted tidyclust object, you can do a number of things. `predict()` returns the cluster a new observation belongs to
+
+```{r}
+predict(kmeans_spec_fit, mtcars[1:4, ])
+```
+
+`extract_cluster_assignment()` returns the cluster assignments of the training observations
+
+```{r}
+extract_cluster_assignment(kmeans_spec_fit)
+```
+
+and `extract_clusters()` returns the locations of the clusters
+
+```{r}
+extract_centroids(kmeans_spec_fit)
+```
+
diff --git a/README.md b/README.md
index 50ce6989..98ce0a0d 100644
--- a/README.md
+++ b/README.md
@@ -1,145 +1,145 @@
-
-
-
-# tidyclust
-
-
-
-[data:image/s3,"s3://crabby-images/170c1/170c1719d9a3dd608e89fef473f0bbd0622f8ba0" alt="Codecov test
-coverage"](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main)
-[data:image/s3,"s3://crabby-images/8067c/8067c6e10b1c91450b8b1e7ef289401c35a91ea5" alt="R-CMD-check"](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml)
-
-
-The goal of tidyclust is to provide a tidy, unified interface to
-clustering models. The packages is closely modeled after the
-[parsnip](https://parsnip.tidymodels.org/) package.
-
-## Installation
-
-You can install the development version of tidyclust from
-[GitHub](https://github.com/) with:
-
-``` r
-# install.packages("devtools")
-devtools::install_github("EmilHvitfeldt/tidyclust")
-```
-
-Please note that this package currently requires a [branch of the
-workflows](https://github.com/tidymodels/workflows/tree/tidyclust)
-package to work. Use with caution.
-
-## Example
-
-The first thing you do is to create a `cluster specification`. For this
-example we are creating a K-means model, using the `stats` engine.
-
-``` r
-library(tidyclust)
-
-kmeans_spec <- k_means(num_clusters = 3) %>%
- set_engine("stats")
-
-kmeans_spec
-#> K Means Cluster Specification (partition)
-#>
-#> Main Arguments:
-#> num_clusters = 3
-#>
-#> Computational engine: stats
-```
-
-This specification can then be fit using data.
-
-``` r
-kmeans_spec_fit <- kmeans_spec %>%
- fit(~., data = mtcars)
-kmeans_spec_fit
-#> tidyclust cluster object
-#>
-#> K-means clustering with 3 clusters of sizes 9, 16, 7
-#>
-#> Cluster means:
-#> mpg cyl disp hp drat wt qsec vs
-#> 1 14.64444 8.000000 388.2222 232.1111 3.343333 4.161556 16.40444 0.0000000
-#> 2 24.50000 4.625000 122.2937 96.8750 4.002500 2.518000 18.54312 0.7500000
-#> 3 17.01429 7.428571 276.0571 150.7143 2.994286 3.601429 18.11857 0.2857143
-#> am gear carb
-#> 1 0.2222222 3.444444 4.000000
-#> 2 0.6875000 4.125000 2.437500
-#> 3 0.0000000 3.000000 2.142857
-#>
-#> Clustering vector:
-#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive
-#> 2 2 2 3
-#> Hornet Sportabout Valiant Duster 360 Merc 240D
-#> 1 3 1 2
-#> Merc 230 Merc 280 Merc 280C Merc 450SE
-#> 2 2 2 3
-#> Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental
-#> 3 3 1 1
-#> Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla
-#> 1 2 2 2
-#> Toyota Corona Dodge Challenger AMC Javelin Camaro Z28
-#> 2 3 3 1
-#> Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa
-#> 1 2 2 2
-#> Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E
-#> 1 2 1 2
-#>
-#> Within cluster sum of squares by cluster:
-#> [1] 46659.32 32838.00 11846.09
-#> (between_SS / total_SS = 85.3 %)
-#>
-#> Available components:
-#>
-#> [1] "cluster" "centers" "totss" "withinss" "tot.withinss"
-#> [6] "betweenss" "size" "iter" "ifault"
-```
-
-Once you have a fitted tidyclust object, you can do a number of things.
-`predict()` returns the cluster a new observation belongs to
-
-``` r
-predict(kmeans_spec_fit, mtcars[1:4, ])
-#> # A tibble: 4 × 1
-#> .pred_cluster
-#>
-#> 1 Cluster_1
-#> 2 Cluster_1
-#> 3 Cluster_1
-#> 4 Cluster_2
-```
-
-`extract_cluster_assignment()` returns the cluster assignments of the
-training observations
-
-``` r
-extract_cluster_assignment(kmeans_spec_fit)
-#> # A tibble: 32 × 1
-#> .cluster
-#>
-#> 1 Cluster_1
-#> 2 Cluster_1
-#> 3 Cluster_1
-#> 4 Cluster_2
-#> 5 Cluster_3
-#> 6 Cluster_2
-#> 7 Cluster_3
-#> 8 Cluster_1
-#> 9 Cluster_1
-#> 10 Cluster_1
-#> # … with 22 more rows
-#> # ℹ Use `print(n = ...)` to see more rows
-```
-
-and `extract_clusters()` returns the locations of the clusters
-
-``` r
-extract_centroids(kmeans_spec_fit)
-#> # A tibble: 3 × 12
-#> .cluster mpg cyl disp hp drat wt qsec vs am gear carb
-#>
-#> 1 Cluster_1 17.0 7.43 276. 151. 2.99 3.60 18.1 0.286 0 3 2.14
-#> 2 Cluster_2 14.6 8 388. 232. 3.34 4.16 16.4 0 0.222 3.44 4
-#> 3 Cluster_3 24.5 4.62 122. 96.9 4.00 2.52 18.5 0.75 0.688 4.12 2.44
-```
+
+
+
+# tidyclust
+
+
+
+[data:image/s3,"s3://crabby-images/170c1/170c1719d9a3dd608e89fef473f0bbd0622f8ba0" alt="Codecov test
+coverage"](https://app.codecov.io/gh/EmilHvitfeldt/tidyclust?branch=main)
+[data:image/s3,"s3://crabby-images/8067c/8067c6e10b1c91450b8b1e7ef289401c35a91ea5" alt="R-CMD-check"](https://github.com/EmilHvitfeldt/tidyclust/actions/workflows/R-CMD-check.yaml)
+
+
+The goal of tidyclust is to provide a tidy, unified interface to
+clustering models. The packages is closely modeled after the
+[parsnip](https://parsnip.tidymodels.org/) package.
+
+## Installation
+
+You can install the development version of tidyclust from
+[GitHub](https://github.com/) with:
+
+``` r
+# install.packages("devtools")
+devtools::install_github("EmilHvitfeldt/tidyclust")
+```
+
+Please note that this package currently requires a [branch of the
+workflows](https://github.com/tidymodels/workflows/tree/tidyclust)
+package to work. Use with caution.
+
+## Example
+
+The first thing you do is to create a `cluster specification`. For this
+example we are creating a K-means model, using the `stats` engine.
+
+``` r
+library(tidyclust)
+
+kmeans_spec <- k_means(num_clusters = 3) %>%
+ set_engine("stats")
+
+kmeans_spec
+#> K Means Cluster Specification (partition)
+#>
+#> Main Arguments:
+#> num_clusters = 3
+#>
+#> Computational engine: stats
+```
+
+This specification can then be fit using data.
+
+``` r
+kmeans_spec_fit <- kmeans_spec %>%
+ fit(~., data = mtcars)
+kmeans_spec_fit
+#> tidyclust cluster object
+#>
+#> K-means clustering with 3 clusters of sizes 9, 16, 7
+#>
+#> Cluster means:
+#> mpg cyl disp hp drat wt qsec vs
+#> 1 14.64444 8.000000 388.2222 232.1111 3.343333 4.161556 16.40444 0.0000000
+#> 2 24.50000 4.625000 122.2937 96.8750 4.002500 2.518000 18.54312 0.7500000
+#> 3 17.01429 7.428571 276.0571 150.7143 2.994286 3.601429 18.11857 0.2857143
+#> am gear carb
+#> 1 0.2222222 3.444444 4.000000
+#> 2 0.6875000 4.125000 2.437500
+#> 3 0.0000000 3.000000 2.142857
+#>
+#> Clustering vector:
+#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive
+#> 2 2 2 3
+#> Hornet Sportabout Valiant Duster 360 Merc 240D
+#> 1 3 1 2
+#> Merc 230 Merc 280 Merc 280C Merc 450SE
+#> 2 2 2 3
+#> Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental
+#> 3 3 1 1
+#> Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla
+#> 1 2 2 2
+#> Toyota Corona Dodge Challenger AMC Javelin Camaro Z28
+#> 2 3 3 1
+#> Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa
+#> 1 2 2 2
+#> Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E
+#> 1 2 1 2
+#>
+#> Within cluster sum of squares by cluster:
+#> [1] 46659.32 32838.00 11846.09
+#> (between_SS / total_SS = 85.3 %)
+#>
+#> Available components:
+#>
+#> [1] "cluster" "centers" "totss" "withinss" "tot.withinss"
+#> [6] "betweenss" "size" "iter" "ifault"
+```
+
+Once you have a fitted tidyclust object, you can do a number of things.
+`predict()` returns the cluster a new observation belongs to
+
+``` r
+predict(kmeans_spec_fit, mtcars[1:4, ])
+#> # A tibble: 4 × 1
+#> .pred_cluster
+#>
+#> 1 Cluster_1
+#> 2 Cluster_1
+#> 3 Cluster_1
+#> 4 Cluster_2
+```
+
+`extract_cluster_assignment()` returns the cluster assignments of the
+training observations
+
+``` r
+extract_cluster_assignment(kmeans_spec_fit)
+#> # A tibble: 32 × 1
+#> .cluster
+#>
+#> 1 Cluster_1
+#> 2 Cluster_1
+#> 3 Cluster_1
+#> 4 Cluster_2
+#> 5 Cluster_3
+#> 6 Cluster_2
+#> 7 Cluster_3
+#> 8 Cluster_1
+#> 9 Cluster_1
+#> 10 Cluster_1
+#> # … with 22 more rows
+#> # ℹ Use `print(n = ...)` to see more rows
+```
+
+and `extract_clusters()` returns the locations of the clusters
+
+``` r
+extract_centroids(kmeans_spec_fit)
+#> # A tibble: 3 × 12
+#> .cluster mpg cyl disp hp drat wt qsec vs am gear carb
+#>
+#> 1 Cluster_1 17.0 7.43 276. 151. 2.99 3.60 18.1 0.286 0 3 2.14
+#> 2 Cluster_2 14.6 8 388. 232. 3.34 4.16 16.4 0 0.222 3.44 4
+#> 3 Cluster_3 24.5 4.62 122. 96.9 4.00 2.52 18.5 0.75 0.688 4.12 2.44
+```
diff --git a/_pkgdown.yml b/_pkgdown.yml
index 09616d33..2d3474c3 100644
--- a/_pkgdown.yml
+++ b/_pkgdown.yml
@@ -1,67 +1,67 @@
-url: ~
-template:
- bootstrap: 5
-
-reference:
-- title: Specifications
- desc: >
- These cluster specification fucntion are used to specify the type of model
- you want to do. These functions work in a similar fashion to the [model
- specification function from parsnip](https://parsnip.tidymodels.org/reference/index.html#models).
- contents:
- - k_means
-- title: Fit and Inspect
- desc: >
- These Functions are the generics are are supported for specifications
- created with tidyclust.
- contents:
- - fit.cluster_spec
- - augment.cluster_fit
- - glance.cluster_fit
- - tidy.cluster_fit
-- title: Prediction
- desc: >
- Once the cluster specification have been fit, you are likely to want to look
- at where the clusters are and which observations are associated with which
- cluster.
- contents:
- - predict.cluster_fit
- - extract_cluster_assignment
- - extract_centroids
-- title: Parameter Objects
- desc: >
- Parameter objects for tuning. Similar to
- [parameter objects from dials package](https://dials.tidymodels.org/reference/index.html#parameter-objects)
- contents:
- - num_clusters
-- title: Model based performance metrics
- desc: >
- These metrics use the fitted clustering model to extract values denoting how
- well the model works.
- contents:
- - cluster_metric_set
- - avg_silhouette
- - tot_sse
- - sse_ratio
- - tot_wss
-- title: Tuning
- desc: >
- Functions to allow multiple cluster specifications to be fit at once.
- contents:
- - control_cluster
- - tidyclust_update
- - finalize_model_tidyclust
- - tune_cluster
-- title: Developer tools
- contents:
- - extract_fit_summary
- - get_centroid_dists
- - new_cluster_metric
- - prep_data_dist
- - reconcile_clusterings
- - translate_tidyclust
- - within_cluster_sse
-- title: Other
- contents:
- - enrichment
- - silhouettes
+url: ~
+template:
+ bootstrap: 5
+
+reference:
+- title: Specifications
+ desc: >
+ These cluster specification fucntion are used to specify the type of model
+ you want to do. These functions work in a similar fashion to the [model
+ specification function from parsnip](https://parsnip.tidymodels.org/reference/index.html#models).
+ contents:
+ - k_means
+- title: Fit and Inspect
+ desc: >
+ These Functions are the generics are are supported for specifications
+ created with tidyclust.
+ contents:
+ - fit.cluster_spec
+ - augment.cluster_fit
+ - glance.cluster_fit
+ - tidy.cluster_fit
+- title: Prediction
+ desc: >
+ Once the cluster specification have been fit, you are likely to want to look
+ at where the clusters are and which observations are associated with which
+ cluster.
+ contents:
+ - predict.cluster_fit
+ - extract_cluster_assignment
+ - extract_centroids
+- title: Parameter Objects
+ desc: >
+ Parameter objects for tuning. Similar to
+ [parameter objects from dials package](https://dials.tidymodels.org/reference/index.html#parameter-objects)
+ contents:
+ - num_clusters
+- title: Model based performance metrics
+ desc: >
+ These metrics use the fitted clustering model to extract values denoting how
+ well the model works.
+ contents:
+ - cluster_metric_set
+ - avg_silhouette
+ - tot_sse
+ - sse_ratio
+ - tot_wss
+- title: Tuning
+ desc: >
+ Functions to allow multiple cluster specifications to be fit at once.
+ contents:
+ - control_cluster
+ - tidyclust_update
+ - finalize_model_tidyclust
+ - tune_cluster
+- title: Developer tools
+ contents:
+ - extract_fit_summary
+ - get_centroid_dists
+ - new_cluster_metric
+ - prep_data_dist
+ - reconcile_clusterings
+ - translate_tidyclust
+ - within_cluster_sse
+- title: Other
+ contents:
+ - enrichment
+ - silhouettes
diff --git a/dev/cross_val_kmeans.R b/dev/cross_val_kmeans.R
index 30880662..538985db 100644
--- a/dev/cross_val_kmeans.R
+++ b/dev/cross_val_kmeans.R
@@ -1,108 +1,108 @@
-library(tidymodels)
-library(tidyverse)
-library(tidyclust)
-
-## "Cross-validation" for kmeans
-
-ir <- iris %>% select(-Species)
-
-cvs <- vfold_cv(ir, v = 5)
-
-res <- data.frame(
- k = NA,
- i = NA,
- wss = NA,
- sil = NA,
- wss_2 = NA
-)
-
-for (k in 2:10) {
-
- km <- k_means(k = k) %>%
- set_engine("stats")
-
-
- for (i in 1:5) {
-
- tmp_train <- training(cvs$splits[[i]])
- tmp_test <- testing(cvs$splits[[i]])
-
- km_fit <- km %>% fit(~., data = tmp_train)
-
- wss <- km_fit %>%
- tot_wss(tmp_test)
-
- wss_2 <- km_fit$fit$tot.withinss
-
- sil <- km_fit %>%
- avg_silhouette(tmp_test)
-
- res <- rbind(res,
- c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2))
-
- }
-
-}
-
-res %>%
- drop_na() %>%
- ggplot(aes(x = factor(k), y = sil)) +
- geom_point()
-
-
-### Second idea
-## What if we cluster the whole data, then see how well subsamples are reclassified?
-## This needs "predict"
-## Doesn't really make sense yet
-
-cvs <- vfold_cv(ir, v = 10)
-
-res <- data.frame(
- k = NA,
- i = NA,
- acc = NA,
- f1 = NA
-)
-
-for (k in 2:10) {
-
- km <- k_means(k = k) %>%
- set_engine("stats")
-
- full_fit <- km %>% fit(~., data = ir)
-
-
- for (i in 1:10) {
-
- tmp_train <- training(cvs$splits[[i]])
- tmp_test <- testing(cvs$splits[[i]])
-
- km_fit <- km %>% fit(~., data = tmp_train)
-
- dat <- tmp_test %>%
- mutate(
- truth = predict(full_fit, tmp_test)$.pred_cluster,
- estimate = predict(km_fit, tmp_test)$.pred_cluster
- )
-
- thing <- reconcile_clusterings(dat$truth, dat$estimate)
-
- acc <- accuracy(thing, clusters_1, clusters_2)
- f1 <- f_meas(thing, clusters_1, clusters_2)
-
- res <- rbind(res,
- c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate))
-
- }
-
-}
-
-
-res %>%
- ggplot(aes(x = factor(k), y = f1)) +
- geom_point()
-
-
-### use orders from reconciling to order centers and check center similarity?
-### or to get "raw probabilities" - what does that mean though?
-### to do predict = raw
+library(tidymodels)
+library(tidyverse)
+library(tidyclust)
+
+## "Cross-validation" for kmeans
+
+ir <- iris %>% select(-Species)
+
+cvs <- vfold_cv(ir, v = 5)
+
+res <- data.frame(
+ k = NA,
+ i = NA,
+ wss = NA,
+ sil = NA,
+ wss_2 = NA
+)
+
+for (k in 2:10) {
+
+ km <- k_means(k = k) %>%
+ set_engine("stats")
+
+
+ for (i in 1:5) {
+
+ tmp_train <- training(cvs$splits[[i]])
+ tmp_test <- testing(cvs$splits[[i]])
+
+ km_fit <- km %>% fit(~., data = tmp_train)
+
+ wss <- km_fit %>%
+ tot_wss(tmp_test)
+
+ wss_2 <- km_fit$fit$tot.withinss
+
+ sil <- km_fit %>%
+ avg_silhouette(tmp_test)
+
+ res <- rbind(res,
+ c(k = k, i = i, wss = wss, sil = sil, wss_2 = wss_2))
+
+ }
+
+}
+
+res %>%
+ drop_na() %>%
+ ggplot(aes(x = factor(k), y = sil)) +
+ geom_point()
+
+
+### Second idea
+## What if we cluster the whole data, then see how well subsamples are reclassified?
+## This needs "predict"
+## Doesn't really make sense yet
+
+cvs <- vfold_cv(ir, v = 10)
+
+res <- data.frame(
+ k = NA,
+ i = NA,
+ acc = NA,
+ f1 = NA
+)
+
+for (k in 2:10) {
+
+ km <- k_means(k = k) %>%
+ set_engine("stats")
+
+ full_fit <- km %>% fit(~., data = ir)
+
+
+ for (i in 1:10) {
+
+ tmp_train <- training(cvs$splits[[i]])
+ tmp_test <- testing(cvs$splits[[i]])
+
+ km_fit <- km %>% fit(~., data = tmp_train)
+
+ dat <- tmp_test %>%
+ mutate(
+ truth = predict(full_fit, tmp_test)$.pred_cluster,
+ estimate = predict(km_fit, tmp_test)$.pred_cluster
+ )
+
+ thing <- reconcile_clusterings(dat$truth, dat$estimate)
+
+ acc <- accuracy(thing, clusters_1, clusters_2)
+ f1 <- f_meas(thing, clusters_1, clusters_2)
+
+ res <- rbind(res,
+ c(k = k, i = i, acc = acc$.estimate[1], f1 = f1$.estimate))
+
+ }
+
+}
+
+
+res %>%
+ ggplot(aes(x = factor(k), y = f1)) +
+ geom_point()
+
+
+### use orders from reconciling to order centers and check center similarity?
+### or to get "raw probabilities" - what does that mean though?
+### to do predict = raw
diff --git a/dev/kmeans.Rmd b/dev/kmeans.Rmd
index 0fa5245c..ef08b912 100644
--- a/dev/kmeans.Rmd
+++ b/dev/kmeans.Rmd
@@ -1,132 +1,132 @@
----
-title: "k-means Clustering"
-output: rmarkdown::html_vignette
-vignette: >
- %\VignetteIndexEntry{k-means Clustering}
- %\VignetteEngine{knitr::rmarkdown}
- %\VignetteEncoding{UTF-8}
----
-
-```{r, include = FALSE}
-knitr::opts_chunk$set(
- collapse = TRUE,
- comment = "#>"
-)
-```
-
-```{r setup}
-library(tidyclust)
-library(palmerpenguins)
-library(tidymodels)
-```
-
-## Fit
-
-```{r}
-kmeans_spec <- k_means(k = 5) %>%
- set_engine("stats")
-
-penguins_rec_1 <- recipe(~ ., data = penguins) %>%
- update_role(species, island, new_role = "demographic") %>%
- step_dummy(sex)
-
-
-penguins_rec_2 <- recipe(species ~ ., data = penguins) %>%
- step_dummy(sex, island)
-
-wflow_1 <- workflow() %>%
- add_model(kmeans_spec) %>%
- add_recipe(penguins_rec_1)
-
-
-wflow_2 <- workflow() %>%
- add_model(kmeans_spec) %>%
- add_recipe(penguins_rec_2)
-```
-
-We need workflows!
-
-```{r}
-# dropping NA first so rows match up later, this is clunky
-pen_sub <- penguins %>%
- drop_na() %>%
- select(-species, -island, -sex)
-
-kmeans_fit <- kmeans_spec %>% fit( ~., pen_sub)
-
-kmeans_fit %>%
- predict(new_data = pen_sub)
-
-### try my new version
-kmeans_fit %>%
- extract_cluster_assignment()
-```
-
-* Needed flexclust install; probably not necessary, we could implement for kmeans with just dists
-
-* Missing values should probably return an NA prediction. Or for k-means, imputation isn't crazy...
-
-* We want a consistent return, and k-means is randomized. We should agree to a consistent default ordering throughout tidyclust - maybe by size (number of members) or something?
-
-
-```{r}
-penguins %>%
- drop_na() %>%
- mutate(
- preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster
- ) %>%
- count(preds, sex, species)
-```
-
-## Diagnostics
-
-* Measure cluster enrichment with demographic variables (an official recipe designation?). Include Chi-Square or ANOVA tests???
-
-* Exploit confusion matrix from `yardstick`
-
-* Characterization of clusters, presumably by centers?
-
-* Automatic plot?!
-
-
-## talk to Emil
-
-* Ordering of clusters is really bothering me
-something with indices
-
-* Cluster density followup: is it kind of a model?
-
-```{r}
-recipe( ~ demo1 + predictor1) %>%
- step_tidyclust(kmeans_fit) # doesn't quite make sense
-
-```
-
-* Some of these followups feel like they need a recipe. Which vars are we using for within SS? Which vars for enrichments? Do we PCA first? etc.
-
-```{r}
-get_SS(Cluster ~ v1 + v2)
-recipe(Cluster ~ v1 + v2) %>%
- step_pca() %>%
- get_ss()
-```
-
-... or maybe fit `enrichment()` on a recipe and it automatically uses the variables with the "enrich" role?
-how would PCA fit in on this?
-
-How does `fit()` access the right variables?
-
-... is this even worth it given we can attach cluster assignments? I say yes, because what if we are trying to "cross-validate".
-
-
-```{r}
-penguins_2 <- penguins %>%
- drop_na() %>%
- mutate(
- preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster
- )
-
-debugonce(enrichment)
-penguins_2 %>%
- enrichment(preds, species)
-```
+---
+title: "k-means Clustering"
+output: rmarkdown::html_vignette
+vignette: >
+ %\VignetteIndexEntry{k-means Clustering}
+ %\VignetteEngine{knitr::rmarkdown}
+ %\VignetteEncoding{UTF-8}
+---
+
+```{r, include = FALSE}
+knitr::opts_chunk$set(
+ collapse = TRUE,
+ comment = "#>"
+)
+```
+
+```{r setup}
+library(tidyclust)
+library(palmerpenguins)
+library(tidymodels)
+```
+
+## Fit
+
+```{r}
+kmeans_spec <- k_means(k = 5) %>%
+ set_engine("stats")
+
+penguins_rec_1 <- recipe(~ ., data = penguins) %>%
+ update_role(species, island, new_role = "demographic") %>%
+ step_dummy(sex)
+
+
+penguins_rec_2 <- recipe(species ~ ., data = penguins) %>%
+ step_dummy(sex, island)
+
+wflow_1 <- workflow() %>%
+ add_model(kmeans_spec) %>%
+ add_recipe(penguins_rec_1)
+
+
+wflow_2 <- workflow() %>%
+ add_model(kmeans_spec) %>%
+ add_recipe(penguins_rec_2)
+```
+
+We need workflows!
+
+```{r}
+# dropping NA first so rows match up later, this is clunky
+pen_sub <- penguins %>%
+ drop_na() %>%
+ select(-species, -island, -sex)
+
+kmeans_fit <- kmeans_spec %>% fit( ~., pen_sub)
+
+kmeans_fit %>%
+ predict(new_data = pen_sub)
+
+### try my new version
+kmeans_fit %>%
+ extract_cluster_assignment()
+```
+
+* Needed flexclust install; probably not necessary, we could implement for kmeans with just dists
+
+* Missing values should probably return an NA prediction. Or for k-means, imputation isn't crazy...
+
+* We want a consistent return, and k-means is randomized. We should agree to a consistent default ordering throughout tidyclust - maybe by size (number of members) or something?
+
+
+```{r}
+penguins %>%
+ drop_na() %>%
+ mutate(
+ preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster
+ ) %>%
+ count(preds, sex, species)
+```
+
+## Diagnostics
+
+* Measure cluster enrichment with demographic variables (an official recipe designation?). Include Chi-Square or ANOVA tests???
+
+* Exploit confusion matrix from `yardstick`
+
+* Characterization of clusters, presumably by centers?
+
+* Automatic plot?!
+
+
+## talk to Emil
+
+* Ordering of clusters is really bothering me
+something with indices
+
+* Cluster density followup: is it kind of a model?
+
+```{r}
+recipe( ~ demo1 + predictor1) %>%
+ step_tidyclust(kmeans_fit) # doesn't quite make sense
+
+```
+
+* Some of these followups feel like they need a recipe. Which vars are we using for within SS? Which vars for enrichments? Do we PCA first? etc.
+
+```{r}
+get_SS(Cluster ~ v1 + v2)
+recipe(Cluster ~ v1 + v2) %>%
+ step_pca() %>%
+ get_ss()
+```
+
+... or maybe fit `enrichment()` on a recipe and it automatically uses the variables with the "enrich" role?
+how would PCA fit in on this?
+
+How does `fit()` access the right variables?
+
+... is this even worth it given we can attach cluster assignments? I say yes, because what if we are trying to "cross-validate".
+
+
+```{r}
+penguins_2 <- penguins %>%
+ drop_na() %>%
+ mutate(
+ preds = predict(kmeans_fit, new_data = pen_sub)$.pred_cluster
+ )
+
+debugonce(enrichment)
+penguins_2 %>%
+ enrichment(preds, species)
+```
diff --git a/dev/test_hc.R b/dev/test_hc.R
index d3ebe427..b6939152 100644
--- a/dev/test_hc.R
+++ b/dev/test_hc.R
@@ -1,33 +1,33 @@
-library(tidyverse)
-library(celery)
-
-ir <- iris[,-5]
-
-hclust(dist(ir))
-
-bob <- hclust_fit(ir)
-
-hc <- hier_clust(k = 3) %>%
- fit(~ ., data = ir)
-
-
-km <- k_means(k = 3) %>%
- fit(~., data = ir)
-
-
-thing <- tibble(
- km_c = extract_cluster_assignment(km)$.cluster,
- hc_c = extract_cluster_assignment(hc)$.cluster,
- truth = iris$Species
-)
-
-thing %>%
- count(hc_c,truth)
-
-cutree(hc$fit, k = 3)
-
-# hc %>%
-# extract_fit_engine() %>%
-# cutree(k = 3)
-
-## reconcile?
+library(tidyverse)
+library(celery)
+
+ir <- iris[,-5]
+
+hclust(dist(ir))
+
+bob <- hclust_fit(ir)
+
+hc <- hier_clust(k = 3) %>%
+ fit(~ ., data = ir)
+
+
+km <- k_means(k = 3) %>%
+ fit(~., data = ir)
+
+
+thing <- tibble(
+ km_c = extract_cluster_assignment(km)$.cluster,
+ hc_c = extract_cluster_assignment(hc)$.cluster,
+ truth = iris$Species
+)
+
+thing %>%
+ count(hc_c,truth)
+
+cutree(hc$fit, k = 3)
+
+# hc %>%
+# extract_fit_engine() %>%
+# cutree(k = 3)
+
+## reconcile?
diff --git a/dev/test_hclust_predict.R b/dev/test_hclust_predict.R
index a6caee32..700d1b0a 100644
--- a/dev/test_hclust_predict.R
+++ b/dev/test_hclust_predict.R
@@ -1,39 +1,39 @@
-library(tidyverse)
-library(tidyclust)
-#
-# my_mod <- hier_clust(k = 3) %>% fit(~., mtcars)
-#
-# #debugonce(tidyclust:::stats_hier_clust_predict)
-# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
-#
-# my_mod <- hier_clust(k = 3, linkage_method = "single") %>% fit(~., mtcars)
-# my_mod$fit$method
-# translate_tidyclust(hier_clust(k = 3, linkage_method = "single"))
-#
-# #debugonce(tidyclust:::stats_hier_clust_predict)
-# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
-#
-# my_mod <- hier_clust(k = 3, linkage_method = "average") %>% fit(~., mtcars)
-#
-# #debugonce(tidyclust:::stats_hier_clust_predict)
-# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
-#
-# my_mod <- hier_clust(k = 3, linkage_method = "median") %>% fit(~., mtcars)
-#
-# #debugonce(tidyclust:::stats_hier_clust_predict)
-# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
-
-my_mod <- hier_clust(k = 3, linkage_method = "centroid") %>% fit(~., mtcars)
-
-#debugonce(tidyclust:::stats_hier_clust_predict)
-tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
-
-my_mod <- hier_clust(k = 3, linkage_method = "ward.D") %>% fit(~., mtcars)
-# debugonce(extract_fit_summary.hclust)
-# extract_fit_summary.hclust(my_mod)
-
-#debugonce(tidyclust:::stats_hier_clust_predict)
-tidyclust:::stats_hier_clust_predict(my_mod$fit, mtcars)
-predict(my_mod, mtcars)
-
-avg_silhouette(my_mod)
+library(tidyverse)
+library(tidyclust)
+#
+# my_mod <- hier_clust(k = 3) %>% fit(~., mtcars)
+#
+# #debugonce(tidyclust:::stats_hier_clust_predict)
+# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
+#
+# my_mod <- hier_clust(k = 3, linkage_method = "single") %>% fit(~., mtcars)
+# my_mod$fit$method
+# translate_tidyclust(hier_clust(k = 3, linkage_method = "single"))
+#
+# #debugonce(tidyclust:::stats_hier_clust_predict)
+# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
+#
+# my_mod <- hier_clust(k = 3, linkage_method = "average") %>% fit(~., mtcars)
+#
+# #debugonce(tidyclust:::stats_hier_clust_predict)
+# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
+#
+# my_mod <- hier_clust(k = 3, linkage_method = "median") %>% fit(~., mtcars)
+#
+# #debugonce(tidyclust:::stats_hier_clust_predict)
+# tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
+
+my_mod <- hier_clust(k = 3, linkage_method = "centroid") %>% fit(~., mtcars)
+
+#debugonce(tidyclust:::stats_hier_clust_predict)
+tidyclust:::stats_hier_clust_predict(my_mod, mtcars)
+
+my_mod <- hier_clust(k = 3, linkage_method = "ward.D") %>% fit(~., mtcars)
+# debugonce(extract_fit_summary.hclust)
+# extract_fit_summary.hclust(my_mod)
+
+#debugonce(tidyclust:::stats_hier_clust_predict)
+tidyclust:::stats_hier_clust_predict(my_mod$fit, mtcars)
+predict(my_mod, mtcars)
+
+avg_silhouette(my_mod)
diff --git a/dev/to do b/dev/to do
index 2fe0da67..91e9e24c 100644
--- a/dev/to do
+++ b/dev/to do
@@ -1,21 +1,21 @@
-* similarity metrics
-- for comparison to external variable
-- for stability between runs
-- for reconciliation
-
-
-metrics for hclust
-parameter object ofr hclust
-
-+ sticker
-
-kmeans tutorial - autoplot!
-hclust tutorial
-
-tuning
-
-predicting and reconciling
-
-
-later:
-* augment with response variable, then compute metrics
+* similarity metrics
+- for comparison to external variable
+- for stability between runs
+- for reconciliation
+
+
+metrics for hclust
+parameter object ofr hclust
+
++ sticker
+
+kmeans tutorial - autoplot!
+hclust tutorial
+
+tuning
+
+predicting and reconciling
+
+
+later:
+* augment with response variable, then compute metrics
diff --git a/man/augment.Rd b/man/augment.Rd
index 4c27a1d6..8d197956 100644
--- a/man/augment.Rd
+++ b/man/augment.Rd
@@ -1,31 +1,31 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/augment.R
-\name{augment.cluster_fit}
-\alias{augment.cluster_fit}
-\title{Augment data with predictions}
-\usage{
-\method{augment}{cluster_fit}(x, new_data, ...)
-}
-\arguments{
-\item{x}{A \code{cluster_fit} object produced by \code{\link[=fit.cluster_spec]{fit.cluster_spec()}} or
-\code{\link[=fit_xy.cluster_spec]{fit_xy.cluster_spec()}} .}
-
-\item{new_data}{A data frame or matrix.}
-
-\item{...}{Not currently used.}
-}
-\description{
-\code{augment()} will add column(s) for predictions to the given data.
-}
-\details{
-For partition models, a \code{.pred_cluster} column is added.
-}
-\examples{
-kmeans_spec <- k_means(num_clusters = 5) \%>\%
- set_engine("stats")
-
-kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-
-kmeans_fit \%>\%
- augment(new_data = mtcars)
-}
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/augment.R
+\name{augment.cluster_fit}
+\alias{augment.cluster_fit}
+\title{Augment data with predictions}
+\usage{
+\method{augment}{cluster_fit}(x, new_data, ...)
+}
+\arguments{
+\item{x}{A \code{cluster_fit} object produced by \code{\link[=fit.cluster_spec]{fit.cluster_spec()}} or
+\code{\link[=fit_xy.cluster_spec]{fit_xy.cluster_spec()}} .}
+
+\item{new_data}{A data frame or matrix.}
+
+\item{...}{Not currently used.}
+}
+\description{
+\code{augment()} will add column(s) for predictions to the given data.
+}
+\details{
+For partition models, a \code{.pred_cluster} column is added.
+}
+\examples{
+kmeans_spec <- k_means(num_clusters = 5) \%>\%
+ set_engine("stats")
+
+kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+
+kmeans_fit \%>\%
+ augment(new_data = mtcars)
+}
diff --git a/man/avg_silhouette.Rd b/man/avg_silhouette.Rd
index f044428e..604a2d33 100644
--- a/man/avg_silhouette.Rd
+++ b/man/avg_silhouette.Rd
@@ -1,55 +1,55 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/metric-silhouette.R
-\name{avg_silhouette}
-\alias{avg_silhouette}
-\alias{avg_silhouette.cluster_fit}
-\alias{avg_silhouette.workflow}
-\alias{avg_silhouette_vec}
-\title{Measures average silhouette across all observations}
-\usage{
-avg_silhouette(object, ...)
-
-\method{avg_silhouette}{cluster_fit}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...)
-
-\method{avg_silhouette}{workflow}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...)
-
-avg_silhouette_vec(
- object,
- new_data = NULL,
- dists = NULL,
- dist_fun = Rfast::Dist,
- ...
-)
-}
-\arguments{
-\item{object}{A fitted kmeans tidyclust model}
-
-\item{...}{Other arguments passed to methods.}
-
-\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.}
-
-\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.}
-
-\item{dist_fun}{A function for calculating distances between observations.
-Defaults to Euclidean distance on processed data.}
-}
-\value{
-A double; the average silhouette.
-}
-\description{
-Measures average silhouette across all observations
-}
-\examples{
-kmeans_spec <- k_means(num_clusters = 5) \%>\%
- set_engine("stats")
-
-kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-
-dists <- mtcars \%>\%
- as.matrix() \%>\%
- dist()
-
-avg_silhouette(kmeans_fit, dists = dists)
-
-avg_silhouette_vec(kmeans_fit, dists = dists)
-}
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/metric-silhouette.R
+\name{avg_silhouette}
+\alias{avg_silhouette}
+\alias{avg_silhouette.cluster_fit}
+\alias{avg_silhouette.workflow}
+\alias{avg_silhouette_vec}
+\title{Measures average silhouette across all observations}
+\usage{
+avg_silhouette(object, ...)
+
+\method{avg_silhouette}{cluster_fit}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...)
+
+\method{avg_silhouette}{workflow}(object, new_data = NULL, dists = NULL, dist_fun = NULL, ...)
+
+avg_silhouette_vec(
+ object,
+ new_data = NULL,
+ dists = NULL,
+ dist_fun = Rfast::Dist,
+ ...
+)
+}
+\arguments{
+\item{object}{A fitted kmeans tidyclust model}
+
+\item{...}{Other arguments passed to methods.}
+
+\item{new_data}{A dataset to predict on. If \code{NULL}, uses trained clustering.}
+
+\item{dists}{A distance matrix. Used if \code{new_data} is \code{NULL}.}
+
+\item{dist_fun}{A function for calculating distances between observations.
+Defaults to Euclidean distance on processed data.}
+}
+\value{
+A double; the average silhouette.
+}
+\description{
+Measures average silhouette across all observations
+}
+\examples{
+kmeans_spec <- k_means(num_clusters = 5) \%>\%
+ set_engine("stats")
+
+kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+
+dists <- mtcars \%>\%
+ as.matrix() \%>\%
+ dist()
+
+avg_silhouette(kmeans_fit, dists = dists)
+
+avg_silhouette_vec(kmeans_fit, dists = dists)
+}
diff --git a/man/extract_centroids.Rd b/man/extract_centroids.Rd
index 330f0909..8a1d03c7 100644
--- a/man/extract_centroids.Rd
+++ b/man/extract_centroids.Rd
@@ -1,26 +1,26 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/extract_characterization.R
-\name{extract_centroids}
-\alias{extract_centroids}
-\title{Extract clusters from model}
-\usage{
-extract_centroids(object, ...)
-}
-\arguments{
-\item{object}{An cluster_spec object.}
-
-\item{...}{Other arguments passed to methods.}
-}
-\description{
-Extract clusters from model
-}
-\examples{
-set.seed(1234)
-kmeans_spec <- k_means(num_clusters = 5) \%>\%
- set_engine("stats")
-
-kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-
-kmeans_fit \%>\%
- extract_centroids()
-}
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/extract_characterization.R
+\name{extract_centroids}
+\alias{extract_centroids}
+\title{Extract clusters from model}
+\usage{
+extract_centroids(object, ...)
+}
+\arguments{
+\item{object}{An cluster_spec object.}
+
+\item{...}{Other arguments passed to methods.}
+}
+\description{
+Extract clusters from model
+}
+\examples{
+set.seed(1234)
+kmeans_spec <- k_means(num_clusters = 5) \%>\%
+ set_engine("stats")
+
+kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+
+kmeans_fit \%>\%
+ extract_centroids()
+}
diff --git a/man/extract_cluster_assignment.Rd b/man/extract_cluster_assignment.Rd
index 1d8eccbd..29fae8b8 100644
--- a/man/extract_cluster_assignment.Rd
+++ b/man/extract_cluster_assignment.Rd
@@ -1,25 +1,25 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/extract_assignment.R
-\name{extract_cluster_assignment}
-\alias{extract_cluster_assignment}
-\title{Extract cluster assignments from model}
-\usage{
-extract_cluster_assignment(object, ...)
-}
-\arguments{
-\item{object}{An cluster_spec object.}
-
-\item{...}{Other arguments passed to methods.}
-}
-\description{
-Extract cluster assignments from model
-}
-\examples{
-kmeans_spec <- k_means(num_clusters = 5) \%>\%
- set_engine("stats")
-
-kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-
-kmeans_fit \%>\%
- extract_cluster_assignment()
-}
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/extract_assignment.R
+\name{extract_cluster_assignment}
+\alias{extract_cluster_assignment}
+\title{Extract cluster assignments from model}
+\usage{
+extract_cluster_assignment(object, ...)
+}
+\arguments{
+\item{object}{An cluster_spec object.}
+
+\item{...}{Other arguments passed to methods.}
+}
+\description{
+Extract cluster assignments from model
+}
+\examples{
+kmeans_spec <- k_means(num_clusters = 5) \%>\%
+ set_engine("stats")
+
+kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+
+kmeans_fit \%>\%
+ extract_cluster_assignment()
+}
diff --git a/man/extract_fit_summary.Rd b/man/extract_fit_summary.Rd
index 56788562..9df27b2e 100644
--- a/man/extract_fit_summary.Rd
+++ b/man/extract_fit_summary.Rd
@@ -1,28 +1,28 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/extract_summary.R
-\name{extract_fit_summary}
-\alias{extract_fit_summary}
-\title{S3 method to get fitted model summary info depending on engine}
-\usage{
-extract_fit_summary(object, ...)
-}
-\arguments{
-\item{object}{a fitted cluster_spec object}
-
-\item{...}{other arguments passed to methods}
-}
-\value{
-A list with various summary elements
-}
-\description{
-S3 method to get fitted model summary info depending on engine
-}
-\examples{
-kmeans_spec <- k_means(num_clusters = 5) \%>\%
- set_engine("stats")
-
-kmeans_fit <- fit(kmeans_spec, ~., mtcars)
-
-kmeans_fit \%>\%
- extract_fit_summary()
-}
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/extract_summary.R
+\name{extract_fit_summary}
+\alias{extract_fit_summary}
+\title{S3 method to get fitted model summary info depending on engine}
+\usage{
+extract_fit_summary(object, ...)
+}
+\arguments{
+\item{object}{a fitted cluster_spec object}
+
+\item{...}{other arguments passed to methods}
+}
+\value{
+A list with various summary elements
+}
+\description{
+S3 method to get fitted model summary info depending on engine
+}
+\examples{
+kmeans_spec <- k_means(num_clusters = 5) \%>\%
+ set_engine("stats")
+
+kmeans_fit <- fit(kmeans_spec, ~., mtcars)
+
+kmeans_fit \%>\%
+ extract_fit_summary()
+}
diff --git a/man/figures/logo.svg b/man/figures/logo.svg
index 045b23b0..6129b854 100644
--- a/man/figures/logo.svg
+++ b/man/figures/logo.svg
@@ -1,646 +1,646 @@
-
-
-