diff --git a/DESCRIPTION b/DESCRIPTION index c904532..52ba3ac 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,7 +23,7 @@ Imports: checkmate, data.table, lgr, - mlr3 (>= 0.17.0), + mlr3 (>= 0.19.0), mlr3misc, uuid Suggests: diff --git a/NEWS.md b/NEWS.md index 093c464..8d11f80 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ * feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`, useful for deleting model data to avoid running out of memory, https://github.com/mlr-org/mlr3batchmark/issues/18 Thanks to Toby Dylan Hocking @tdhock for the PR. * docs: A warning is now given when the loaded mlr3 version differs from the mlr3 version stored in the trained learners +* Support marshaling # mlr3batchmark 0.1.1 diff --git a/R/reduceResultsBatchmark.R b/R/reduceResultsBatchmark.R index 7ba85e8..441c1f2 100644 --- a/R/reduceResultsBatchmark.R +++ b/R/reduceResultsBatchmark.R @@ -14,7 +14,8 @@ #' #' @return [mlr3::BenchmarkResult]. #' @export -reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), fun=NULL) { # nolint +reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), fun = NULL, unmarshal = TRUE) { # nolint + assert_flag(unmarshal) if (is.null(ids)) { ids = batchtools::findDone(ids, reg = reg) } else { @@ -87,5 +88,9 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch bmr$combine(mlr3::BenchmarkResult$new(rdata)) } + if (unmarshal) { + bmr$unmarshal() + } + return(bmr) } diff --git a/R/worker.R b/R/worker.R index 600d414..613587c 100644 --- a/R/worker.R +++ b/R/worker.R @@ -9,6 +9,7 @@ run_learner = function(job, data, learner_hash, store_models, ...) { learner = learner, resampling = resampling, store_models = store_models, - lgr_threshold = lgr::get_logger("mlr3")$threshold + lgr_threshold = lgr::get_logger("mlr3")$threshold, + is_sequential = FALSE ) } diff --git a/man/reduceResultsBatchmark.Rd b/man/reduceResultsBatchmark.Rd index d1964e1..7eb04e2 100644 --- a/man/reduceResultsBatchmark.Rd +++ b/man/reduceResultsBatchmark.Rd @@ -8,7 +8,8 @@ reduceResultsBatchmark( ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), - fun = NULL + fun = NULL, + unmarshal = TRUE ) } \arguments{ @@ -34,6 +35,12 @@ Registry. If not explicitly passed, uses the default registry (see \code{\link[b \item{fun}{[\code{function}]\cr Function to apply to each result. The result is passed unnamed as first argument. If \code{NULL}, the identity is used. If the function has the formal argument \dQuote{job}, the \code{\link[batchtools]{Job}}/\code{\link[batchtools]{Experiment}} is also passed to the function.} + +\item{unmarshal}{\code{\link[mlr3]{Learner}}\cr +Whether to unmarshal learners that were marshaled during the execution. +Setting this to \code{FALSE} does not guarantee that the learners are marshaled. +For example, with sequential execution and no encapsulation, marshaling is not necessary. +If you want to ensure that all learners are in marshaled form, you need to call \verb{$marshal()} on the result object.} } \value{ \link[mlr3:BenchmarkResult]{mlr3::BenchmarkResult}. diff --git a/tests/testthat/test_batchmark.R b/tests/testthat/test_batchmark.R index 86b0153..b00e11e 100644 --- a/tests/testthat/test_batchmark.R +++ b/tests/testthat/test_batchmark.R @@ -86,3 +86,14 @@ test_that("failing jobs", { expect_data_table(as.data.table(results), nrow = 8L) expect_error(reduceResultsBatchmark(reg = reg, ids = ids), "successfully computed") }) + +test_that("marshaling", { + reg = batchtools::makeExperimentRegistry(NA) + batchmark(benchmark_grid(tsk("iris"), lrn("classif.debug"), rsmp("holdout")), store_models = TRUE) + submitJobs() + bmr_unmarshaled = reduceResultsBatchmark(unmarshal = TRUE) + bmr_marshaled = reduceResultsBatchmark(unmarshal = FALSE) + + expect_true(bmr_marshaled$resample_result(1)$learners[[1]]$marshaled) + expect_false(bmr_unmarshaled$resample_result(1)$learners[[1]]$marshaled) +})