From 792bed350af059430a4dc3cd5c187c8ccf54e8f5 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 22 Oct 2024 16:18:17 +0200 Subject: [PATCH] port robustness plot to ggplot2 (#214) * port robustness plot to ggplot2 * cleanup * update description * skip test that fails on only macos --- DESCRIPTION | 1 + R/abtestbayesian.R | 173 ++++++++++++++++-- .../_snaps/abtestbayesian/robustness.svg | 147 +++++++++++++++ tests/testthat/test-abtestbayesian.R | 5 +- 4 files changed, 311 insertions(+), 15 deletions(-) create mode 100644 tests/testthat/_snaps/abtestbayesian/robustness.svg diff --git a/DESCRIPTION b/DESCRIPTION index 1fea33a0..78ccf90a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -16,6 +16,7 @@ Imports: conting, multibridge, ggplot2, + interp, jaspBase, jaspGraphs, plyr, diff --git a/R/abtestbayesian.R b/R/abtestbayesian.R index a6c40ed2..a54ebc1d 100644 --- a/R/abtestbayesian.R +++ b/R/abtestbayesian.R @@ -358,19 +358,14 @@ ABTestBayesianInternal <- function(jaspResults, dataset = NULL, options) { mu_range = c(options$bfRobustnessPlotLowerPriorMean, options$bfRobustnessPlotUpperPriorMean) sigma_range = c(options$bfRobustnessPlotLowerPriorSd, options$bfRobustnessPlotUpperPriorSd) - plotFunc <- function() { - - abtest::plot_robustness( - x = ab_obj, - mu_steps = options$bfRobustnessPlotStepsPriorMean, - sigma_steps = options$bfRobustnessPlotStepsPriorSd, - mu_range = mu_range, - sigma_range = sigma_range, - bftype = options$bfRobustnessPlotType - ) - } - - return (plotFunc) + return(plot_robustness_ggplot2( + x = ab_obj, + mu_steps = options$bfRobustnessPlotStepsPriorMean, + sigma_steps = options$bfRobustnessPlotStepsPriorSd, + mu_range = mu_range, + sigma_range = sigma_range, + bftype = options$bfRobustnessPlotType + )) } @@ -817,3 +812,155 @@ abtest_plot_posterior <- function (x, what = "logor", hypothesis = "H1", ci = 0. } par(op) } + + + +plot_robustness_ggplot2 <- function(x, + bftype = "BF10", + log = FALSE, + mu_range = c(0, 0.3), + sigma_range = c(0.25, 1), + mu_steps = 40, + sigma_steps = 40, + cores = 1, + ...) { + + # make sure that object is of class ab + if ( ! inherits(x, "ab")) { + stop("x needs to be of class 'ab'", call. = FALSE) + } + + # check bftype + if ( ! bftype %in% c("BF10", "BF01", "BF+0", "BF0+", "BF-0", "BF0-")) { + stop("bftype needs to be either 'BF10', 'BF01', 'BF+0', 'BF0+', 'BF-0', or 'BF0-'", + call. = FALSE) + } + + # check that sigma_range is positive + if (any(sigma_range <= 0)) { + stop("sigma_range may not contain values smaller or equal to 0", + call. = FALSE) + } + + # conduct robustness check + mu <- seq(mu_range[1], mu_range[2], length.out = mu_steps) + sigma <- seq(sigma_range[1], sigma_range[2], length.out = sigma_steps) + + prior_par_matrix <- expand.grid(mu, sigma) + colnames(prior_par_matrix) <- c("mu_psi", "sigma_psi") + + # the function is identical to abtest::compute_logbf + logbfs <- apply(prior_par_matrix, 1, function(y, ab) { + prior_par_i <- list(mu_psi = y[["mu_psi"]], + sigma_psi = y[["sigma_psi"]], + mu_beta = ab$input$prior_par$mu_beta, + sigma_beta = ab$input$prior_par$sigma_beta) + abtest::ab_test(data = ab$input$data, + prior_par = prior_par_i, + prior_prob = ab$input$prior_prob, + nsamples = ab$input$nsamples, + is_df = ab$input$is_df, + posterior = FALSE)$logbf + }, ab = x) + + + if (bftype %in% c("BF10", "BF+0", "BF-0")) { + + bfname <- switch(bftype, + "BF10" = "bf10", + "BF+0" = "bfplus0", + "BF-0" = "bfminus0") + bf <- vapply(logbfs, function(y) y[[bfname]], 1) + + } else if (bftype %in% c("BF01", "BF0+", "BF0-")) { + + bfname <- switch(bftype, + "BF01" = "bf10", + "BF0+" = "bfplus0", + "BF0-" = "bfminus0") + bf <- 1 / exp(vapply(logbfs, function(y) y[[bfname]], 1)) + bf <- log(bf) + + } + + subscripts <- strsplit(bftype, split = "")[[1]][3:4] + if (log) { + colorbarName <- bquote("Log(" ~BF[.(subscripts[1])][.(subscripts[2])]~")") + } else { + bf <- exp(bf) + colorbarName <- bquote(BF[.(subscripts[1])][.(subscripts[2])]) + } + + df <- interp::interp( + x = unname(prior_par_matrix[, 1]), + y = unname(prior_par_matrix[, 2]), + z = c(bf), + nx = 1000, ny = 1000, + # the default method, "linear" is a bit too jagged in my opinion + method = "akima") |> + interp::interp2xyz() |> + as.data.frame() + + xBreaks <- if (length(mu) <= 6) mu else jaspGraphs::getPrettyAxisBreaks(mu_range) + xLimits <- range(xBreaks) + xName <- expression(mu[psi]) + + yBreaks <- if (length(sigma) <= 6) sigma else jaspGraphs::getPrettyAxisBreaks(sigma_range) + yLimits <- range(yBreaks) + yName <- expression(sigma[psi]) + + + rbf <- range(bf, na.rm = TRUE) + + # for testing the labels + # rbf <- c(.18, 1.2) # 0.05 + # rbf <- c(.18, 2.2) # 0.1 + # rbf <- c(.18, 4.2) # 0.2 + + nTicks <- 20 # same as graphics::filled.contour + colorBreaks <- jaspGraphs::getPrettyAxisBreaks(rbf, nTicks) + + + + deltaColorBreaks <- colorBreaks[2] - colorBreaks[1] + bases <- c(1, 2, 5) + whichbase <- which(vapply(bases, \(b) is.wholenumber(log10(deltaColorBreaks / b)), logical(1L))) + multipliers <- c(5, 5, 4) + multiplier <- multipliers[whichbase] * deltaColorBreaks + + start <- mfloor(colorBreaks[1], multiplier) + + colorLabels <- rep("", length(colorBreaks)) + + idx <- is.wholenumber((colorBreaks - start) / multiplier) + colorLabels[idx] <- scales::label_number()(colorBreaks[idx]) + + colorLimits <- range(colorBreaks) + + + ggplot2::ggplot(df, ggplot2::aes(x = x, y = y, z = z)) + + ggplot2::geom_contour_filled(ggplot2::aes(fill = ggplot2::after_stat(level_mid)), bins = 20) + + ggplot2::scale_fill_stepsn( + colorbarName, + breaks = colorBreaks, + limits = colorLimits, + labels = colorLabels, + colors = grDevices::hcl.colors(nTicks, "YlOrRd", rev = TRUE)) + + ggplot2::scale_x_continuous(name = xName, breaks = xBreaks, limits = xLimits, expand = c(0,0)) + + ggplot2::scale_y_continuous(name = yName, breaks = yBreaks, limits = yLimits, expand = c(0,0)) + + ggplot2::guides(fill = ggplot2::guide_colorbar(barheight = ggplot2::unit(1, "null"))) + + jaspGraphs::geom_rangeframe() + + jaspGraphs::themeJaspRaw() + + ggplot2::theme( + legend.position = "right", + axis.title.y.left = ggplot2::element_text(angle = 0, vjust = .5), + legend.frame = ggplot2::element_rect(color = "black", fill = NA, size = .4), + legend.ticks = ggplot2::element_line(size = .5), + legend.ticks.length = ggplot2::unit(1.0, "cm"), + panel.border = ggplot2::element_rect(colour = "black", fill = NA, size = .5), + plot.margin = ggplot2::margin(10, 10, 10, 10, "pt") + ) +} + +is.wholenumber <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol +mfloor <- function(x, m) m * floor(x / m) diff --git a/tests/testthat/_snaps/abtestbayesian/robustness.svg b/tests/testthat/_snaps/abtestbayesian/robustness.svg new file mode 100644 index 00000000..4b5d2402 --- /dev/null +++ b/tests/testthat/_snaps/abtestbayesian/robustness.svg @@ -0,0 +1,147 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0.100 +0.325 +0.550 +0.775 +1.000 + + + + + + + + + + + +-0.50 +-0.25 +0.00 +0.25 +0.50 +μ +ψ +σ +ψ + + +B +F +1 +0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0.80 +0.90 +1.00 +1.10 +1.20 + + diff --git a/tests/testthat/test-abtestbayesian.R b/tests/testthat/test-abtestbayesian.R index 3b7c9734..5e940e7d 100644 --- a/tests/testthat/test-abtestbayesian.R +++ b/tests/testthat/test-abtestbayesian.R @@ -160,8 +160,8 @@ test_that("Sequential plot matches", { test_that("plotRobustness plot matches", { - skip("Have to set a global theme.") set.seed(0) + testthat::skip_on_os("mac") options <- jaspTools::analysisOptions("ABTestBayesian") options$n1 <- "n1" @@ -173,8 +173,9 @@ test_that("plotRobustness plot matches", { options$priorModelProbabilityLess <- 0 options$priorModelProbabilityTwoSided <- 0 options$bfRobustnessPlot <- TRUE + options$bfRobustnessPlotType <- "BF10" - results <- jaspTools::runAnalysis("ABTestBayesian", "ab_data.csv", options) + results <- jaspTools::runAnalysis("ABTestBayesian", testthat::test_path("ab_data.csv"), options) testPlot <- results[["state"]][["figures"]][[1]] jaspTools::expect_equal_plots(testPlot, "robustness")