Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: offset column role in Task #1225

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Authors@R:
comment = c(ORCID = "0000-0002-8115-0400")),
person("Sebastian", "Fischer", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0002-9609-3197")),
person("Lona", "Koers", , "[email protected]", role = "ctb")
person("Lona", "Koers", , "[email protected]", role = "ctb"),
person("John", "Zobolas", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0002-3609-8674"))
)
Description: Efficient, object-oriented programming on the
building blocks of machine learning. Provides 'R6' objects for tasks,
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mlr3 (development version)

* feat: add new `col_role` offset in `Task`.
* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
Expand Down
28 changes: 24 additions & 4 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ Task = R6Class("Task",
#' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (role `"group"`).
#' * `"weights"`: The task comes with observation weights (role `"weight"`).
#' * `"offset"`: The task includes an offset column specifying fixed adjustments for model training (role `"offset"`).
#' * `"ordered"`: The task has columns which define the row order (role `"order"`).
#'
#' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly.
Expand All @@ -907,6 +908,7 @@ Task = R6Class("Task",
if (length(col_roles$group)) "groups" else NULL,
if (length(col_roles$stratum)) "strata" else NULL,
if (length(col_roles$weight)) "weights" else NULL,
if (length(col_roles$offset)) "offset" else NULL,
if (length(col_roles$order)) "ordered" else NULL
)
} else {
Expand Down Expand Up @@ -951,6 +953,10 @@ Task = R6Class("Task",
#' Not more than a single column can be associated with this role.
#' * `"stratum"`: Stratification variables. Multiple discrete columns may have this role.
#' * `"weight"`: Observation weights. Not more than one numeric column may have this role.
#' * `"offset"`: Offset values specifying fixed adjustments for model training.
#' These values can be used to provide baseline predictions from an existing model for updating another model.
#' Some learners require an offset for each target class in a multiclass setting.
#' In this case, the offset columns must be named `"offset_target_class"`.
#'
#' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
Expand Down Expand Up @@ -1250,6 +1256,11 @@ task_check_col_roles.Task = function(task, new_roles, ...) {
}
}

# check offset
if (length(new_roles[["offset"]]) && any(fget(task$col_info, new_roles[["offset"]], "type", key = "id") %nin% c("numeric", "integer"))) {
stopf("Offset column(s) %s must be a numeric or integer column", paste0("'", new_roles[["offset"]], "'", collapse = ","))
}

return(new_roles)
}

Expand All @@ -1266,16 +1277,25 @@ task_check_col_roles.TaskClassif = function(task, new_roles, ...) {
stopf("Target column(s) %s must be a factor or ordered factor", paste0("'", new_roles[["target"]], "'", collapse = ","))
}

if (length(new_roles[["offset"]]) > 1L && length(task$class_names) == 2L) {
stop("There may only be up to one column with role 'offset' for binary classification tasks")
}

if (length(new_roles[["offset"]]) > 1L) {
expected_names = paste0("offset_", task$class_names)
expect_subset(new_roles[["offset"]], expected_names)
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskRegr = function(task, new_roles, ...) {

# check target
if (length(new_roles[["target"]]) > 1L) {
stopf("There may only be up to one column with role 'target'")
for (role in c("target", "offset")) {
if (length(new_roles[[role]]) > 1L) {
stopf("There may only be up to one column with role '%s'", role)
}
}

if (length(new_roles[["target"]]) && any(fget(task$col_info, new_roles[["target"]], "type", key = "id") %nin% c("numeric", "integer"))) {
Expand Down
6 changes: 3 additions & 3 deletions R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ local({
"use"
)

tmp = c("feature", "target", "name", "order", "stratum", "group", "weight")
tmp = c("feature", "target", "name", "order", "stratum", "group", "weight", "offset")
mlr_reflections$task_col_roles = list(
regr = tmp,
classif = tmp,
unsupervised = c("feature", "name", "order")
)

tmp = c("strata", "groups", "weights")
tmp = c("strata", "groups", "weights", "offset")
mlr_reflections$task_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp,
Expand All @@ -114,7 +114,7 @@ local({

mlr_reflections$task_print_col_roles = list(
before = character(),
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight")
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight", "Offset" = "offset")
)

### Learner
Expand Down
5 changes: 5 additions & 0 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,18 @@ test_that("stratify works", {
})

test_that("groups/weights work", {
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20),
o = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
task = TaskRegr$new("test", b, target = "y")
task$set_row_roles(16:20, character())

expect_false("groups" %chin% task$properties)
expect_false("weights" %chin% task$properties)
expect_false("offset" %chin% task$properties)
expect_null(task$groups)
expect_null(task$weights)

# weight
task$col_roles$weight = "w"
expect_subset("weights", task$properties)
expect_data_table(task$weights, ncols = 2, nrows = 15)
Expand All @@ -265,6 +268,7 @@ test_that("groups/weights work", {
task$col_roles$weight = character()
expect_true("weights" %nin% task$properties)

# group
task$col_roles$group = "g"
expect_subset("groups", task$properties)
expect_data_table(task$groups, ncols = 2, nrows = 15)
Expand Down Expand Up @@ -726,3 +730,4 @@ test_that("warn when internal valid task has 0 obs", {
task = tsk("iris")
expect_warning({task$internal_valid_task = 151}, "has 0 observations")
})

30 changes: 30 additions & 0 deletions tests/testthat/test_TaskClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,33 @@ test_that("target is encoded as factor (#629)", {
dt$target = ordered(dt$target)
TaskClassif$new(id = "XX", backend = dt, target = "target")
})

test_that("offset column role works with binary tasks", {
task = tsk("pima")
task$set_col_roles("glucose", "offset")

expect_subset("offset", task$properties)

expect_error({
task$col_roles$offset = c("glucose", "diabetes")
}, "There may only be up to one column with role")
})

test_that("offset column role works with multiclass tasks", {
task = tsk("penguins")
task$set_col_roles("body_mass", "offset")
expect_subset("offset", task$properties)

expect_error({
task$col_roles$offset = c("body_mass", "flipper_length")
}, "Must be a subset of")

task = tsk("penguins")
data = task$data()
set(data, j = "offset_Adelie", value = runif(nrow(data)))
set(data, j = "offset_Chinstrap", value = runif(nrow(data)))
task = as_task_classif(data, target = "species")
task$set_col_roles(c("offset_Adelie", "offset_Chinstrap"), "offset")

expect_subset("offset", task$properties)
})
15 changes: 15 additions & 0 deletions tests/testthat/test_TaskRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,18 @@ test_that("$add_strata", {
task$add_strata(task$target_names, bins = 2)
expect_identical(task$strata$N, c(50L, 10L))
})

test_that("offset column role works", {
task = tsk("mtcars")
task$set_col_roles("am", "offset")

expect_subset("offset", task$properties)

expect_error({
task$col_roles$offset = c("am", "gear")
}, "up to one")


task$col_roles$offset = character()
expect_true("offset" %nin% task$properties)
})
Loading