-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Hi there!
The winter wind howls... and with it, a fresh bug has slipped under the door.
I was working on forecasting hierarchical time series and used the cross-validation feature. However, when computing accuracy metrics via accuracy(), there was an issue in grouping by if not including the keys of the data.
Here's the reproducible example:
library(tidyverse)
library(fpp3)
# --- FIT AND FORECAST ---
tourism_full <- tourism %>%
aggregate_key((State/Region), Trips = sum(Trips))
fit <- tourism_full %>%
slide_tsibble(.size = 60) %>%
relocate(.id) %>%
filter(.id <= 2) %>% # only size 2 for Cross-validation
# modelling
model(base = ETS(Trips)) |>
reconcile(
ols = min_trace(base, method = "ols"),
mint = min_trace(base, method = "mint_shrink")
)
fc <- fit %>%
forecast(h = 4) %>%
group_by(.id, .model, State, Region) %>%
mutate(h = row_number()) %>%
ungroup() %>%
as_fable(response = "Trips", distribution = Trips)
# --- ACCURACY ---
# Directly group_by h and model
accuracy_1 <- fc %>%
accuracy(
tourism_full,
by = c("h", ".model"),
measures = list(RMSE = RMSE)
)
# Group_by h, model, series; then summarise later
accuracy_2 <- fc %>%
accuracy(
tourism_full,
by = c("h", ".model", "State", "Region"),
measures = list(RMSE = RMSE)
) %>%
group_by(.model, h) %>%
summarise(RMSE = sqrt(mean(RMSE^2)), .groups = "drop")
# Plotting
accuracy_1 %>%
ggplot(aes(x = h, y = RMSE, color = .model)) +
geom_line()
accuracy_2 %>%
ggplot(aes(x = h, y = RMSE, color = .model)) +
geom_line()Theoretically the numbers should be the same. Diving into the codes of accuracy(), there’s an issue in the 1st call.
As we group by h and model, there’s a point (here) where the function left-join the forecast data with the actual data based on c("h", ".model") as our argument.
But, the keys of forecast data are .id, State, Region, .model and keys of actual data are State, Region. And the accuracy() doesn’t take that into account:
# --- This is forecast data
# A tsibble: 2,040 x 8 [1Q]
# Key: .id, State, Region, .model [510]
.id State Region .fc .dist Quarter h .model
<int> <chr*> <chr*> <dbl> <dist> <qtr> <int> <chr>
1 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base
2 1 ACT Canberra 486. N(486, 3706) 2013 Q2 2 base
3 1 ACT Canberra 486. N(486, 3706) 2013 Q3 3 base
# ℹ 2,037 more rows
# --- This is actual data for calculating resid
# A tsibble: 6,800 x 4 [1Q]
# Key: State, Region [85]
State Region Quarter .actual
<chr*> <chr*> <qtr> <dbl>
1 <aggregated> <aggregated> 1998 Q1 23182.
2 <aggregated> <aggregated> 1998 Q2 20323.
3 <aggregated> <aggregated> 1998 Q3 19827.
# ℹ 6,797 more rows
# --- After left-join
# A tibble: 173,400 × 11
.id State Region .fc .dist Quarter h .model State.y Region.y .actual
<int> <chr*> <chr*> <dbl> <dist> <qtr> <int> <chr> <chr*> <chr*> <dbl>
1 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base <aggregated> <aggregated> 21984.
2 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base ACT <aggregated> 525.
3 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base New South Wales <aggregated> 6971.
4 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base Northern Territory <aggregated> 133.
5 1 ACT Canberra 486. N(486, 3706) 2013 Q1 1 base Queensland <aggregated> 4779.
# ℹ 173,395 more rowsWe actually left-join based on Quarter, leading to a tibble with 173,400 rows.
This was my own investigation, please let me know if I'm missing anything! :)