Skip to content

Commit

Permalink
add matt's remaining ggplot additions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Nov 22, 2024
1 parent ac7d749 commit d176ee5
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 473 deletions.
4 changes: 3 additions & 1 deletion R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ utils::globalVariables(c("y", "year", "smooth_vals", "smooth_num",
"total_evd", "smooth_label", "by_variable",
"gr", "tot_subgrs", "subgr", "lambda",
"level", "sim_hilbert_gp", "trend_model",
"jags_path", "x"))
"jags_path", "x", "elpds", "pareto_ks",
"value", "threshold", "colour", "resids",
"c_dark", "eval_timepoints"))
116 changes: 51 additions & 65 deletions R/lfo_cv.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ lfo_cv.mvgam = function(object,
#' @importFrom graphics layout axis lines abline polygon points
#' @param x An object of class `mvgam_lfo`
#' @param ... Ignored
#' @return A base `R` plot of Pareto-k and ELPD values over the
#' @return A ggplot object of Pareto-k and ELPD values over the
#' evaluation timepoints. For the Pareto-k plot, a dashed red line indicates the
#' specified threshold chosen for triggering model refits. For the ELPD plot,
#' a dashed red line indicates the bottom 10% quantile of ELPD values. Points below
Expand All @@ -294,73 +294,59 @@ plot.mvgam_lfo = function(x, ...){

object <- x

# Graphical parameters
.pardefault <- par(no.readonly=T)
on.exit(par(.pardefault))
par(mfrow = c(2, 1))

# Plot Pareto-k values over time
object$pareto_ks[which(is.infinite(object$pareto_ks))] <-
max(object$pareto_ks[which(!is.infinite(object$pareto_ks))])
plot(1, type = "n", bty = 'L',
xlab = '',
ylab = 'Pareto k',
xaxt = 'n',
xlim = range(object$eval_timepoints),
ylim = c(min(object$pareto_ks) - 0.1,
max(object$pareto_ks) + 0.1))
axis(side = 1, labels = NA, lwd = 2)

lines(x = object$eval_timepoints,
y = object$pareto_ks,
lwd = 2.5)

abline(h = object$pareto_k_threshold, col = 'white', lwd = 2.85)
abline(h = object$pareto_k_threshold, col = "#A25050", lwd = 2.5, lty = 'dashed')

points(x = object$eval_timepoints,
y = object$pareto_ks, pch = 16, col = "white", cex = 1.25)
points(x = object$eval_timepoints,
y = object$pareto_ks, pch = 16, col = "black", cex = 1)

points(x = object$eval_timepoints[which(object$pareto_ks > object$pareto_k_threshold)],
y = object$pareto_ks[which(object$pareto_ks > object$pareto_k_threshold)],
pch = 16, col = "white", cex = 1.5)
points(x = object$eval_timepoints[which(object$pareto_ks > object$pareto_k_threshold)],
y = object$pareto_ks[which(object$pareto_ks > object$pareto_k_threshold)],
pch = 16, col = "#7C0000", cex = 1.25)

box(bty = 'l', lwd = 2)

# Plot ELPD values over time
plot(1, type = "n", bty = 'L',
xlab = 'Time point',
ylab = 'ELPD',
xlim = range(object$eval_timepoints),
ylim = c(min(object$elpds) - 0.1,
max(object$elpds) + 0.1))

lines(x = object$eval_timepoints,
y = object$elpds,
lwd = 2.5)

lower_vals <- quantile(object$elpds, probs = c(0.15))
abline(h = lower_vals, col = 'white', lwd = 2.85)
abline(h = lower_vals, col = "#A25050", lwd = 2.5, lty = 'dashed')

points(x = object$eval_timepoints,
y = object$elpds, pch = 16, col = "white", cex = 1.25)
points(x = object$eval_timepoints,
y = object$elpds, pch = 16, col = "black", cex = 1)

points(x = object$eval_timepoints[which(object$elpds < lower_vals)],
y = object$elpds[which(object$elpds < lower_vals)],
pch = 16, col = "white", cex = 1.5)
points(x = object$eval_timepoints[which(object$elpds < lower_vals)],
y = object$elpds[which(object$elpds < lower_vals)],
pch = 16, col = "#7C0000", cex = 1.25)

box(bty = 'l', lwd = 2)

dplyr::tibble(eval_timepoints = object$eval_timepoints,
elpds = object$elpds,
pareto_ks = object$pareto_ks) -> obj_tribble

# Hack so we don't have to import tidyr just to use pivot_longer once
dplyr::bind_rows(obj_tribble %>%
dplyr::select(eval_timepoints, elpds) %>%
dplyr::mutate(name = 'elpds', value = elpds) %>%
dplyr::select(-elpds),
obj_tribble %>%
dplyr::select(eval_timepoints, pareto_ks) %>%
dplyr::mutate(name = 'pareto_ks', value = pareto_ks) %>%
dplyr::select(-pareto_ks)) %>%
dplyr::left_join(
dplyr::tribble(~name, ~threshold,
"elpds", quantile(object$elpds, probs = 0.15),
"pareto_ks", object$pareto_k_threshold),
by = "name"
) %>%
dplyr::rowwise() %>%
dplyr::mutate(colour = dplyr::case_when(
name == 'elpds' & value < threshold ~ "outlier",
name == 'pareto_ks' & value > threshold ~ "outlier",
TRUE ~ "inlier"
)) %>%
dplyr::ungroup() %>%
ggplot2::ggplot(ggplot2::aes(eval_timepoints, value)) +
ggplot2::facet_wrap(~ factor(name,
levels = c("pareto_ks", "elpds"),
labels = c("Pareto K", "ELPD")),
ncol = 1,
scales = "free_y") +
ggplot2::geom_hline(ggplot2::aes(yintercept = threshold),
colour = "#A25050",
linetype = "dashed",
linewidth = 1) +
ggplot2::geom_line(linewidth = 0.5,
col = "grey30") +
ggplot2::geom_point(shape = 16,
colour = 'white',
size = 2) +
ggplot2::geom_point(ggplot2::aes(colour = colour),
shape = 16,
show.legend = F,
size = 1.5) +
ggplot2::scale_colour_manual(values = c("grey30", "#8F2727")) +
ggplot2::labs(x = "Evaluation time",
y = NULL) +
ggplot2::theme_bw()
}

#' Function to generate training and testing splits
Expand Down
3 changes: 2 additions & 1 deletion R/plot.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ plot.mvgam = function(x, type = 'residuals',
}

if(type == 'residuals'){
plot_mvgam_resids(object, series = series, data_test = data_test, ...)
suppressWarnings(plot(plot_mvgam_resids(object, series = series,
newdata = data_test, ...)))
}

if(type == 'factors'){
Expand Down
Loading

0 comments on commit d176ee5

Please sign in to comment.