Skip to content

Commit 7d7a826

Browse files
committed
WIP split PipeOpEncodePL into two PipeOps, one for each method
1 parent 8a5e162 commit 7d7a826

File tree

2 files changed

+83
-56
lines changed

2 files changed

+83
-56
lines changed

R/PipeOpEncodePL.R

+82-56
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
#' @title Factor Encoding
1+
#' @title Piecewise Linear Encoding Base Class
22
#'
33
#' @usage NULL
4-
#' @name mlr_pipeops_encode
5-
#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
4+
#' @name mlr_pipeops_encodepl
5+
#' @format Abstract [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
66
#'
77
#' @description
8+
#' Abstract base class for piecewise linear encoding.
9+
#'
810
#' Encodes columns of type `numeric` and `integer`.
911
#'
1012
#'
@@ -37,78 +39,39 @@
3739
#' Initialized to `""`. One of:
3840
#'
3941
#' @section Methods:
40-
#' Only methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
42+
#' Methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`], as well as
43+
#' * `.get_bins(task, cols)`\cr
44+
#' ([`Task`][mlr3::Task], `character`) -> `list` \cr
45+
#'
4146
#'
4247
#' @references
4348
#' `r format_bib("gorishniy_2022")`
4449
#'
4550
#' @family PipeOps
51+
#' @family PipeOpsPLE
4652
#' @template seealso_pipeopslist
4753
#' @include PipeOpTaskPreproc.R
4854
#' @export
49-
#' @examples
50-
#' library("mlr3")
51-
#'
5255
PipeOpEncodePL = R6Class("PipeOpEncodePL",
5356
inherit = PipeOpTaskPreprocSimple,
5457
public = list(
55-
initialize = function(task_type, id = "encodepl", param_vals = list()) {
56-
# NOTE: Might use different name, change assert, and conditions
57-
assert_choice(task_type, mlr_reflections$task_types$task)
58-
if (task_type == "TaskRegr") {
59-
private$.tree_learner = LearnerRegrRpart$new()
60-
} else if (task_type == "TaskClassif") {
61-
private$.tree_learner = LearnerClassifRpart$new()
62-
} else {
63-
stopf("Task type %s not supported", task_type)
64-
}
65-
66-
private$.encodepl_param_set = ps(
67-
method = p_fct(levels = c("quantiles", "tree"), tags = c("train", "predict", "required")),
68-
quantiles_numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict"), depends = quote(method == "quantiles"))
69-
)
70-
private$.encodepl_param_set$values = list(method = "quantiles")
71-
72-
super$initialize(id, param_set = alist(encodepl = private$.encodepl_param_set, private$.tree_learner$param_set),
73-
param_vals = param_vals, packages = c("stats", private$.tree_learner$packages),
58+
initialize = function(id = "encodepl", param_set = ps(), param_vals = list()) {
59+
super$initialize(id, param_set = param_set, param_vals = param_vals,
7460
task_type = task_type, tags = "encode", feature_types = c("numeric", "integer"))
7561
}
7662
),
7763
private = list(
7864

79-
.tree_learner = NULL,
80-
.encodepl_param_set = NULL,
65+
.get_bins = function(task, cols) {
66+
stop("Abstract.")
67+
},
8168

8269
.get_state = function(task) {
8370
cols = private$.select_cols(task)
8471
if (!length(cols)) {
85-
return(task) # early exit
72+
return(list(bins = numeric(0))) # early exit
8673
}
87-
88-
pv = private$.encodepl_param_set$values
89-
numsplits = pv$quantiles_numsplits %??% 2
90-
91-
if (pv$method == "quantiles") {
92-
# TODO: check that min / max is correct here (according to paper / implementation)
93-
bins = lapply(task$data(cols = cols), function(d) {
94-
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
95-
})
96-
} else {
97-
learner = private$.tree_learner
98-
99-
bins = list()
100-
for (col in cols) {
101-
t = task$clone(deep = TRUE)$select(col)
102-
splits = learner$train(t)$model$splits
103-
# Get column "index" in model splits
104-
boundaries = unname(sort(splits[, "index"]))
105-
106-
d = task$data(cols = col)
107-
bins[[col]] = c(min(d), boundaries, max(d))
108-
}
109-
}
110-
111-
list(bins = bins)
74+
list(bins = .get_bins(task, cols))
11275
},
11376

11477
.transform = function(task) {
@@ -126,8 +89,6 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL",
12689
)
12790
)
12891

129-
mlr_pipeops$add("encodepl", PipeOpEncodePL, list(task_type = "TaskRegr"))
130-
13192
# Helper function to implement piecewise linear encoding.
13293
# * column: numeric vector
13394
# * colname: name of `column`
@@ -149,3 +110,68 @@ encode_piecewise_linear = function(column, colname, bins) {
149110

150111
dt
151112
}
113+
114+
#' PipeOpEncodePLQuantiles
115+
PipeOpEncodePLQuantiles = R6Class("PipeOpEncodePLQuantiles",
116+
inherit = PipeOpEncodePL,
117+
public = list(
118+
initialize = function(id = "encodeplquantiles", param_vals = list()) {
119+
ps = ps(
120+
numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict", "required"))
121+
)
122+
super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats")
123+
}
124+
),
125+
private = list(
126+
127+
.get_bins = function(task, cols) {
128+
numsplits = self$param_set$values$numsplits %??% 2
129+
lapply(task$data(cols = cols), function(d) {
130+
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
131+
})
132+
}
133+
)
134+
)
135+
136+
mlr_pipeops$add("encodeplquantiles", PipeOpEncodePLQuantiles)
137+
138+
#' PipeOpEncodePLTree
139+
PipeOpEncodePLTree = R6Class("PipeOpEncodePLTree",
140+
inherit = PipeOpEncodePL,
141+
public = list(
142+
initialize = function(task_type, id = "encodepltree", param_vals = list()) {
143+
assert_choice(task_type, mlr_reflections$task_types$task)
144+
if (task_type == "TaskRegr") {
145+
private$.tree_learner = LearnerRegrRpart$new()
146+
} else if (task_type == "TaskClassif") {
147+
private$.tree_learner = LearnerClassifRpart$new()
148+
} else {
149+
stopf("Task type %s not supported.", task_type)
150+
}
151+
152+
super$initialize(id, param_set = alist(private$.tree_learner$param_set), param_vals = param_vals,
153+
packages = private$.tree_learner$packages, task_type = task_type)
154+
}
155+
),
156+
private = list(
157+
158+
.tree_learner = NULL,
159+
160+
.get_bins = function(task, cols) {
161+
learner = private$.tree_learner
162+
163+
bins = list()
164+
for (col in cols) {
165+
t = task$clone(deep = TRUE)$select(col)
166+
# Get column "index" in model splits
167+
boundaries = unname(sort(learner$train(t)$model$splits[, "index"]))
168+
d = task$data(cols = col)
169+
bins[[col]] = c(min(d), boundaries, max(d))
170+
}
171+
bins
172+
}
173+
)
174+
)
175+
176+
# Registering with "TaskRegr", however both "TaskRegr" and "TaskClassif" are acceptable, see issue ...
177+
mlr_pipeops$add("encodepltree", PipeOpEncodePLTree, list(task_type = "TaskRegr"))

tests/testthat/test_pipeop_encodepl.R

+1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ test_that("PipeOpEncodePL - basic properties", {
1313
# - different methods
1414
# - with params (not all for regtree, hopefully)
1515
# - test on tasks with simple data that behaviour is as expected (compare dts)
16+
# - for different task types
1617
# - TODO: decide how to handle NAs in feature columns and test that

0 commit comments

Comments
 (0)