From 4e74521d9166d5d468e1a26280abee522aee991a Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Thu, 16 Jun 2022 17:18:46 +1000 Subject: [PATCH 1/6] Add tidyselect and across multiple responses --- R/parse.R | 38 ++++++++++++++++++++++++++++++++++- R/utils.R | 4 ++++ tests/testthat/setup-models.R | 2 +- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/R/parse.R b/R/parse.R index 742cc094..5e9ce511 100644 --- a/R/parse.R +++ b/R/parse.R @@ -122,13 +122,18 @@ parse_model_rhs <- function(model){ #' @export parse_model_lhs <- function(model){ model_lhs <- model_lhs(model) - if(is_call(model_lhs) && call_name(model_lhs) == "vars"){ + if(is_call_name(model_lhs, c("vars", "c"))){ model_lhs[[1]] <- sym("exprs") model_lhs <- eval(model_lhs) } else{ model_lhs <- list(model_lhs) } + # Store response variable in a list + model_lhs <- model_lhs %>% + map(parse_tidyselect, model$data) %>% + map(parse_across, model$data) %>% + rlang::squash() is_resp <- function(x) is_call(x) && x[[1]] == sym("resp") # Traverse call removing all resp() usage @@ -253,4 +258,35 @@ Please specify a valid form of your transformation using `new_transformation()`. response = syms(responses), transformation = transformations ) +} + +parse_tidyselect <- function(lhs, data){ + pos <- try(tidyselect::eval_select(lhs, data), silent = TRUE) + if(class(pos) == "try-error"){ + if(is_call_name(lhs, "c")) { + warning(sprintf("Fail to parse %s. Check that the formula are specified correctly", deparse(lhs))) + } + return(lhs) + } + syms(names(pos)) +} + +parse_across <- function(lhs, data){ + if(!is_call_name(lhs, "across")) + return(lhs) + + unname_args <- lhs[names(lhs) == ""] %||% lhs + .cols <- lhs[[".cols"]] %||% unname_args[[2]] + .fns <- lhs[[".fns"]] %||% if(is.null(lhs[[".cols"]])) unname_args[[3]] else unname_args[[2]] + + .cols <- parse_tidyselect(.cols, data) + + if(is_call_name(.fns, "list")){ + .fns <- as.list(.fns)[-1] + } else { + .fns <- list(.fns) + } + + c(outer(.cols, .fns,FUN = function(cols, fns) + map2(cols, fns, function(col, fn) eval(expr( call2(fn, col)))))) } \ No newline at end of file diff --git a/R/utils.R b/R/utils.R index 1aedd6d8..1cd97405 100644 --- a/R/utils.R +++ b/R/utils.R @@ -2,6 +2,10 @@ names_no_null <- function(x){ names(x) %||% rep_along(x, "") } +is_call_name <- function(x, name){ + is.call(x) && call_name(x) %in% name +} + # Small function to combine named lists merge_named_list <- function(...){ flat <- flatten(list(...)) diff --git a/tests/testthat/setup-models.R b/tests/testthat/setup-models.R index bd877661..ee3ae448 100644 --- a/tests/testthat/setup-models.R +++ b/tests/testthat/setup-models.R @@ -25,4 +25,4 @@ no_specials <- function(formula, ...){ specials <- function(formula, ...){ specials_model <- new_model_class(model = "test model", train = test_train, specials = test_specials) new_model_definition(specials_model, !!enquo(formula), ...) -} \ No newline at end of file +} From d99ab89a6e3c46b98749fbac20541ea18050becd Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Thu, 16 Jun 2022 18:05:47 +1000 Subject: [PATCH 2/6] Fix VAR forecast if length > 1 error since R 4.2.0 --- R/forecast.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/forecast.R b/R/forecast.R index 022f0737..299562f5 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -220,14 +220,13 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL # )) # } }) - is_transformed <- vapply(bt, function(x) !is_symbol(body(x[[1]])), logical(1L)) if(length(bt) > 1) { if(any(is_transformed)){ abort("Transformations of multivariate forecasts are not yet supported") } } - if(is_transformed) { + if(all(is_transformed)) { if (identical(unique(dist_types(fc)), "dist_sample")) { fc <- vec_c(!!!mapply(exec, bt[[1]], fc)) } else { From 6e36be36cbc2a9559d154d4572887760b06c75a0 Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Thu, 16 Jun 2022 21:04:03 +1000 Subject: [PATCH 3/6] Fix tidyselect to choose without index --- R/parse.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/parse.R b/R/parse.R index 5e9ce511..4692a454 100644 --- a/R/parse.R +++ b/R/parse.R @@ -261,6 +261,8 @@ Please specify a valid form of your transformation using `new_transformation()`. } parse_tidyselect <- function(lhs, data){ + data <- as_tibble(data) %>% + select(-sym(index_var(data))) pos <- try(tidyselect::eval_select(lhs, data), silent = TRUE) if(class(pos) == "try-error"){ if(is_call_name(lhs, "c")) { From 89de6c231a315dff9ec0c756f82e0f58eaad568d Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Thu, 16 Jun 2022 22:08:47 +1000 Subject: [PATCH 4/6] Add tidyselect test --- tests/testthat/test-parser.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/testthat/test-parser.R b/tests/testthat/test-parser.R index 2604d5e3..6deeed43 100644 --- a/tests/testthat/test-parser.R +++ b/tests/testthat/test-parser.R @@ -55,6 +55,20 @@ test_that("Model parsing variety", { mdl3_trans <- parse_log3[[1]][[1]]$transformation[[1]] expect_identical(capture.output(mdl1_trans), capture.output(mdl3_trans)) expect_identical(response_vars(parse_log1), response_vars(parse_log3)) + + # Parse tidyselect multivariate lhs + skip_if_not_installed("fable") + parse_resp_multivariate <- response_vars(mbl_mv) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(vars(mdeaths, fdeaths) ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(mdeaths:fdeaths ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(everything() ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(ends_with("deaths") ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(all_of(c("mdeaths", "fdeaths")) ~ AR(3))))) + + # Parse tidyselect lhs + parse_resp_tidyselect <- response_vars(model(lung_deaths_wide_tr, fable::VAR(fdeaths ~ AR(3)))) + expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(!mdeaths ~ AR(3))))) + expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(2 ~ AR(3))))) }) From ce5441b49112b694029c9131f3c882cd29676e8c Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Fri, 17 Jun 2022 12:19:48 +1000 Subject: [PATCH 5/6] Fix #349 documentation link --- R/generate.R | 2 +- man/generate.mdl_df.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/generate.R b/R/generate.R index 843ef159..b7769fb4 100644 --- a/R/generate.R +++ b/R/generate.R @@ -2,7 +2,7 @@ #' #' Use a model's fitted distribution to simulate additional data with similar #' behaviour to the response. This is a tidy implementation of -#' `\link[stats]{simulate}`. +#' \code{\link[stats]{simulate}}. #' #' Innovations are sampled by the model's assumed error distribution. #' If `bootstrap` is `TRUE`, innovations will be sampled from the model's diff --git a/man/generate.mdl_df.Rd b/man/generate.mdl_df.Rd index b59a739f..c21f2834 100644 --- a/man/generate.mdl_df.Rd +++ b/man/generate.mdl_df.Rd @@ -39,7 +39,7 @@ time series with no exogenous regressors).} \description{ Use a model's fitted distribution to simulate additional data with similar behaviour to the response. This is a tidy implementation of -\verb{\link[stats]\{simulate\}}. +\code{\link[stats]{simulate}}. } \details{ Innovations are sampled by the model's assumed error distribution. From b5917f3ade4f7b7c0f4e5f442d013b08b40fd588 Mon Sep 17 00:00:00 2001 From: Fin Yang Date: Fri, 17 Jun 2022 14:42:42 +1000 Subject: [PATCH 6/6] across: Better arguments matching; Bug fix --- R/parse.R | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/R/parse.R b/R/parse.R index 4692a454..18a99acb 100644 --- a/R/parse.R +++ b/R/parse.R @@ -264,7 +264,7 @@ parse_tidyselect <- function(lhs, data){ data <- as_tibble(data) %>% select(-sym(index_var(data))) pos <- try(tidyselect::eval_select(lhs, data), silent = TRUE) - if(class(pos) == "try-error"){ + if(class(pos) == "try-error" || length(pos) == 0){ if(is_call_name(lhs, "c")) { warning(sprintf("Fail to parse %s. Check that the formula are specified correctly", deparse(lhs))) } @@ -277,9 +277,14 @@ parse_across <- function(lhs, data){ if(!is_call_name(lhs, "across")) return(lhs) - unname_args <- lhs[names(lhs) == ""] %||% lhs - .cols <- lhs[[".cols"]] %||% unname_args[[2]] - .fns <- lhs[[".fns"]] %||% if(is.null(lhs[[".cols"]])) unname_args[[3]] else unname_args[[2]] + across <- function(.cols, .fns = `(`){} + lhs <- rlang::call_match(lhs, across, defaults = TRUE) + + .cols <- lhs[[".cols"]] + .fns <- lhs[[".fns"]] + if(deparse(.cols) == "") { + abort("No variable selected in `across`.") + } .cols <- parse_tidyselect(.cols, data) @@ -291,4 +296,5 @@ parse_across <- function(lhs, data){ c(outer(.cols, .fns,FUN = function(cols, fns) map2(cols, fns, function(col, fn) eval(expr( call2(fn, col)))))) -} \ No newline at end of file + +}