Skip to content

Commit 6023b3e

Browse files
authored
tests(marshaling): add parallelization tests (#281)
1 parent 2f48ac7 commit 6023b3e

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

tests/testthat/test_LearnerTorch.R

+17-2
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ test_that("resample() works", {
417417
expect_r6(rr, "ResampleResult")
418418
})
419419

420-
test_that("callr encapsulation and marshaling", {
421-
skip_if_not_installed("callr")
420+
test_that("marshaling", {
422421
task = tsk("mtcars")$filter(1:5)
423422
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"),
424423
neurons = 20
@@ -427,14 +426,30 @@ test_that("callr encapsulation and marshaling", {
427426
expect_false(learner$marshaled)
428427
learner$marshal()$unmarshal()
429428
expect_prediction(learner$predict(task))
429+
})
430430

431+
test_that("callr encapsulation and marshaling", {
432+
skip_if_not_installed("callr")
433+
task = tsk("mtcars")$filter(1:5)
431434
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"),
432435
neurons = 20
433436
)
434437
learner$train(task)
435438
expect_prediction(learner$predict(task))
436439
})
437440

441+
test_that("future and marshaling", {
442+
skip_if_not_installed("future")
443+
task = tsk("mtcars")$filter(1:5)
444+
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu",
445+
neurons = 20
446+
)
447+
rr = with_future(future::multisession, {
448+
resample(task, learner, rsmp("holdout"))
449+
})
450+
expect_class(rr, "ResampleResult")
451+
})
452+
438453
test_that("Input verification works during `$train()` (train-predict shapes work together)", {
439454
task = nano_mnist()
440455

0 commit comments

Comments
 (0)