Skip to content

Commit a80278c

Browse files
authored
Feat: validation and internal tuning (#229)
* support internal tuning and validation * add eval_freq parameter to torch learner * add stage on_valid_end and support termination of training through callbacks * some other fixes
1 parent d41d116 commit a80278c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+480
-575
lines changed

DESCRIPTION

+5-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ Depends:
4040
Imports:
4141
backports,
4242
checkmate (>= 2.2.0),
43-
coro,
4443
data.table,
4544
lgr,
4645
methods,
@@ -56,6 +55,7 @@ Suggests:
5655
jsonlite,
5756
knitr,
5857
magick,
58+
mlr3tuning,
5959
progress,
6060
rmarkdown,
6161
rpart,
@@ -64,8 +64,9 @@ Suggests:
6464
torchvision,
6565
waldo
6666
Remotes:
67-
mlr-org/mlr3,
68-
mlr-org/mlr3pipelines,
67+
mlr-org/mlr3@feat/inner_valid,
68+
mlr-org/mlr3pipelines@feat/inner_valid,
69+
mlr-org/mlr3tuning@internal_tuning,
6970
mlverse/torchvision
7071
Config/testthat/edition: 3
7172
NeedsCompilation: no
@@ -79,6 +80,7 @@ Collate:
7980
'zzz.R'
8081
'TorchCallback.R'
8182
'CallbackSetCheckpoint.R'
83+
'CallbackSetEarlyStopping.R'
8284
'CallbackSetHistory.R'
8385
'CallbackSetProgress.R'
8486
'ContextTorch.R'
@@ -116,7 +118,6 @@ Collate:
116118
'PipeOpTorchOptimizer.R'
117119
'PipeOpTorchReshape.R'
118120
'PipeOpTorchSoftmax.R'
119-
'ResamplingRowRoles.R'
120121
'TaskClassif_lazy_iris.R'
121122
'TaskClassif_mnist.R'
122123
'TaskClassif_tiny_imagenet.R'

NAMESPACE

-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ export(PipeOpTorchTanh)
131131
export(PipeOpTorchTanhShrink)
132132
export(PipeOpTorchThreshold)
133133
export(PipeOpTorchUnsqueeze)
134-
export(ResamplingRowRoles)
135134
export(TorchCallback)
136135
export(TorchDescriptor)
137136
export(TorchIngressToken)

R/CallbackSet.R

+10-2
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,21 @@
3535
#' @section Stages:
3636
#' * `begin` :: Run before the training loop begins.
3737
#' * `epoch_begin` :: Run he beginning of each epoch.
38-
#' * `before_validation` :: Run before each validation loop.
3938
#' * `batch_begin` :: Run before the forward call.
4039
#' * `after_backward` :: Run after the backward call.
4140
#' * `batch_end` :: Run after the optimizer step.
4241
#' * `batch_valid_begin` :: Run before the forward call in the validation loop.
4342
#' * `batch_valid_end` :: Run after the forward call in the validation loop.
43+
#' * `valid_end` :: Run at the end of validation.
4444
#' * `epoch_end` :: Run at the end of each epoch.
4545
#' * `end` :: Run after last epoch.
4646
#' * `exit` :: Run at last, using `on.exit()`.
47+
#'
48+
#' @section Terminate Training:
49+
#' If training is to be stopped, it is possible to set the field `$terminate` of [`ContextTorch`].
50+
#' At the end of every epoch this field is checked and if it is `TRUE`, training stops.
51+
#' This can for example be used to implement custom early stopping.
52+
#'
4753
#' @family Callback
4854
#' @export
4955
CallbackSet = R6Class("CallbackSet",
@@ -119,7 +125,7 @@ CallbackSet = R6Class("CallbackSet",
119125
#'
120126
#' @param classname (`character(1)`)\cr
121127
#' The class name.
122-
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end,on_exit (`function`)\cr
128+
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end,on_valid_end,on_exit (`function`)\cr
123129
#' Function to execute at the given stage, see section *Stages*.
124130
#' @param initialize (`function()`)\cr
125131
#' The initialization method of the callback.
@@ -159,6 +165,7 @@ callback_set = function(
159165
# validation
160166
on_batch_valid_begin = NULL,
161167
on_batch_valid_end = NULL,
168+
on_valid_end = NULL,
162169
# other methods
163170
state_dict = NULL,
164171
load_state_dict = NULL,
@@ -181,6 +188,7 @@ callback_set = function(
181188
on_after_backward = assert_function(on_after_backward, nargs = 0, null.ok = TRUE),
182189
on_batch_valid_begin = assert_function(on_batch_valid_begin, nargs = 0, null.ok = TRUE),
183190
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE),
191+
on_valid_end = assert_function(on_valid_end, nargs = 0, null.ok = TRUE),
184192
on_exit = assert_function(on_exit, nargs = 0, null.ok = TRUE)
185193
)
186194

R/CallbackSetEarlyStopping.R

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
CallbackSetEarlyStopping = R6Class("CallbackSetEarlyStopping",
2+
inherit = CallbackSet,
3+
lock_objects = FALSE,
4+
public = list(
5+
initialize = function(patience, min_delta) {
6+
self$patience = assert_int(patience, lower = 1L)
7+
self$min_delta = assert_double(min_delta, lower = 0, len = 1L, any.missing = FALSE)
8+
self$stagnation = 0L
9+
},
10+
on_valid_end = function() {
11+
if (is.null(self$prev_valid_scores)) {
12+
self$prev_valid_scores = self$ctx$last_scores_valid
13+
return(NULL)
14+
}
15+
if (is.null(self$ctx$last_scores_valid)) {
16+
return(NULL)
17+
}
18+
multiplier = if (self$ctx$measures_valid[[1L]]$minimize) -1 else 1
19+
improvement = multiplier * (self$ctx$last_scores_valid[[1L]] - self$prev_valid_scores[[1L]])
20+
21+
if (is.na(improvement)) {
22+
lg$warn("Learner %s in epoch %s: Difference between subsequent validation performances is NA",
23+
self$ctx$learner$id, self$ctx$epoch)
24+
return(NULL)
25+
}
26+
27+
if (improvement < self$min_delta) {
28+
self$stagnation = self$stagnation + 1L
29+
if (self$stagnation == self$patience) {
30+
self$ctx$terminate = TRUE
31+
}
32+
} else {
33+
self$stagnation = 0
34+
}
35+
self$prev_valid_scores = self$ctx$last_scores_valid
36+
}
37+
)
38+
)

R/CallbackSetHistory.R

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#' @title History Callback
2+
23
#'
34
#' @name mlr_callback_set.history
45
#'
@@ -22,17 +23,18 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
2223
#' @description
2324
#' Converts the lists to data.tables.
2425
state_dict = function() {
25-
structure(list(
26+
list(
2627
train = rbindlist(self$train, fill = TRUE),
2728
valid = rbindlist(self$valid, fill = TRUE)
28-
), class = "callback_state_history")
29+
)
2930
},
3031
#' @description
3132
#' Sets the field `$train` and `$valid` to those contained in the state dict.
3233
#' @param state_dict (`callback_state_history`)\cr
3334
#' The state dict as retrieved via `$state_dict()`.
3435
load_state_dict = function(state_dict) {
35-
assert_class(state_dict, "callback_state_history")
36+
assert_list(state_dict, "data.table")
37+
assert_permutation(names(state_dict), c("train", "valid"))
3638
self$train = state_dict$train
3739
self$valid = state_dict$valid
3840
},

R/CallbackSetProgress.R

-6
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ CallbackSetProgress = R6Class("CallbackSetProgress",
6262
cat(paste(output, collapse = ""))
6363
}
6464
}
65-
},
66-
#' @description
67-
#' Deletes the progess bar objects.
68-
on_end = function() {
69-
self$pb_train = NULL
70-
self$pb_valid = NULL
7165
}
7266
)
7367
)

R/ContextTorch.R

+19-6
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ ContextTorch = R6Class("ContextTorch",
3838
#' The total number of epochs the learner is trained for.
3939
#' @param prediction_encoder (`function()`)\cr
4040
#' The learner's prediction encoder.
41+
#' @param eval_freq (`integer(1)`)\cr
42+
#' The evaluation frequency.
4143
initialize = function(learner, task_train, task_valid = NULL, loader_train, loader_valid = NULL,
42-
measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder) {
44+
measures_train = NULL, measures_valid = NULL, network, optimizer, loss_fn, total_epochs, prediction_encoder,
45+
eval_freq = 1L) {
4346
self$learner = assert_r6(learner, "Learner")
4447
self$task_train = assert_r6(task_train, "Task")
4548
self$task_valid = assert_r6(task_valid, "Task", null.ok = TRUE)
@@ -56,8 +59,8 @@ ContextTorch = R6Class("ContextTorch",
5659
self$last_scores_train = structure(list(), names = character(0))
5760
self$last_scores_valid = structure(list(), names = character(0))
5861
self$prediction_encoder = assert_function(prediction_encoder, args = c("predict_tensor", "task"))
59-
self$epoch = 0
60-
self$batch = 0
62+
self$eval_freq = assert_int(eval_freq, lower = 1L)
63+
self$terminate = FALSE
6164
},
6265
#' @field learner ([`Learner`])\cr
6366
#' The torch learner.
@@ -92,11 +95,15 @@ ContextTorch = R6Class("ContextTorch",
9295
#' @field total_epochs (`integer(1)`)\cr
9396
#' The total number of epochs the learner is trained for.
9497
total_epochs = NULL,
95-
#' @field last_scores_train (named `list()`)\cr
96-
#' The scores from the last training batch. Names are the ids of the training measures.
98+
#' @field last_scores_train (named `list()` or `NULL`)\cr
99+
#' The scores from the last training batch. Names are the ids of the training measures.
100+
#' If [`LearnerTorch`] sets `eval_freq` different from `1`, this is `NULL` in all epochs
101+
#' that don't evaluate the model.
97102
last_scores_train = NULL,
98103
#' @field last_scores_valid (`list()`)\cr
99104
#' The scores from the last validation batch. Names are the ids of the validation measures.
105+
#' If [`LearnerTorch`] sets `eval_freq` different from `1`, this is `NULL` in all epochs
106+
#' that don't evaluate the model.
100107
last_scores_valid = NULL,
101108
#' @field epoch (`integer(1)`)\cr
102109
#' The current epoch.
@@ -106,6 +113,12 @@ ContextTorch = R6Class("ContextTorch",
106113
step = NULL,
107114
#' @field prediction_encoder (`function()`)\cr
108115
#' The learner's prediction encoder.
109-
prediction_encoder = NULL
116+
prediction_encoder = NULL,
117+
#' @field batch (named `list()` of `torch_tensor`s)\cr
118+
#' The current batch.
119+
batch = NULL,
120+
#' @field terminate (`logical(1)`)\cr
121+
#' If this field is set to `TRUE` at the end of an epoch, training stops.
122+
terminate = NULL
110123
)
111124
)

R/DataBackendLazy.R

+7-6
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
173173
backend = function(rhs) {
174174
assert_ro_binding(rhs)
175175
if (is.null(private$.backend)) {
176-
private$.backend = assert_backend(private$.constructor(self))
176+
backend = assert_backend(private$.constructor(self))
177177

178178
f = function(test, x, y, var_name) {
179179
if (!test(x, y)) {
@@ -185,12 +185,13 @@ DataBackendLazy = R6Class("DataBackendLazy",
185185
}
186186
}
187187

188-
f(identical, private$.backend$primary_key, self$primary_key, "primary key")
189-
f(test_permutation, private$.backend$rownames, self$rownames, "row identifiers")
190-
f(test_permutation, private$.backend$colnames, private$.colnames, "column names")
191-
f(test_equal_col_info, col_info(private$.backend), private$.col_info, "column information")
188+
f(identical, backend$primary_key, self$primary_key, "primary key")
189+
f(test_permutation, backend$rownames, self$rownames, "row identifiers")
190+
f(test_permutation, backend$colnames, private$.colnames, "column names")
191+
f(test_equal_col_info, col_info(backend), private$.col_info, "column information")
192192
# need to reverse the order for correct error message
193-
f(function(x, y) test_subset(y, x), private$.backend$data_formats, self$data_formats, "data formats")
193+
f(function(x, y) test_subset(y, x), backend$data_formats, self$data_formats, "data formats")
194+
private$.backend = backend
194195
}
195196
private$.backend
196197
},

R/LearnerTorch.R

+46-10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#' * `loss_fn` :: The `$state_dict()` of the [loss][torch::nn_module] used to train the network.
4242
#' * `callbacks` :: The [callbacks][mlr3torch::mlr_callback_set] used to train the network.
4343
#' * `seed` :: The seed that was / is used for training and prediction.
44+
#' * `epochs` :: How many epochs the model was trained for (early stopping).
4445
#' * `task_col_info` :: A `data.table()` containing information about the train-task.
4546
#'
4647
#' @template paramset_torchlearner
@@ -141,7 +142,7 @@ LearnerTorch = R6Class("LearnerTorch",
141142

142143

143144
assert_subset(properties, mlr_reflections$learner_properties[[task_type]])
144-
properties = union(properties, "marshal")
145+
properties = union(properties, c("marshal", "validation", "internal_tuning"))
145146
assert_subset(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]))
146147
if (any(grepl("^(loss\\.|opt\\.|cb\\.)", param_set$ids()))) {
147148
stopf("Prefixes 'loss.', 'opt.', and 'cb.' are reserved for dynamically constructed parameters.")
@@ -210,6 +211,30 @@ LearnerTorch = R6Class("LearnerTorch",
210211
}
211212
),
212213
active = list(
214+
#' @field validate
215+
#' How to construct the internal validation data. This parameter can be either `NULL`,
216+
#' a ratio in $(0, 1)$, `"test"`, or `"predefined"`.
217+
validate = function(rhs) {
218+
if (!missing(rhs)) {
219+
private$.validate = assert_validate(rhs)
220+
}
221+
private$.validate
222+
},
223+
224+
#' @field internal_valid_scores
225+
#' Retrieves the internal validation scores as a named `list()`.
226+
#' Specify the `$validate` field and the `measures_valid` parameter to configure this.
227+
#' Returns `NULL` if learner is not trained yet.
228+
internal_valid_scores = function() {
229+
self$state$internal_valid_scores
230+
},
231+
#' @field internal_tuned_values
232+
#' When early stopping is activate, this returns a named list with the early-stopped epochs,
233+
#' otherwise an empty list is returned.
234+
#' Returns `NULL` if learner is not trained yet.
235+
internal_tuned_values = function() {
236+
self$state$internal_tuned_values
237+
},
213238
#' @field marshaled (`logical(1)`)\cr
214239
#' Whether the learner is marshaled.
215240
marshaled = function(rhs) {
@@ -257,6 +282,21 @@ LearnerTorch = R6Class("LearnerTorch",
257282
}
258283
),
259284
private = list(
285+
.extract_internal_tuned_values = function() {
286+
if (self$state$param_vals$patience == 0) {
287+
named_list()
288+
} else {
289+
list(epochs = self$model$epochs)
290+
}
291+
},
292+
.extract_internal_valid_scores = function() {
293+
if (is.null(self$model$internal_valid_scores)) {
294+
named_list()
295+
} else {
296+
self$model$internal_valid_scores
297+
}
298+
},
299+
.validate = NULL,
260300
.additional_phash_input = function() {
261301
if (is.null(self$initialize)) return(NULL)
262302
initformals = names(formals(args(self$initialize)))
@@ -372,20 +412,16 @@ LearnerTorch = R6Class("LearnerTorch",
372412
model = value$model
373413
value["model"] = list(NULL)
374414
value = super$deep_clone(name, value)
375-
value[["model"]] = set_class(list(
376-
network = model$network$clone(deep = TRUE),
377-
loss_fn = clone_recurse(model$loss_fn),
378-
optimizer = clone_recurse(model$optimizer),
379-
callbacks = map(model$callbacks, function(x) {
415+
model$network = model$network$clone(deep = TRUE)
416+
model$loss_fn = clone_recurse(model$loss_fn)
417+
model$callbacks = map(model$callbacks, function(x) {
380418
if (is.R6(x)) {
381419
x$clone(deep = TRUE)
382420
} else {
383421
x
384422
}
385-
}),
386-
seed = model$seed,
387-
task_col_info = copy(model$task_col_info)
388-
), c("learner_torch_model", "list"))
423+
})
424+
value$model = model
389425
}
390426
return(value)
391427
} else if (name == ".param_set") {

0 commit comments

Comments
 (0)