Skip to content

Commit 8b387ea

Browse files
committed
test that tune_cluster works with recipes with id variables
1 parent 9ab8532 commit 8b387ea

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/testthat/test-tune_cluster.R

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,42 @@ test_that("select_best() and show_best() works", {
471471
dplyr::select(num_clusters, .config)
472472
)
473473
})
474+
475+
test_that("doesn't error if recipes uses id variables", {
476+
helper_objects <- helper_objects_tidyclust()
477+
478+
mtcars_id <- mtcars %>%
479+
tibble::rownames_to_column(var = "model")
480+
481+
rec_id <- recipes::recipe(~., data = mtcars_id) %>%
482+
recipes::update_role(model, new_role = "id variable") %>%
483+
recipes::step_normalize(recipes::all_numeric_predictors())
484+
485+
set.seed(4400)
486+
wflow <- workflows::workflow() %>%
487+
workflows::add_recipe(rec_id) %>%
488+
workflows::add_model(helper_objects$kmeans_mod)
489+
pset <- hardhat::extract_parameter_set_dials(wflow) %>%
490+
update(num_clusters = dials::num_clusters(c(1, 3)))
491+
grid <- dials::grid_regular(pset, levels = 3)
492+
folds <- rsample::vfold_cv(mtcars_id, v = 2)
493+
control <- tune::control_grid(extract = identity)
494+
metrics <- cluster_metric_set(sse_within_total, sse_total)
495+
496+
res <- tune_cluster(
497+
wflow,
498+
resamples = folds,
499+
grid = grid,
500+
control = control,
501+
metrics = metrics
502+
)
503+
res_est <- tune::collect_metrics(res)
504+
res_workflow <- res$.extracts[[1]]$.extracts[[1]]
505+
506+
expect_equal(res$id, folds$id)
507+
expect_equal(nrow(res_est), nrow(grid) * 2)
508+
expect_equal(sum(res_est$.metric == "sse_total"), nrow(grid))
509+
expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid))
510+
expect_equal(res_est$n, rep(2, nrow(grid) * 2))
511+
expect_true(res_workflow$trained)
512+
})

0 commit comments

Comments
 (0)