diff --git a/R/forecast.R b/R/forecast.R index cb88acbb..21d2e41a 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -220,7 +220,6 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL # )) # } }) - is_transformed <- vapply(bt, function(x) !is_symbol(body(x[[1]])), logical(1L)) if(length(bt) > 1) { if(any(is_transformed)){ diff --git a/R/generate.R b/R/generate.R index 843ef159..b7769fb4 100644 --- a/R/generate.R +++ b/R/generate.R @@ -2,7 +2,7 @@ #' #' Use a model's fitted distribution to simulate additional data with similar #' behaviour to the response. This is a tidy implementation of -#' `\link[stats]{simulate}`. +#' \code{\link[stats]{simulate}}. #' #' Innovations are sampled by the model's assumed error distribution. #' If `bootstrap` is `TRUE`, innovations will be sampled from the model's diff --git a/R/parse.R b/R/parse.R index 1ffa539b..a456c9d0 100644 --- a/R/parse.R +++ b/R/parse.R @@ -120,13 +120,18 @@ parse_model_rhs <- function(model){ #' @keywords internal parse_model_lhs <- function(model){ model_lhs <- model_lhs(model) - if(is_call(model_lhs) && call_name(model_lhs) == "vars"){ + if(is_call_name(model_lhs, c("vars", "c"))){ model_lhs[[1]] <- sym("exprs") model_lhs <- eval(model_lhs) } else{ model_lhs <- list(model_lhs) } + # Store response variable in a list + model_lhs <- model_lhs %>% + map(parse_tidyselect, model$data) %>% + map(parse_across, model$data) %>% + rlang::squash() is_resp <- function(x) is_call(x) && x[[1]] == sym("resp") # Traverse call removing all resp() usage @@ -251,4 +256,43 @@ Please specify a valid form of your transformation using `new_transformation()`. response = syms(responses), transformation = transformations ) -} \ No newline at end of file +} + +parse_tidyselect <- function(lhs, data){ + data <- as_tibble(data) %>% + select(-sym(index_var(data))) + pos <- try(tidyselect::eval_select(lhs, data), silent = TRUE) + if(class(pos) == "try-error" || length(pos) == 0){ + if(is_call_name(lhs, "c")) { + warning(sprintf("Fail to parse %s. Check that the formula are specified correctly", deparse(lhs))) + } + return(lhs) + } + syms(names(pos)) +} + +parse_across <- function(lhs, data){ + if(!is_call_name(lhs, "across")) + return(lhs) + + across <- function(.cols, .fns = `(`){} + lhs <- rlang::call_match(lhs, across, defaults = TRUE) + + .cols <- lhs[[".cols"]] + .fns <- lhs[[".fns"]] + if(deparse(.cols) == "") { + abort("No variable selected in `across`.") + } + + .cols <- parse_tidyselect(.cols, data) + + if(is_call_name(.fns, "list")){ + .fns <- as.list(.fns)[-1] + } else { + .fns <- list(.fns) + } + + c(outer(.cols, .fns,FUN = function(cols, fns) + map2(cols, fns, function(col, fn) eval(expr( call2(fn, col)))))) + +} diff --git a/R/utils.R b/R/utils.R index 1aedd6d8..1cd97405 100644 --- a/R/utils.R +++ b/R/utils.R @@ -2,6 +2,10 @@ names_no_null <- function(x){ names(x) %||% rep_along(x, "") } +is_call_name <- function(x, name){ + is.call(x) && call_name(x) %in% name +} + # Small function to combine named lists merge_named_list <- function(...){ flat <- flatten(list(...)) diff --git a/man/generate.mdl_df.Rd b/man/generate.mdl_df.Rd index b59a739f..c21f2834 100644 --- a/man/generate.mdl_df.Rd +++ b/man/generate.mdl_df.Rd @@ -39,7 +39,7 @@ time series with no exogenous regressors).} \description{ Use a model's fitted distribution to simulate additional data with similar behaviour to the response. This is a tidy implementation of -\verb{\link[stats]\{simulate\}}. +\code{\link[stats]{simulate}}. } \details{ Innovations are sampled by the model's assumed error distribution. diff --git a/tests/testthat/setup-models.R b/tests/testthat/setup-models.R index bd877661..ee3ae448 100644 --- a/tests/testthat/setup-models.R +++ b/tests/testthat/setup-models.R @@ -25,4 +25,4 @@ no_specials <- function(formula, ...){ specials <- function(formula, ...){ specials_model <- new_model_class(model = "test model", train = test_train, specials = test_specials) new_model_definition(specials_model, !!enquo(formula), ...) -} \ No newline at end of file +} diff --git a/tests/testthat/test-parser.R b/tests/testthat/test-parser.R index 2604d5e3..6deeed43 100644 --- a/tests/testthat/test-parser.R +++ b/tests/testthat/test-parser.R @@ -55,6 +55,20 @@ test_that("Model parsing variety", { mdl3_trans <- parse_log3[[1]][[1]]$transformation[[1]] expect_identical(capture.output(mdl1_trans), capture.output(mdl3_trans)) expect_identical(response_vars(parse_log1), response_vars(parse_log3)) + + # Parse tidyselect multivariate lhs + skip_if_not_installed("fable") + parse_resp_multivariate <- response_vars(mbl_mv) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(vars(mdeaths, fdeaths) ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(mdeaths:fdeaths ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(everything() ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(ends_with("deaths") ~ AR(3))))) + expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(all_of(c("mdeaths", "fdeaths")) ~ AR(3))))) + + # Parse tidyselect lhs + parse_resp_tidyselect <- response_vars(model(lung_deaths_wide_tr, fable::VAR(fdeaths ~ AR(3)))) + expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(!mdeaths ~ AR(3))))) + expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(2 ~ AR(3))))) })