Skip to content

Commit e5f95bd

Browse files
committed
Issue #156
1 parent 44f2477 commit e5f95bd

9 files changed

+65
-7
lines changed

NAMESPACE

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
S3method("[",workflow_set)
44
S3method("names<-",workflow_set)
55
S3method(autoplot,workflow_set)
6+
S3method(collect_extracts,workflow_set)
67
S3method(collect_metrics,workflow_set)
78
S3method(collect_notes,workflow_set)
89
S3method(collect_predictions,workflow_set)
@@ -38,6 +39,7 @@ S3method(vec_restore,workflow_set)
3839
export("%>%")
3940
export(as_workflow_set)
4041
export(autoplot)
42+
export(collect_extracts)
4143
export(collect_metrics)
4244
export(collect_notes)
4345
export(collect_predictions)
@@ -93,6 +95,7 @@ importFrom(stats,as.formula)
9395
importFrom(stats,model.frame)
9496
importFrom(stats,predict)
9597
importFrom(stats,qnorm)
98+
importFrom(tune,collect_extracts)
9699
importFrom(tune,collect_metrics)
97100
importFrom(tune,collect_notes)
98101
importFrom(tune,collect_predictions)

R/0_imports.R

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ tune::collect_predictions
3737
#' @export
3838
tune::collect_notes
3939

40+
#' @importFrom tune collect_extracts
41+
#' @export
42+
tune::collect_extracts
43+
4044
#' @importFrom dplyr %>%
4145
#' @export
4246
dplyr::`%>%`

R/collect.R

+15
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,18 @@ collect_notes.workflow_set <- function(x, ...) {
166166

167167
res
168168
}
169+
170+
#'
171+
#' @export
172+
#' @rdname collect_metrics.workflow_set
173+
collect_extracts.workflow_set <- function(x, ...) {
174+
check_incompete(x)
175+
176+
res <- dplyr::rowwise(x)
177+
res <- dplyr::mutate(res, extracts = list(collect_extracts(result)))
178+
res <- dplyr::ungroup(res)
179+
res <- dplyr::select(res, wflow_id, extracts)
180+
res <- tidyr::unnest(res, cols = extracts)
181+
182+
res
183+
}

man/collect_metrics.workflow_set.Rd

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

+4-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/workflow_set.Rd

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
skip_on_cran()
2+
3+
test_that("collect_extracts works", {
4+
set.seed(1)
5+
folds <- rsample::vfold_cv(mtcars, v = 3)
6+
7+
wflow_set <-
8+
workflow_set(
9+
list(reg = mpg ~ ., nonlin = mpg ~ wt + 1 / sqrt(disp)),
10+
list(lm = parsnip::linear_reg())
11+
)
12+
13+
wflow_set_trained <-
14+
wflow_set %>%
15+
workflow_map("fit_resamples",
16+
resamples = folds,
17+
control = tune::control_resamples(extract = function(x) { x })
18+
)
19+
20+
21+
extracts <- collect_extracts(wflow_set_trained)
22+
23+
expect_equal(nrow(extracts), 6)
24+
expect_contains(
25+
class(extracts$.extracts[[1]]), "workflow"
26+
)
27+
expect_named(extracts, c("wflow_id", "id", ".extracts", ".config"))
28+
})

tests/testthat/test-collect-notes.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ test_that("collect_notes works", {
2222

2323
expect_equal(nrow(notes), 6)
2424
expect_contains(notes$note, "hey!")
25-
expect_named(notes, c("wflow_id", "id", "location", "type", "note"))
25+
expect_named(notes, c("wflow_id", "id", "location", "type", "note", "trace"))
2626
})

tests/testthat/test-fit_best.R

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ test_that("fit_best fits with correct hyperparameters", {
4040

4141
manual_wf$pre$mold$blueprint$recipe$fit_times <-
4242
fit_best_wf$pre$mold$blueprint$recipe$fit_times
43+
manual_wf$fit$fit$elapsed$elapsed <-
44+
fit_best_wf$fit$fit$elapsed$elapsed
4345
expect_equal(manual_wf, fit_best_wf)
4446

4547
# metric: iic
@@ -53,6 +55,8 @@ test_that("fit_best fits with correct hyperparameters", {
5355

5456
manual_wf_2$pre$mold$blueprint$recipe$fit_times <-
5557
fit_best_wf_2$pre$mold$blueprint$recipe$fit_times
58+
manual_wf_2$fit$fit$elapsed$elapsed <-
59+
fit_best_wf_2$fit$fit$elapsed$elapsed
5660
expect_equal(manual_wf_2, fit_best_wf_2)
5761
})
5862

0 commit comments

Comments
 (0)