diff --git a/NAMESPACE b/NAMESPACE index 2f4f72e5..8b0dc284 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -43,6 +43,9 @@ S3method(tidy,cluster_fit) S3method(translate_tidyclust,default) S3method(translate_tidyclust,hier_clust) S3method(translate_tidyclust,k_means) +S3method(tunable,cluster_spec) +S3method(tunable,k_means) +S3method(tune_args,cluster_spec) S3method(tune_cluster,cluster_spec) S3method(tune_cluster,default) S3method(tune_cluster,workflow) @@ -111,6 +114,8 @@ importFrom(generics,glance) importFrom(generics,min_grid) importFrom(generics,required_pkgs) importFrom(generics,tidy) +importFrom(generics,tunable) +importFrom(generics,tune_args) importFrom(hardhat,extract_fit_engine) importFrom(hardhat,extract_fit_parsnip) importFrom(hardhat,extract_parameter_set_dials) diff --git a/NEWS.md b/NEWS.md index 83e1d09d..3828867a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # tidyclust (development version) +* The cluster specification methods for `generics::tune_args()` and `generics::tunable()` are now registered unconditionally (#115). + # tidyclust 0.1.1 * Fixed bug where `extract_cluster_assignment()` and `predict()` sometimes didn't have agreement of clusters. (#94) diff --git a/R/tidyclust-package.R b/R/tidyclust-package.R index 45f9d307..523796ad 100644 --- a/R/tidyclust-package.R +++ b/R/tidyclust-package.R @@ -3,6 +3,7 @@ ## usethis namespace: start #' @importFrom dplyr bind_cols +#' @importFrom generics tunable tune_args #' @importFrom parsnip make_call #' @importFrom parsnip maybe_data_frame #' @importFrom parsnip maybe_matrix diff --git a/R/tunable.R b/R/tunable.R index 5b6d7b33..acfab565 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -1,7 +1,5 @@ -# Lazily registered in .onLoad() -# Unit tests are in extratests -# nocov start -tunable_cluster_spec <- function(x, ...) { +#' @export +tunable.cluster_spec <- function(x, ...) { mod_env <- rlang::ns_env("modelenv")$modelenv if (is.null(x$engine)) { @@ -57,8 +55,8 @@ add_engine_parameters <- function(pset, engines) { pset } -# Lazily registered in .onLoad() -tunable_k_means <- function(x, ...) { +#' @export +tunable.k_means <- function(x, ...) { res <- NextMethod() if (x$engine == "stats") { res <- add_engine_parameters(res, stats_k_means_engine_args) @@ -78,5 +76,3 @@ stats_k_means_engine_args <- component = "k_means", component_id = "engine" ) - -# nocov end diff --git a/R/tune_args.R b/R/tune_args.R index 610abd31..1ff24ad1 100644 --- a/R/tune_args.R +++ b/R/tune_args.R @@ -1,5 +1,5 @@ -# Lazily registered in .onLoad() -tune_args_cluster_spec <- function(object, full = FALSE, ...) { +#' @export +tune_args.cluster_spec <- function(object, full = FALSE, ...) { # use the cluster_spec top level class as the id cluster_type <- class(object)[1] diff --git a/R/zzz.R b/R/zzz.R index 5d1b9f44..44938d0c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -7,31 +7,6 @@ s3_register("generics::required_pkgs", "cluster_fit") s3_register("generics::required_pkgs", "cluster_spec") - # - If tune isn't installed, register the method (`packageVersion()` will error here) - # - If tune >= 0.1.6.9001 is installed, register the method - should_register_tune_args_method <- tryCatch( - expr = utils::packageVersion("tune") >= "0.1.6.9001", - error = function(cnd) TRUE - ) - - if (should_register_tune_args_method) { - # `tune_args.cluster_spec()` moved from tune to parsnip - vctrs::s3_register("generics::tune_args", "cluster_spec", tune_args_cluster_spec) - } - - # - If tune isn't installed, register the method (`packageVersion()` will error here) - # - If tune >= 0.1.6.9002 is installed, register the method - should_register_tunable_method <- tryCatch( - expr = utils::packageVersion("tune") >= "0.1.6.9002", - error = function(cnd) TRUE - ) - - if (should_register_tunable_method) { - # `tunable.cluster_spec()` and friends moved from tune to parsnip - vctrs::s3_register("generics::tunable", "cluster_spec", tunable_cluster_spec) - vctrs::s3_register("generics::tunable", "k_means", tunable_k_means) - } - ns <- rlang::ns_env("tidyclust") makeActiveBinding( "tidyclust_color", @@ -83,7 +58,6 @@ ) } - # vctrs:::s3_register() s3_register <- function(generic, class, method = NULL) { stopifnot(is.character(generic), length(generic) == 1)