Skip to content

Commit 21019d9

Browse files
Merge pull request #168 from tidymodels/fix124
2 parents 339ad8f + a79e0d6 commit 21019d9

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
* Fixed bug where levels didn't match number of clusters if prediction on fewer number of observations. (#158)
2626

27+
* Fixed bug where `tune_cluster()` would error if used with an recipe that contained non-predictor variables such as id variables. (#124)
28+
2729
# tidyclust 0.1.2
2830

2931
* The cluster specification methods for `generics::tune_args()` and `generics::tunable()` are now registered unconditionally (#115).

R/metric-aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ extract_post_preprocessor <- function(object, new_data) {
234234
} else if (inherits(preprocessor, "recipe")) {
235235
new_data <- object %>%
236236
hardhat::extract_recipe() %>%
237-
recipes::bake(new_data)
237+
recipes::bake(new_data, recipes::all_predictors())
238238
}
239239
new_data
240240
}

tests/testthat/test-tune_cluster.R

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,42 @@ test_that("select_best() and show_best() works", {
471471
dplyr::select(num_clusters, .config)
472472
)
473473
})
474+
475+
test_that("doesn't error if recipes uses id variables", {
476+
helper_objects <- helper_objects_tidyclust()
477+
478+
mtcars_id <- mtcars %>%
479+
tibble::rownames_to_column(var = "model")
480+
481+
rec_id <- recipes::recipe(~., data = mtcars_id) %>%
482+
recipes::update_role(model, new_role = "id variable") %>%
483+
recipes::step_normalize(recipes::all_numeric_predictors())
484+
485+
set.seed(4400)
486+
wflow <- workflows::workflow() %>%
487+
workflows::add_recipe(rec_id) %>%
488+
workflows::add_model(helper_objects$kmeans_mod)
489+
pset <- hardhat::extract_parameter_set_dials(wflow) %>%
490+
update(num_clusters = dials::num_clusters(c(1, 3)))
491+
grid <- dials::grid_regular(pset, levels = 3)
492+
folds <- rsample::vfold_cv(mtcars_id, v = 2)
493+
control <- tune::control_grid(extract = identity)
494+
metrics <- cluster_metric_set(sse_within_total, sse_total)
495+
496+
res <- tune_cluster(
497+
wflow,
498+
resamples = folds,
499+
grid = grid,
500+
control = control,
501+
metrics = metrics
502+
)
503+
res_est <- tune::collect_metrics(res)
504+
res_workflow <- res$.extracts[[1]]$.extracts[[1]]
505+
506+
expect_equal(res$id, folds$id)
507+
expect_equal(nrow(res_est), nrow(grid) * 2)
508+
expect_equal(sum(res_est$.metric == "sse_total"), nrow(grid))
509+
expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid))
510+
expect_equal(res_est$n, rep(2, nrow(grid) * 2))
511+
expect_true(res_workflow$trained)
512+
})

0 commit comments

Comments
 (0)