18
18
# ' This context is assigned at the beginning of the training loop and removed afterwards.
19
19
# ' Different stages of a callback can communicate with each other by assigning values to `$self`.
20
20
# '
21
+ # ' *State*:
22
+ # ' To be able to store information in the `$model` slot of a [`LearnerTorch`], callbacks support a state API.
23
+ # ' You can overload the `$state_dict()` public method to define what will be stored in `learner$model$callbacks$<id>`
24
+ # ' after training finishes.
25
+ # ' This then also requires to implement a `$load_state_dict(state_dict)` method that defines how to load a previously saved
26
+ # ' callback state into a different callback.
27
+ # ' Note that the `$state_dict()` should not include the parameter values that were used to initialize the callback.
28
+ # '
21
29
# ' For creating custom callbacks, the function [`torch_callback()`] is recommended, which creates a
22
30
# ' `CallbackSet` and then wraps it in a [`TorchCallback`].
23
31
# ' To create a `CallbackSet` the convenience function [`callback_set()`] can be used.
24
32
# ' These functions perform checks such as that the stages are not accidentally misspelled.
25
33
# '
34
+ # '
26
35
# ' @section Stages:
27
36
# ' * `begin` :: Run before the training loop begins.
28
37
# ' * `epoch_begin` :: Run he beginning of each epoch.
33
42
# ' * `batch_valid_begin` :: Run before the forward call in the validation loop.
34
43
# ' * `batch_valid_end` :: Run after the forward call in the validation loop.
35
44
# ' * `epoch_end` :: Run at the end of each epoch.
36
- # ' * `end` :: Run at last, using `on.exit()`.
45
+ # ' * `end` :: Run after last epoch.
46
+ # ' * `exit` :: Run at last, using `on.exit()`.
37
47
# ' @family Callback
38
48
# ' @export
39
49
CallbackSet = R6Class(" CallbackSet" ,
@@ -50,6 +60,21 @@ CallbackSet = R6Class("CallbackSet",
50
60
print = function (... ) {
51
61
catn(sprintf(" <%s>" , class(self )[[1L ]]))
52
62
catn(str_indent(" * Stages:" , self $ stages ))
63
+ },
64
+ # ' @description
65
+ # ' Returns information that is kept in the the [`LearnerTorch`]'s state after training.
66
+ # ' This information should be loadable into the callback using `$load_state_dict()` to be able to continue training.
67
+ # ' This returns `NULL` by default.
68
+ state_dict = function () {
69
+ NULL
70
+ },
71
+ # ' @description
72
+ # ' Loads the state dict into the callback to continue training.
73
+ # ' @param state_dict (any)\cr
74
+ # ' The state dict as retrieved via `$state_dict()`.
75
+ load_state_dict = function (state_dict ) {
76
+ assert_true(is.null(state_dict ))
77
+ NULL
53
78
}
54
79
),
55
80
active = list (
@@ -71,6 +96,10 @@ CallbackSet = R6Class("CallbackSet",
71
96
deep_clone = function (name , value ) {
72
97
if (name == " ctx" && ! is.null(value )) {
73
98
stopf(" CallbackSet instances can only be cloned when the 'ctx' is NULL." )
99
+ } else if (is.R6(value )) {
100
+ value $ clone(deep = TRUE )
101
+ } else if (is.data.table(value )) {
102
+ copy(value )
74
103
} else {
75
104
value
76
105
}
@@ -90,7 +119,7 @@ CallbackSet = R6Class("CallbackSet",
90
119
# '
91
120
# ' @param classname (`character(1)`)\cr
92
121
# ' The class name.
93
- # ' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end (`function`)\cr
122
+ # ' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end,on_exit (`function`)\cr
94
123
# ' Function to execute at the given stage, see section *Stages*.
95
124
# ' @param initialize (`function()`)\cr
96
125
# ' The initialization method of the callback.
@@ -101,6 +130,11 @@ CallbackSet = R6Class("CallbackSet",
101
130
# ' @param inherit (`R6ClassGenerator`)\cr
102
131
# ' From which class to inherit.
103
132
# ' This class must either be [`CallbackSet`] (default) or inherit from it.
133
+ # ' @param state_dict (`function()`)\cr
134
+ # ' The function that retrieves the state dict from the callback.
135
+ # ' This is what will be available in the learner after training.
136
+ # ' @param load_state_dict (`function(state_dict)`)\cr
137
+ # ' Function that loads a callback state.
104
138
# ' @param lock_objects (`logical(1)`)\cr
105
139
# ' Whether to lock the objects of the resulting [`R6Class`].
106
140
# ' If `FALSE` (default), values can be freely assigned to `self` without declaring them in the
@@ -115,6 +149,7 @@ callback_set = function(
115
149
# training
116
150
on_begin = NULL ,
117
151
on_end = NULL ,
152
+ on_exit = NULL ,
118
153
on_epoch_begin = NULL ,
119
154
on_before_valid = NULL ,
120
155
on_epoch_end = NULL ,
@@ -125,11 +160,16 @@ callback_set = function(
125
160
on_batch_valid_begin = NULL ,
126
161
on_batch_valid_end = NULL ,
127
162
# other methods
163
+ state_dict = NULL ,
164
+ load_state_dict = NULL ,
128
165
initialize = NULL ,
129
166
public = NULL , private = NULL , active = NULL , parent_env = parent.frame(), inherit = CallbackSet ,
130
167
lock_objects = FALSE
131
168
) {
132
169
assert_true(startsWith(classname , " CallbackSet" ))
170
+ assert_false(xor(is.null(state_dict ), is.null(load_state_dict )))
171
+ assert_function(state_dict , nargs = 0 , null.ok = TRUE )
172
+ assert_function(load_state_dict , args = " state_dict" , nargs = 1 , null.ok = TRUE )
133
173
more_public = list (
134
174
on_begin = assert_function(on_begin , nargs = 0 , null.ok = TRUE ),
135
175
on_end = assert_function(on_end , nargs = 0 , null.ok = TRUE ),
@@ -140,7 +180,8 @@ callback_set = function(
140
180
on_batch_end = assert_function(on_batch_end , nargs = 0 , null.ok = TRUE ),
141
181
on_after_backward = assert_function(on_after_backward , nargs = 0 , null.ok = TRUE ),
142
182
on_batch_valid_begin = assert_function(on_batch_valid_begin , nargs = 0 , null.ok = TRUE ),
143
- on_batch_valid_end = assert_function(on_batch_valid_end , nargs = 0 , null.ok = TRUE )
183
+ on_batch_valid_end = assert_function(on_batch_valid_end , nargs = 0 , null.ok = TRUE ),
184
+ on_exit = assert_function(on_exit , nargs = 0 , null.ok = TRUE )
144
185
)
145
186
146
187
assert_function(initialize , null.ok = TRUE )
@@ -153,6 +194,11 @@ callback_set = function(
153
194
assert_list(public , null.ok = TRUE , names = " unique" )
154
195
if (length(public )) assert_names(names(public ), disjunct.from = names(more_public ))
155
196
197
+ if (! is.null(state_dict )) {
198
+ public $ state_dict = state_dict
199
+ public $ load_state_dict = load_state_dict
200
+ }
201
+
156
202
invalid_stages = names(public )[grepl(" ^on_" , names(public ))]
157
203
158
204
if (length(invalid_stages )) {
0 commit comments