Skip to content

Commit

Permalink
show how to do this with fitted_values
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinsimpson committed Nov 24, 2023
1 parent db120bd commit d5556dc
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions day-5/distributional-gam-example.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ vapply(pkgs, library, logical(1), character.only = TRUE, logical.return = TRUE,
# Load the mcycle data
data(mcycle, package = "MASS")

# plot the data
mcycle |>
ggplot(aes(x = times, y = accel)) +
geom_point()

# Fit a GAM with linear predictors for the mean and the variance of accel
m_dist1 <- gam(list(accel ~ s(times, k = 20, bs = "ad"),
~ s(times, k = 10)),
Expand All @@ -17,8 +22,7 @@ m_dist1 <- gam(list(accel ~ s(times, k = 20, bs = "ad"),
family = gaulss())

## we can use draw() and appraise() as usual
## don't want the uncertainty in the intercept here
draw(m_dist1, overall_uncertainty = FALSE)
draw(m_dist1, overall_uncertainty = TRUE)

## model diagnostics
appraise(m_dist1)
Expand Down Expand Up @@ -52,10 +56,38 @@ AIC(m_dist1, m_dist2, m_dist3)
## new data to predict at
new_df <- with(mcycle,
tibble(times = seq(round(min(times)), round(max(times)),
length.out = 200)))
length.out = 200),
.row = seq_len(200)))

## sorry, doesn't work yet - actually it does now for selected distributions
fv <- fitted_values(m_dist1, data = new_df)

mu_plt <- fv |>
filter(.parameter == "location") |>
left_join(new_df, by = join_by(".row" == ".row")) |>
ggplot(aes(x = times, y = .fitted)) +
geom_point(data = mcycle, aes(x = times, y = accel)) +
geom_ribbon(aes(ymin = .lower_ci, ymax = .upper_ci), alpha = 0.2) +
geom_line() +
labs(y = "Acceleration", x = "Milliseconds after impact")

## have something to plot for "data" for the std dev plot
## take the absolute value of the response residual
res_data <- mcycle %>%
mutate(abs_residual = abs(accel - fitted(m_dist1)[,1]))

## sorry, doesn't work yet
fitted_values(m_dist1)
sd_plt <- fv |>
filter(.parameter == "scale") |>
left_join(new_df, by = join_by(".row" == ".row")) |>
mutate(across(all_of(c(".fitted", ".lower_ci", ".upper_ci")),
.fns = ~ 1 / .x)) |>
ggplot(aes(x = times, y = .fitted)) +
geom_point(data = res_data, aes(x = times, y = abs_residual)) +
geom_ribbon(aes(ymin = .lower_ci, ymax = .upper_ci), alpha = 0.2) +
geom_line(aes(y = .fitted)) +
labs(y = "Std. Deviation", x = "Milliseconds after impact")

mu_plt + sd_plt

## so we go back to our recipe
pred <- predict(m_dist1, newdata = new_df, se.fit = TRUE, type = "link")
Expand Down

0 comments on commit d5556dc

Please sign in to comment.