Skip to content

Commit 339ad8f

Browse files
Merge pull request #166 from tidymodels/one-row-prediction-fix
2 parents ca02908 + 43a052c commit 339ad8f

File tree

7 files changed

+85
-1
lines changed

7 files changed

+85
-1
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
* Engine specific documentation has been added for all models and engines. (#159)
2424

25+
* Fixed bug where levels didn't match number of clusters if prediction on fewer number of observations. (#158)
26+
2527
# tidyclust 0.1.2
2628

2729
* The cluster specification methods for `generics::tune_args()` and `generics::tunable()` are now registered unconditionally (#115).

R/predict_helpers.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,5 @@ make_predictions <- function(x, prefix, n_clusters) {
150150
}
151151
pred_clusts <- unique(clusters$.cluster)[pred_clusts_num]
152152

153-
return(factor(pred_clusts))
153+
pred_clusts
154154
}

tests/testthat/test-hier_clust-stats.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ test_that("predicting", {
2727
)
2828
})
2929

30+
test_that("all levels are preserved with 1 row predictions", {
31+
set.seed(1234)
32+
spec <- hier_clust(num_clusters = 3) %>%
33+
set_engine("stats")
34+
35+
res <- fit(spec, ~., mtcars)
36+
37+
preds <- predict(res, mtcars[1, ])
38+
39+
expect_identical(
40+
levels(preds$.pred_cluster),
41+
paste0("Cluster_", 1:3)
42+
)
43+
})
44+
3045
test_that("extract_centroids() works", {
3146
set.seed(1234)
3247
spec <- hier_clust(num_clusters = 3) %>%

tests/testthat/test-k_means-clustMixType.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ test_that("predicting", {
4040
)
4141
})
4242

43+
test_that("all levels are preserved with 1 row predictions", {
44+
set.seed(1234)
45+
spec <- k_means(num_clusters = 3) %>%
46+
set_engine("clustMixType")
47+
48+
res <- fit(spec, ~., iris)
49+
50+
preds <- predict(res, iris[1, ])
51+
52+
expect_identical(
53+
levels(preds$.pred_cluster),
54+
paste0("Cluster_", 1:3)
55+
)
56+
})
57+
4358
test_that("extract_centroids() works", {
4459
skip_if_not_installed("clustMixType")
4560

tests/testthat/test-k_means-clusterR.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ test_that("predicting", {
3232
)
3333
})
3434

35+
test_that("all levels are preserved with 1 row predictions", {
36+
set.seed(1234)
37+
spec <- k_means(num_clusters = 3) %>%
38+
set_engine("ClusterR")
39+
40+
res <- fit(spec, ~., mtcars)
41+
42+
preds <- predict(res, mtcars[1, ])
43+
44+
expect_identical(
45+
levels(preds$.pred_cluster),
46+
paste0("Cluster_", 1:3)
47+
)
48+
})
49+
3550
test_that("extract_centroids() works", {
3651
skip_if_not_installed("ClusterR")
3752

tests/testthat/test-k_means-klaR.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,28 @@ test_that("predicting", {
5151
)
5252
})
5353

54+
test_that("all levels are preserved with 1 row predictions", {
55+
skip_if_not_installed("klaR")
56+
skip_if_not_installed("modeldata")
57+
58+
data("ames", package = "modeldata")
59+
60+
ames_cat <- dplyr::select(ames, dplyr::where(is.factor))
61+
62+
set.seed(1234)
63+
spec <- k_means(num_clusters = 3) %>%
64+
set_engine("klaR")
65+
66+
res <- fit(spec, ~., ames_cat)
67+
68+
preds <- predict(res, ames_cat[1, ])
69+
70+
expect_identical(
71+
levels(preds$.pred_cluster),
72+
paste0("Cluster_", 1:3)
73+
)
74+
})
75+
5476
test_that("predicting ties argument works", {
5577
skip_if_not_installed("klaR")
5678

tests/testthat/test-k_means-stats.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ test_that("predicting", {
2727
)
2828
})
2929

30+
test_that("all levels are preserved with 1 row predictions", {
31+
set.seed(1234)
32+
spec <- k_means(num_clusters = 3) %>%
33+
set_engine("stats")
34+
35+
res <- fit(spec, ~., mtcars)
36+
37+
preds <- predict(res, mtcars[1, ])
38+
39+
expect_identical(
40+
levels(preds$.pred_cluster),
41+
paste0("Cluster_", 1:3)
42+
)
43+
})
44+
3045
test_that("extract_centroids() works", {
3146
set.seed(1234)
3247
spec <- k_means(num_clusters = 3) %>%

0 commit comments

Comments
 (0)