Skip to content

Commit

Permalink
add marshalling (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Apr 25, 2024
1 parent 5b2b081 commit a8820b2
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Imports:
checkmate,
data.table,
lgr,
mlr3 (>= 0.17.0),
mlr3 (>= 0.19.0),
mlr3misc,
uuid
Suggests:
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`, useful for deleting model data to avoid running out of memory, https://github.com/mlr-org/mlr3batchmark/issues/18 Thanks to Toby Dylan Hocking @tdhock for the PR.
* docs: A warning is now given when the loaded mlr3 version differs from the
mlr3 version stored in the trained learners
* Support marshaling

# mlr3batchmark 0.1.1

Expand Down
7 changes: 6 additions & 1 deletion R/reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#'
#' @return [mlr3::BenchmarkResult].
#' @export
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), fun=NULL) { # nolint
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), fun = NULL, unmarshal = TRUE) { # nolint
assert_flag(unmarshal)
if (is.null(ids)) {
ids = batchtools::findDone(ids, reg = reg)
} else {
Expand Down Expand Up @@ -87,5 +88,9 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch
bmr$combine(mlr3::BenchmarkResult$new(rdata))
}

if (unmarshal) {
bmr$unmarshal()
}

return(bmr)
}
3 changes: 2 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ run_learner = function(job, data, learner_hash, store_models, ...) {
learner = learner,
resampling = resampling,
store_models = store_models,
lgr_threshold = lgr::get_logger("mlr3")$threshold
lgr_threshold = lgr::get_logger("mlr3")$threshold,
is_sequential = FALSE
)
}
9 changes: 8 additions & 1 deletion man/reduceResultsBatchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_batchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,14 @@ test_that("failing jobs", {
expect_data_table(as.data.table(results), nrow = 8L)
expect_error(reduceResultsBatchmark(reg = reg, ids = ids), "successfully computed")
})

test_that("marshaling", {
reg = batchtools::makeExperimentRegistry(NA)
batchmark(benchmark_grid(tsk("iris"), lrn("classif.debug"), rsmp("holdout")), store_models = TRUE)
submitJobs()
bmr_unmarshaled = reduceResultsBatchmark(unmarshal = TRUE)
bmr_marshaled = reduceResultsBatchmark(unmarshal = FALSE)

expect_true(bmr_marshaled$resample_result(1)$learners[[1]]$marshaled)
expect_false(bmr_unmarshaled$resample_result(1)$learners[[1]]$marshaled)
})

0 comments on commit a8820b2

Please sign in to comment.