Skip to content

Commit

Permalink
port robustness plot to ggplot2 (#214)
Browse files Browse the repository at this point in the history
* port robustness plot to ggplot2

* cleanup

* update description

* skip test that fails on only macos
  • Loading branch information
vandenman authored Oct 22, 2024
1 parent fc9e553 commit 792bed3
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 15 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Imports:
conting,
multibridge,
ggplot2,
interp,
jaspBase,
jaspGraphs,
plyr,
Expand Down
173 changes: 160 additions & 13 deletions R/abtestbayesian.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,14 @@ ABTestBayesianInternal <- function(jaspResults, dataset = NULL, options) {
mu_range = c(options$bfRobustnessPlotLowerPriorMean, options$bfRobustnessPlotUpperPriorMean)
sigma_range = c(options$bfRobustnessPlotLowerPriorSd, options$bfRobustnessPlotUpperPriorSd)

plotFunc <- function() {

abtest::plot_robustness(
x = ab_obj,
mu_steps = options$bfRobustnessPlotStepsPriorMean,
sigma_steps = options$bfRobustnessPlotStepsPriorSd,
mu_range = mu_range,
sigma_range = sigma_range,
bftype = options$bfRobustnessPlotType
)
}

return (plotFunc)
return(plot_robustness_ggplot2(
x = ab_obj,
mu_steps = options$bfRobustnessPlotStepsPriorMean,
sigma_steps = options$bfRobustnessPlotStepsPriorSd,
mu_range = mu_range,
sigma_range = sigma_range,
bftype = options$bfRobustnessPlotType
))
}


Expand Down Expand Up @@ -817,3 +812,155 @@ abtest_plot_posterior <- function (x, what = "logor", hypothesis = "H1", ci = 0.
}
par(op)
}



plot_robustness_ggplot2 <- function(x,
bftype = "BF10",
log = FALSE,
mu_range = c(0, 0.3),
sigma_range = c(0.25, 1),
mu_steps = 40,
sigma_steps = 40,
cores = 1,
...) {

# make sure that object is of class ab
if ( ! inherits(x, "ab")) {
stop("x needs to be of class 'ab'", call. = FALSE)
}

# check bftype
if ( ! bftype %in% c("BF10", "BF01", "BF+0", "BF0+", "BF-0", "BF0-")) {
stop("bftype needs to be either 'BF10', 'BF01', 'BF+0', 'BF0+', 'BF-0', or 'BF0-'",
call. = FALSE)
}

# check that sigma_range is positive
if (any(sigma_range <= 0)) {
stop("sigma_range may not contain values smaller or equal to 0",
call. = FALSE)
}

# conduct robustness check
mu <- seq(mu_range[1], mu_range[2], length.out = mu_steps)
sigma <- seq(sigma_range[1], sigma_range[2], length.out = sigma_steps)

prior_par_matrix <- expand.grid(mu, sigma)
colnames(prior_par_matrix) <- c("mu_psi", "sigma_psi")

# the function is identical to abtest::compute_logbf
logbfs <- apply(prior_par_matrix, 1, function(y, ab) {
prior_par_i <- list(mu_psi = y[["mu_psi"]],
sigma_psi = y[["sigma_psi"]],
mu_beta = ab$input$prior_par$mu_beta,
sigma_beta = ab$input$prior_par$sigma_beta)
abtest::ab_test(data = ab$input$data,
prior_par = prior_par_i,
prior_prob = ab$input$prior_prob,
nsamples = ab$input$nsamples,
is_df = ab$input$is_df,
posterior = FALSE)$logbf
}, ab = x)


if (bftype %in% c("BF10", "BF+0", "BF-0")) {

bfname <- switch(bftype,
"BF10" = "bf10",
"BF+0" = "bfplus0",
"BF-0" = "bfminus0")
bf <- vapply(logbfs, function(y) y[[bfname]], 1)

} else if (bftype %in% c("BF01", "BF0+", "BF0-")) {

bfname <- switch(bftype,
"BF01" = "bf10",
"BF0+" = "bfplus0",
"BF0-" = "bfminus0")
bf <- 1 / exp(vapply(logbfs, function(y) y[[bfname]], 1))
bf <- log(bf)

}

subscripts <- strsplit(bftype, split = "")[[1]][3:4]
if (log) {
colorbarName <- bquote("Log(" ~BF[.(subscripts[1])][.(subscripts[2])]~")")
} else {
bf <- exp(bf)
colorbarName <- bquote(BF[.(subscripts[1])][.(subscripts[2])])
}

df <- interp::interp(
x = unname(prior_par_matrix[, 1]),
y = unname(prior_par_matrix[, 2]),
z = c(bf),
nx = 1000, ny = 1000,
# the default method, "linear" is a bit too jagged in my opinion
method = "akima") |>
interp::interp2xyz() |>
as.data.frame()

xBreaks <- if (length(mu) <= 6) mu else jaspGraphs::getPrettyAxisBreaks(mu_range)
xLimits <- range(xBreaks)
xName <- expression(mu[psi])

yBreaks <- if (length(sigma) <= 6) sigma else jaspGraphs::getPrettyAxisBreaks(sigma_range)
yLimits <- range(yBreaks)
yName <- expression(sigma[psi])


rbf <- range(bf, na.rm = TRUE)

# for testing the labels
# rbf <- c(.18, 1.2) # 0.05
# rbf <- c(.18, 2.2) # 0.1
# rbf <- c(.18, 4.2) # 0.2

nTicks <- 20 # same as graphics::filled.contour
colorBreaks <- jaspGraphs::getPrettyAxisBreaks(rbf, nTicks)



deltaColorBreaks <- colorBreaks[2] - colorBreaks[1]
bases <- c(1, 2, 5)
whichbase <- which(vapply(bases, \(b) is.wholenumber(log10(deltaColorBreaks / b)), logical(1L)))
multipliers <- c(5, 5, 4)
multiplier <- multipliers[whichbase] * deltaColorBreaks

start <- mfloor(colorBreaks[1], multiplier)

colorLabels <- rep("", length(colorBreaks))

idx <- is.wholenumber((colorBreaks - start) / multiplier)
colorLabels[idx] <- scales::label_number()(colorBreaks[idx])

colorLimits <- range(colorBreaks)


ggplot2::ggplot(df, ggplot2::aes(x = x, y = y, z = z)) +
ggplot2::geom_contour_filled(ggplot2::aes(fill = ggplot2::after_stat(level_mid)), bins = 20) +
ggplot2::scale_fill_stepsn(
colorbarName,
breaks = colorBreaks,
limits = colorLimits,
labels = colorLabels,
colors = grDevices::hcl.colors(nTicks, "YlOrRd", rev = TRUE)) +
ggplot2::scale_x_continuous(name = xName, breaks = xBreaks, limits = xLimits, expand = c(0,0)) +
ggplot2::scale_y_continuous(name = yName, breaks = yBreaks, limits = yLimits, expand = c(0,0)) +
ggplot2::guides(fill = ggplot2::guide_colorbar(barheight = ggplot2::unit(1, "null"))) +
jaspGraphs::geom_rangeframe() +
jaspGraphs::themeJaspRaw() +
ggplot2::theme(
legend.position = "right",
axis.title.y.left = ggplot2::element_text(angle = 0, vjust = .5),
legend.frame = ggplot2::element_rect(color = "black", fill = NA, size = .4),
legend.ticks = ggplot2::element_line(size = .5),
legend.ticks.length = ggplot2::unit(1.0, "cm"),
panel.border = ggplot2::element_rect(colour = "black", fill = NA, size = .5),
plot.margin = ggplot2::margin(10, 10, 10, 10, "pt")
)
}

is.wholenumber <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
mfloor <- function(x, m) m * floor(x / m)
Loading

0 comments on commit 792bed3

Please sign in to comment.