Skip to content

Commit 42ac56c

Browse files
committed
add marshalling
1 parent 1aed0c0 commit 42ac56c

File tree

5 files changed

+23
-3
lines changed

5 files changed

+23
-3
lines changed

DESCRIPTION

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Imports:
2727
Suggests:
2828
rpart,
2929
testthat
30+
Remotes:
31+
mlr-org/mlr3@bundle
3032
Encoding: UTF-8
3133
Roxygen: list(markdown = TRUE)
3234
RoxygenNote: 7.2.3.9000

R/reduceResultsBatchmark.R

+6-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#'
1515
#' @return [mlr3::BenchmarkResult].
1616
#' @export
17-
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry()) { # nolint
17+
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), unmarshal = TRUE) { # nolint
18+
assert_flag(unmarshal)
1819
if (is.null(ids)) {
1920
ids = batchtools::findDone(ids, reg = reg)
2021
} else {
@@ -87,5 +88,9 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch
8788
bmr$combine(mlr3::BenchmarkResult$new(rdata))
8889
}
8990

91+
if (unmarshal) {
92+
bmr$unmarshal()
93+
}
94+
9095
return(bmr)
9196
}

R/worker.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ run_learner = function(job, data, learner_hash, store_models, ...) {
99
learner = learner,
1010
resampling = resampling,
1111
store_models = store_models,
12-
lgr_threshold = lgr::get_logger("mlr3")$threshold
12+
lgr_threshold = lgr::get_logger("mlr3")$threshold,
13+
is_sequential = FALSE
1314
)
1415
}

man/reduceResultsBatchmark.Rd

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_batchmark.R

+11
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,14 @@ test_that("failing jobs", {
8686
expect_data_table(as.data.table(results), nrow = 8L)
8787
expect_error(reduceResultsBatchmark(reg = reg, ids = ids), "successfully computed")
8888
})
89+
90+
test_that("marshalling", {
91+
reg = batchtools::makeExperimentRegistry(NA)
92+
batchmark(benchmark_grid(tsk("iris"), lrn("classif.lily"), rsmp("holdout")), store_models = TRUE)
93+
submitJobs()
94+
bmr_unmarshalled = reduceResultsBatchmark(unmarshal = TRUE)
95+
bmr_marshalled = reduceResultsBatchmark(unmarshal = FALSE)
96+
97+
expect_true(bmr_marshalled$resample_result(1)$learners[[1]]$marshalled)
98+
expect_false(bmr_unmarshalled$resample_result(1)$learners[[1]]$marshalled)
99+
})

0 commit comments

Comments
 (0)