Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL
# ))
# }
})

is_transformed <- vapply(bt, function(x) !is_symbol(body(x[[1]])), logical(1L))
if(length(bt) > 1) {
if(any(is_transformed)){
Expand Down
2 changes: 1 addition & 1 deletion R/generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Use a model's fitted distribution to simulate additional data with similar
#' behaviour to the response. This is a tidy implementation of
#' `\link[stats]{simulate}`.
#' \code{\link[stats]{simulate}}.
#'
#' Innovations are sampled by the model's assumed error distribution.
#' If `bootstrap` is `TRUE`, innovations will be sampled from the model's
Expand Down
48 changes: 46 additions & 2 deletions R/parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,18 @@ parse_model_rhs <- function(model){
#' @keywords internal
parse_model_lhs <- function(model){
model_lhs <- model_lhs(model)
if(is_call(model_lhs) && call_name(model_lhs) == "vars"){
if(is_call_name(model_lhs, c("vars", "c"))){
model_lhs[[1]] <- sym("exprs")
model_lhs <- eval(model_lhs)
}
else{
model_lhs <- list(model_lhs)
}
# Store response variable in a list
model_lhs <- model_lhs %>%
map(parse_tidyselect, model$data) %>%
map(parse_across, model$data) %>%
rlang::squash()
is_resp <- function(x) is_call(x) && x[[1]] == sym("resp")

# Traverse call removing all resp() usage
Expand Down Expand Up @@ -251,4 +256,43 @@ Please specify a valid form of your transformation using `new_transformation()`.
response = syms(responses),
transformation = transformations
)
}
}

parse_tidyselect <- function(lhs, data){
data <- as_tibble(data) %>%
select(-sym(index_var(data)))
pos <- try(tidyselect::eval_select(lhs, data), silent = TRUE)
if(class(pos) == "try-error" || length(pos) == 0){
if(is_call_name(lhs, "c")) {
warning(sprintf("Fail to parse %s. Check that the formula are specified correctly", deparse(lhs)))
}
return(lhs)
}
syms(names(pos))
}

parse_across <- function(lhs, data){
if(!is_call_name(lhs, "across"))
return(lhs)

across <- function(.cols, .fns = `(`){}
lhs <- rlang::call_match(lhs, across, defaults = TRUE)

.cols <- lhs[[".cols"]]
.fns <- lhs[[".fns"]]
if(deparse(.cols) == "") {
abort("No variable selected in `across`.")
}

.cols <- parse_tidyselect(.cols, data)

if(is_call_name(.fns, "list")){
.fns <- as.list(.fns)[-1]
} else {
.fns <- list(.fns)
}

c(outer(.cols, .fns,FUN = function(cols, fns)
map2(cols, fns, function(col, fn) eval(expr( call2(fn, col))))))

}
4 changes: 4 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ names_no_null <- function(x){
names(x) %||% rep_along(x, "")
}

is_call_name <- function(x, name){
is.call(x) && call_name(x) %in% name
}

# Small function to combine named lists
merge_named_list <- function(...){
flat <- flatten(list(...))
Expand Down
2 changes: 1 addition & 1 deletion man/generate.mdl_df.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/setup-models.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ no_specials <- function(formula, ...){
specials <- function(formula, ...){
specials_model <- new_model_class(model = "test model", train = test_train, specials = test_specials)
new_model_definition(specials_model, !!enquo(formula), ...)
}
}
14 changes: 14 additions & 0 deletions tests/testthat/test-parser.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ test_that("Model parsing variety", {
mdl3_trans <- parse_log3[[1]][[1]]$transformation[[1]]
expect_identical(capture.output(mdl1_trans), capture.output(mdl3_trans))
expect_identical(response_vars(parse_log1), response_vars(parse_log3))

# Parse tidyselect multivariate lhs
skip_if_not_installed("fable")
parse_resp_multivariate <- response_vars(mbl_mv)
expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(vars(mdeaths, fdeaths) ~ AR(3)))))
expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(mdeaths:fdeaths ~ AR(3)))))
expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(everything() ~ AR(3)))))
expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(ends_with("deaths") ~ AR(3)))))
expect_identical(parse_resp_multivariate, response_vars(model(lung_deaths_wide_tr, fable::VAR(all_of(c("mdeaths", "fdeaths")) ~ AR(3)))))

# Parse tidyselect lhs
parse_resp_tidyselect <- response_vars(model(lung_deaths_wide_tr, fable::VAR(fdeaths ~ AR(3))))
expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(!mdeaths ~ AR(3)))))
expect_identical(parse_resp_tidyselect, response_vars(model(lung_deaths_wide_tr, fable::VAR(2 ~ AR(3)))))
})


Expand Down