Skip to content

Commit d41d116

Browse files
authored
fix: deep-cloning, callback states, hashes(#198)
* fixes the deep clone methods of objects * callbacks now store a state and not themselves in the learner's model * fix some hashes * some other smaller improvements and fixes
1 parent 1a77da8 commit d41d116

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+846
-506
lines changed

DESCRIPTION

+8-6
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ Description: Deep Learning library that extends the mlr3 framework by building
3434
License: LGPL (>= 3)
3535
Depends:
3636
mlr3 (>= 0.19.0),
37-
mlr3pipelines (>= 0.5.2),
37+
mlr3pipelines,
3838
torch (>= 0.13.0),
3939
R (>= 3.5.0)
4040
Imports:
4141
backports,
4242
checkmate (>= 2.2.0),
4343
coro,
44+
data.table,
4445
lgr,
45-
mlr3misc (>= 0.14.0),
4646
methods,
47-
data.table,
48-
paradox (>= 0.11.0),
47+
mlr3misc (>= 0.14.0),
48+
paradox (>= 1.0.0),
4949
R6,
5050
withr
5151
Suggests:
@@ -60,10 +60,12 @@ Suggests:
6060
rmarkdown,
6161
rpart,
6262
viridis,
63+
testthat (>= 3.0.0),
6364
torchvision,
64-
testthat (>= 3.0.0)
65+
waldo
6566
Remotes:
66-
mlr-org/paradox,
67+
mlr-org/mlr3,
68+
mlr-org/mlr3pipelines,
6769
mlverse/torchvision
6870
Config/testthat/edition: 3
6971
NeedsCompilation: no

NAMESPACE

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ S3method(as_torch_optimizer,torch_optimizer_generator)
2828
S3method(c,lazy_tensor)
2929
S3method(col_info,DataBackendLazy)
3030
S3method(format,lazy_tensor)
31+
S3method(hash_input,TorchIngressToken)
3132
S3method(hash_input,lazy_tensor)
32-
S3method(marshal_model,learner_torch_state)
33+
S3method(hash_input,nn_module)
34+
S3method(marshal_model,learner_torch_model)
3335
S3method(materialize,data.frame)
3436
S3method(materialize,lazy_tensor)
3537
S3method(materialize,list)
@@ -52,7 +54,7 @@ S3method(t_opt,"NULL")
5254
S3method(t_opt,character)
5355
S3method(t_opts,"NULL")
5456
S3method(t_opts,character)
55-
S3method(unmarshal_model,learner_torch_state_marshaled)
57+
S3method(unmarshal_model,learner_torch_model_marshaled)
5658
export(CallbackSet)
5759
export(CallbackSetCheckpoint)
5860
export(CallbackSetHistory)

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# mlr3torch 0.0.0-900
1+
# mlr3torch dev

R/CallbackSet.R

+49-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,20 @@
1818
#' This context is assigned at the beginning of the training loop and removed afterwards.
1919
#' Different stages of a callback can communicate with each other by assigning values to `$self`.
2020
#'
21+
#' *State*:
22+
#' To be able to store information in the `$model` slot of a [`LearnerTorch`], callbacks support a state API.
23+
#' You can overload the `$state_dict()` public method to define what will be stored in `learner$model$callbacks$<id>`
24+
#' after training finishes.
25+
#' This then also requires to implement a `$load_state_dict(state_dict)` method that defines how to load a previously saved
26+
#' callback state into a different callback.
27+
#' Note that the `$state_dict()` should not include the parameter values that were used to initialize the callback.
28+
#'
2129
#' For creating custom callbacks, the function [`torch_callback()`] is recommended, which creates a
2230
#' `CallbackSet` and then wraps it in a [`TorchCallback`].
2331
#' To create a `CallbackSet` the convenience function [`callback_set()`] can be used.
2432
#' These functions perform checks such as that the stages are not accidentally misspelled.
2533
#'
34+
#'
2635
#' @section Stages:
2736
#' * `begin` :: Run before the training loop begins.
2837
#' * `epoch_begin` :: Run he beginning of each epoch.
@@ -33,7 +42,8 @@
3342
#' * `batch_valid_begin` :: Run before the forward call in the validation loop.
3443
#' * `batch_valid_end` :: Run after the forward call in the validation loop.
3544
#' * `epoch_end` :: Run at the end of each epoch.
36-
#' * `end` :: Run at last, using `on.exit()`.
45+
#' * `end` :: Run after last epoch.
46+
#' * `exit` :: Run at last, using `on.exit()`.
3747
#' @family Callback
3848
#' @export
3949
CallbackSet = R6Class("CallbackSet",
@@ -50,6 +60,21 @@ CallbackSet = R6Class("CallbackSet",
5060
print = function(...) {
5161
catn(sprintf("<%s>", class(self)[[1L]]))
5262
catn(str_indent("* Stages:", self$stages))
63+
},
64+
#' @description
65+
#' Returns information that is kept in the the [`LearnerTorch`]'s state after training.
66+
#' This information should be loadable into the callback using `$load_state_dict()` to be able to continue training.
67+
#' This returns `NULL` by default.
68+
state_dict = function() {
69+
NULL
70+
},
71+
#' @description
72+
#' Loads the state dict into the callback to continue training.
73+
#' @param state_dict (any)\cr
74+
#' The state dict as retrieved via `$state_dict()`.
75+
load_state_dict = function(state_dict) {
76+
assert_true(is.null(state_dict))
77+
NULL
5378
}
5479
),
5580
active = list(
@@ -71,6 +96,10 @@ CallbackSet = R6Class("CallbackSet",
7196
deep_clone = function(name, value) {
7297
if (name == "ctx" && !is.null(value)) {
7398
stopf("CallbackSet instances can only be cloned when the 'ctx' is NULL.")
99+
} else if (is.R6(value)) {
100+
value$clone(deep = TRUE)
101+
} else if (is.data.table(value)) {
102+
copy(value)
74103
} else {
75104
value
76105
}
@@ -90,7 +119,7 @@ CallbackSet = R6Class("CallbackSet",
90119
#'
91120
#' @param classname (`character(1)`)\cr
92121
#' The class name.
93-
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end (`function`)\cr
122+
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end,on_exit (`function`)\cr
94123
#' Function to execute at the given stage, see section *Stages*.
95124
#' @param initialize (`function()`)\cr
96125
#' The initialization method of the callback.
@@ -101,6 +130,11 @@ CallbackSet = R6Class("CallbackSet",
101130
#' @param inherit (`R6ClassGenerator`)\cr
102131
#' From which class to inherit.
103132
#' This class must either be [`CallbackSet`] (default) or inherit from it.
133+
#' @param state_dict (`function()`)\cr
134+
#' The function that retrieves the state dict from the callback.
135+
#' This is what will be available in the learner after training.
136+
#' @param load_state_dict (`function(state_dict)`)\cr
137+
#' Function that loads a callback state.
104138
#' @param lock_objects (`logical(1)`)\cr
105139
#' Whether to lock the objects of the resulting [`R6Class`].
106140
#' If `FALSE` (default), values can be freely assigned to `self` without declaring them in the
@@ -115,6 +149,7 @@ callback_set = function(
115149
# training
116150
on_begin = NULL,
117151
on_end = NULL,
152+
on_exit = NULL,
118153
on_epoch_begin = NULL,
119154
on_before_valid = NULL,
120155
on_epoch_end = NULL,
@@ -125,11 +160,16 @@ callback_set = function(
125160
on_batch_valid_begin = NULL,
126161
on_batch_valid_end = NULL,
127162
# other methods
163+
state_dict = NULL,
164+
load_state_dict = NULL,
128165
initialize = NULL,
129166
public = NULL, private = NULL, active = NULL, parent_env = parent.frame(), inherit = CallbackSet,
130167
lock_objects = FALSE
131168
) {
132169
assert_true(startsWith(classname, "CallbackSet"))
170+
assert_false(xor(is.null(state_dict), is.null(load_state_dict)))
171+
assert_function(state_dict, nargs = 0, null.ok = TRUE)
172+
assert_function(load_state_dict, args = "state_dict", nargs = 1, null.ok = TRUE)
133173
more_public = list(
134174
on_begin = assert_function(on_begin, nargs = 0, null.ok = TRUE),
135175
on_end = assert_function(on_end, nargs = 0, null.ok = TRUE),
@@ -140,7 +180,8 @@ callback_set = function(
140180
on_batch_end = assert_function(on_batch_end, nargs = 0, null.ok = TRUE),
141181
on_after_backward = assert_function(on_after_backward, nargs = 0, null.ok = TRUE),
142182
on_batch_valid_begin = assert_function(on_batch_valid_begin, nargs = 0, null.ok = TRUE),
143-
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE)
183+
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE),
184+
on_exit = assert_function(on_exit, nargs = 0, null.ok = TRUE)
144185
)
145186

146187
assert_function(initialize, null.ok = TRUE)
@@ -153,6 +194,11 @@ callback_set = function(
153194
assert_list(public, null.ok = TRUE, names = "unique")
154195
if (length(public)) assert_names(names(public), disjunct.from = names(more_public))
155196

197+
if (!is.null(state_dict)) {
198+
public$state_dict = state_dict
199+
public$load_state_dict = load_state_dict
200+
}
201+
156202
invalid_stages = names(public)[grepl("^on_", names(public))]
157203

158204
if (length(invalid_stages)) {

R/CallbackSetCheckpoint.R

+58-19
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33
#' @name mlr_callback_set.checkpoint
44
#'
55
#' @description
6-
#' Saves the model during training.
6+
#' Saves the optimizer and network states during training.
7+
#' The final network and optimizer are always stored.
8+
#' @details
9+
#' Saving the learner itself in the callback with a trained model is impossible,
10+
#' as the model slot is set *after* the last callback step is executed.
11+
#'
712
#' @param path (`character(1)`)\cr
8-
#' The path to a folder where the models are saved. This path must not exist before.
13+
#' The path to a folder where the models are saved.
914
#' @param freq (`integer(1)`)\cr
10-
#' The frequency how often the model is saved (epoch frequency).
11-
#'
15+
#' The frequency how often the model is saved.
16+
#' Frequency is either per step or epoch, which can be configured through the `freq_type` parameter.
17+
#' @param freq_type (`character(1)`)\cr
18+
#' Can be be either `"epoch"` (default) or `"step"`.
1219
#' @family Callback
1320
#' @export
1421
#' @include CallbackSet.R
@@ -19,27 +26,58 @@ CallbackSetCheckpoint = R6Class("CallbackSetCheckpoint",
1926
public = list(
2027
#' @description
2128
#' Creates a new instance of this [R6][R6::R6Class] class.
22-
initialize = function(path, freq) {
23-
# TODO: Maybe we want to be able to give gradient steps here instead of epochs?
24-
assert_path_for_output(path)
25-
dir.create(path, recursive = TRUE)
26-
self$path = path
29+
initialize = function(path, freq, freq_type = "epoch") {
2730
self$freq = assert_int(freq, lower = 1L)
31+
self$path = assert_path_for_output(path)
32+
self$freq_type = assert_choice(freq_type, c("epoch", "step"))
33+
if (!dir.exists(path)) {
34+
dir.create(path, recursive = TRUE)
35+
}
2836
},
2937
#' @description
30-
#' Saves the network state dict.
38+
#' Saves the network and optimizer state dict.
39+
#' Does nothing if `freq_type` or `freq` are not met.
3140
on_epoch_end = function() {
32-
if ((self$ctx$epoch %% self$freq) == 0) {
33-
torch::torch_save(self$ctx$network, file.path(self$path, paste0("network", self$ctx$epoch, ".pt")))
41+
if (self$freq_type == "step" || (self$ctx$epoch %% self$freq != 0)) {
42+
return(NULL)
43+
}
44+
private$.save(self$ctx$epoch)
45+
},
46+
#' @description
47+
#' Saves the selected objects defined in `save`.
48+
#' Does nothing if freq_type or freq are not met.
49+
on_batch_end = function() {
50+
if (self$freq_type == "epoch" || (self$ctx$step %% self$freq != 0)) {
51+
return(NULL)
3452
}
53+
private$.save(self$ctx$step)
3554
},
3655
#' @description
37-
#' Saves the final network.
38-
on_end = function() {
39-
path = file.path(self$path, paste0("network", self$ctx$epoch, ".pt"))
40-
if (!file.exists(path)) { # no need to save the last network twice if it was already saved.
41-
torch::torch_save(self$ctx$network, path)
56+
#' Saves the learner.
57+
on_exit = function() {
58+
if (self$ctx$epoch == 0) return(NULL)
59+
if (self$freq_type == "epoch") {
60+
if (self$ctx$epoch %% self$freq == 0) {
61+
# already saved
62+
return(NULL)
63+
} else {
64+
private$.save(self$ctx$epoch)
65+
}
4266
}
67+
if (self$freq_type == "step") {
68+
if (self$ctx$step %% self$freq == 0) {
69+
# already saved
70+
return(NULL)
71+
} else {
72+
private$.save(self$ctx$epoch)
73+
}
74+
}
75+
}
76+
),
77+
private = list(
78+
.save = function(suffix) {
79+
torch_save(self$ctx$network$state_dict(), file.path(self$path, paste0("network", suffix, ".pt")))
80+
torch_save(self$ctx$optimizer$state_dict(), file.path(self$path, paste0("optimizer", suffix, ".pt")))
4381
}
4482
)
4583
)
@@ -49,8 +87,9 @@ mlr3torch_callbacks$add("checkpoint", function() {
4987
TorchCallback$new(
5088
callback_generator = CallbackSetCheckpoint,
5189
param_set = ps(
52-
path = p_uty(),
53-
freq = p_int(lower = 1L)
90+
path = p_uty(tags = c("train", "required")),
91+
freq = p_int(lower = 1L, tags = c("train", "required")),
92+
freq_type = p_fct(default = "epoch", c("epoch", "step"), tags = "train")
5493
),
5594
id = "checkpoint",
5695
label = "Checkpoint",

R/CallbackSetHistory.R

+14-55
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,20 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
2121
},
2222
#' @description
2323
#' Converts the lists to data.tables.
24-
on_end = function() {
25-
self$train = rbindlist(self$train, fill = TRUE)
26-
self$valid = rbindlist(self$valid, fill = TRUE)
24+
state_dict = function() {
25+
structure(list(
26+
train = rbindlist(self$train, fill = TRUE),
27+
valid = rbindlist(self$valid, fill = TRUE)
28+
), class = "callback_state_history")
29+
},
30+
#' @description
31+
#' Sets the field `$train` and `$valid` to those contained in the state dict.
32+
#' @param state_dict (`callback_state_history`)\cr
33+
#' The state dict as retrieved via `$state_dict()`.
34+
load_state_dict = function(state_dict) {
35+
assert_class(state_dict, "callback_state_history")
36+
self$train = state_dict$train
37+
self$valid = state_dict$valid
2738
},
2839
#' @description
2940
#' Add the latest training scores to the history.
@@ -42,58 +53,6 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
4253
list(epoch = self$ctx$epoch), self$ctx$last_scores_valid
4354
)
4455
}
45-
},
46-
#' @description Plots the history.
47-
#' @param measures (`character()`)\cr
48-
#' Which measures to plot. No default.
49-
#' @param set (`character(1)`)\cr
50-
#' Which set to plot. Either `"train"` or `"valid"`. Default is `"valid"`.
51-
#' @param epochs (`integer()`)\cr
52-
#' An integer vector restricting which epochs to plot. Default is `NULL`, which plots all epochs.
53-
#' @param theme ([ggplot2::theme()])\cr
54-
#' The theme, [ggplot2::theme_minimal()] is the default.
55-
#' @param ... (any)\cr
56-
#' Currently unused.
57-
plot = function(measures, set = "valid", epochs = NULL, theme = ggplot2::theme_minimal(), ...) {
58-
assert_choice(set, c("valid", "train"))
59-
data = self[[set]]
60-
assert_subset(measures, colnames(data))
61-
62-
if (is.null(epochs)) {
63-
data = data[, c("epoch", measures), with = FALSE]
64-
} else {
65-
assert_integerish(epochs, unique = TRUE)
66-
data = data[get("epoch") %in% epochs, c("epoch", measures), with = FALSE]
67-
}
68-
69-
if ((!nrow(data)) || (ncol(data) < 2)) {
70-
stopf("No eligible measures to plot for set '%s'.", set)
71-
}
72-
73-
epoch = score = measure = .data = NULL
74-
if (ncol(data) == 2L) {
75-
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = .data[[measures]])) +
76-
ggplot2::geom_line() +
77-
ggplot2::geom_point() +
78-
ggplot2::labs(
79-
x = "Epoch",
80-
y = measures,
81-
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
82-
) +
83-
theme
84-
} else {
85-
data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
86-
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) +
87-
viridis::scale_color_viridis(discrete = TRUE) +
88-
ggplot2::geom_line() +
89-
ggplot2::geom_point() +
90-
ggplot2::labs(
91-
x = "Epoch",
92-
y = "Score",
93-
title = sprintf("%s Loss", switch(set, valid = "Validation", train = "Training"))
94-
) +
95-
theme
96-
}
9756
}
9857
),
9958
private = list(

R/ContextTorch.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ ContextTorch = R6Class("ContextTorch",
101101
#' @field epoch (`integer(1)`)\cr
102102
#' The current epoch.
103103
epoch = NULL,
104-
#' @field batch (`integer(1)`)\cr
105-
#' The current iteration of the batch.
106-
batch = NULL,
104+
#' @field step (`integer(1)`)\cr
105+
#' The current iteration.
106+
step = NULL,
107107
#' @field prediction_encoder (`function()`)\cr
108108
#' The learner's prediction encoder.
109109
prediction_encoder = NULL

0 commit comments

Comments
 (0)