Skip to content

Commit

Permalink
Merge pull request #2794 from mlr-org/fix2771__factor_na
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-s authored Oct 28, 2020
2 parents 4cb6ab9 + bfc2cf9 commit 7d2138f
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 7 deletions.
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mlr 2.18.0.9000

- Internal changes only.
- Warning if `fix.factors.prediction = TRUE` causes the generation of NAs for new factor levels in prediction.
- Clear error message if prediction of wrapped learner has not the same length as `newdata`.
- Internal changes.


# mlr 2.18.0
Expand Down
6 changes: 6 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ predict.WrappedModel = function(object, task, newdata, subset = NULL, ...) {
dump = addClasses(get("last.dump", envir = .GlobalEnv), "mlr.dump")
}
}
# did the prediction fail otherwise?
np = nrow(p)
if (is.null(np)) np = length(p)
if (np != nrow(newdata)) {
stopf("predictLearner for %s has returned %i predictions instead of %i!", learner$id, np, nrow(newdata))
}
}
if (missing(task)) {
ids = NULL
Expand Down
9 changes: 7 additions & 2 deletions R/predictLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ predictLearner2 = function(.learner, .model, .newdata, ...) {
ns = intersect(colnames(.newdata), ns)
fls = fls[ns]
if (length(ns) > 0L) {
.newdata[ns] = mapply(factor, x = .newdata[ns],
levels = fls, SIMPLIFY = FALSE)
safe_factor = function(x, levels) {
if (length(setdiff(levels(x), levels)) > 0) {
warning("fix.factors.prediction = TRUE produced NAs because of new factor levels in prediction data.")
}
factor(x, levels)
}
.newdata[ns] = mapply(safe_factor, x = .newdata[ns], levels = fls, SIMPLIFY = FALSE)
}
}
p = predictLearner(.learner, .model, .newdata, ...)
Expand Down
2 changes: 1 addition & 1 deletion man/getFeatureImportance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions tests/testthat/test_base_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ test_that("resample printer respects show.info", {
})

test_that("resample drops unseen factors in predict data set", {
data = data.frame(a = c("a", "b", "a", "b", "a", "c"),
data = data.frame(
a = c("a", "b", "a", "b", "a", "c"),
b = c(1, 1, 2, 2, 2, 1),
trg = c("a", "b", "a", "b", "a", "b"),
stringsAsFactors = TRUE)
Expand All @@ -202,6 +203,12 @@ test_that("resample drops unseen factors in predict data set", {

lrn = makeLearner("classif.logreg", fix.factors.prediction = TRUE)
model = train(lrn, subsetTask(task, 1:4))
predict(model, subsetTask(task, 5:6))
resample(lrn, task, resinst)
expect_warning(predict(model, subsetTask(task, 5:6)), "produced NAs because of new factor levels")
expect_warning(resample(lrn, task, resinst), "produced NAs because of new factor levels")

# do it manually
train_task = makeClassifTask("unseen.factors", data[1:4,], "trg", fixup = "quiet") # quiet becasue
# we get dropped factors warning (which we want here)
model = train(lrn, train_task)
expect_warning(predict(model, newdata = data[5:6,]), "produced NAs because of new factor levels")
})
18 changes: 18 additions & 0 deletions tests/testthat/test_classif_ksvm.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,21 @@ test_that("classif_ksvm", {
testCV("classif.ksvm", multiclass.df, multiclass.target, tune.train = tt, tune.predict = tp,
parset = list(kernel = "polydot", degree = 3, offset = 2, scale = 1.5))
})

test_that("classif_ksvm produces error for new factor levels on predict", {
# https://github.com/mlr-org/mlr/issues/2771
train_data = data.frame(
A = sample(c("A","B"), 10, TRUE),
B = factor(sample(c("A", "B"), 10, replace = T))
)
test_data = data.frame(
A = sample(c("A","B"), 10, TRUE),
B = factor(sample(c("A", "B","C"), 10, replace = T))
)
lrn = makeLearner("classif.ksvm", fix.factors.prediction = TRUE)
train_task = makeClassifTask(data = train_data, target = "A")
model = train(lrn, train_task)
expect_warning({
expect_error(predict(model, newdata = test_data), "has returned .+ instead of 10")
}, "produced NAs because of new factor levels")
})

0 comments on commit 7d2138f

Please sign in to comment.