Skip to content

Commit

Permalink
Added parameter to allow threshold compositing #84
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin-Jung committed Jan 13, 2024
1 parent 2a6819e commit dc5cde0
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 37 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Suggests:
xgboost
URL: https://iiasa.github.io/ibis.iSDM/
BugReports: https://github.com/iiasa/ibis.iSDM/issues
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
Config/testthat/edition: 3
Roxygen: list(markdown = TRUE)
Biarch: true
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#### New features

#### Minor improvements and bug fixes
* Added a logical parameter to `ensemble()` enabling compositing of thresholds if set #84
* Support of multi-band rasters in `ensemble()` for convenience.
* Fix of bug in `threshold()` for supplied point data and improved error messages.
* Cleaner docs and structure
Expand Down
15 changes: 15 additions & 0 deletions R/bdproto-distributionmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,21 @@ DistributionModel <- bdproto(
attr(obj, "threshold")
)
},
# Get threshold type and format if calculated
get_thresholdtype = function(self){
# Determines whether a threshold exists and plots it
rl <- self$show_rasters()
if(length(grep('threshold',rl))==0) return( new_waiver() )

# Get the thresholded layer and return the respective attribute
obj <- self$get_data( grep('threshold',rl,value = TRUE) )
assertthat::assert_that(
assertthat::has_attr(obj, "format"),
assertthat::has_attr(obj, "method"))
return(
c("method" = attr(obj, "method"), "format" = attr(obj, "format"))
)
},
# List all rasters in object
show_rasters = function(self){
rn <- names(self$fits)
Expand Down
57 changes: 47 additions & 10 deletions R/ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
#' standard deviation (\code{"sd"}), the average of all PCA axes except the
#' first \code{"pca"}, the coefficient of variation (\code{"cv"}, Default) or
#' the range between the lowest and highest value (\code{"range"}).
#' @param apply_threshold A [`logical`] flag (Default: \code{TRUE}) specifying
#' whether threshold values should also be created via \code{"method"}. Only
#' applies and works for [`DistributionModel`] and thresholds found.
#'
#' @details Possible options for creating an ensemble includes:
#' * \code{'mean'} - Calculates the mean of several predictions.
Expand Down Expand Up @@ -71,14 +74,17 @@
#' @keywords train
#'
#' @examples
#' \dontrun{
#' # Assumes previously computed predictions
#' ex <- ensemble(mod1, mod2, mod3, method = "mean")
#' names(ex)
#' # Method works for fitted models as well as as rasters
#' r1 <- terra::rast(nrows = 10, ncols = 10, res = 0.05, xmin = -1.5,
#' xmax = 1.5, ymin = -1.5, ymax = 1.5, vals = rnorm(3600,mean = .5,sd = .1))
#' r2 <- terra::rast(nrows = 10, ncols = 10, res = 0.05, xmin = -1.5,
#' xmax = 1.5, ymin = -1.5, ymax = 1.5, vals = rnorm(3600,mean = .5,sd = .5))
#' names(r1) <- names(r2) <- "mean"
#'
#' # Make a bivariate plot (might require other packages)
#' bivplot(ex)
#' }
#' # Assumes previously computed predictions
#' ex <- ensemble(r1, r2, method = "mean")
#'
#' terra::plot(ex)
#'
#' @name ensemble
NULL
Expand All @@ -88,14 +94,14 @@ NULL
methods::setGeneric("ensemble",
signature = methods::signature("..."),
function(..., method = "mean", weights = NULL, min.value = NULL, layer = "mean",
normalize = FALSE, uncertainty = "cv") standardGeneric("ensemble"))
normalize = FALSE, uncertainty = "cv", apply_threshold = TRUE) standardGeneric("ensemble"))

#' @rdname ensemble
methods::setMethod(
"ensemble",
methods::signature("ANY"),
function(..., method = "mean", weights = NULL, min.value = NULL, layer = "mean",
normalize = FALSE, uncertainty = "cv"){
normalize = FALSE, uncertainty = "cv", apply_threshold = TRUE){
if(length(list(...))>1) {
mc <- list(...)
} else {
Expand Down Expand Up @@ -124,7 +130,8 @@ methods::setMethod(
is.null(layer) || ( is.character(layer) && length(layer) == 1 ),
is.null(weights) || is.vector(weights),
is.logical(normalize),
is.character(uncertainty)
is.character(uncertainty),
is.logical(apply_threshold)
)

# Check the method
Expand Down Expand Up @@ -156,6 +163,7 @@ methods::setMethod(
if(getOption('ibis.setupmessages')) myLog('[Ensemble]','red','Rasters need to be aligned. Check.')
ll_ras[[2]] <- terra::resample(ll_ras[[2]], ll_ras[[1]], method = "bilinear")
}

# Now ensemble per layer entry
out <- terra::rast()
for(lyr in layer){
Expand Down Expand Up @@ -248,6 +256,35 @@ methods::setMethod(
}
}

# Check for threshold values and collate
if(apply_threshold){
ll_val <- sapply(mods, function(x) x$get_thresholdvalue())
# Incase no thresholds are found, ignore entirely
if(!all(any(sapply(ll_val, is.Waiver)))){
# Respecify weights as otherwise call below fails
if(any(sapply(ll_val, is.Waiver))){
if(getOption('ibis.setupmessages')) myLog('[Ensemble]','yellow','Threshold values not found for all objects')
ll_val <- ll_val[-which(sapply(ll_val, is.Waiver))]
ll_val <- ll_val |> as.numeric()
}
if(is.null(weights)) weights <- rep(1, length(ll_val))

# Composite threshold
tr <- dplyr::case_when(
method == "mean" ~ mean(ll_val, na.rm = TRUE),
method == "median" ~ median(ll_val, na.rm = TRUE),
method == "max" ~ max(ll_val, na.rm = TRUE),
method == "min" ~ min(ll_val, na.rm = TRUE),
method == "weighted.mean" ~ weighted.mean(ll_val, w = weights, na.rm = TRUE),
.default = mean(ll_val, na.rm = TRUE)
)

# Ensemble the first layer
out <- c(out,
threshold(out[[1]], method = "fixed", value = tr)
)
}
}
assertthat::assert_that(is.Raster(out))

return(out)
Expand Down
5 changes: 3 additions & 2 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ methods::setMethod(
)
# Check whether object is a raster, otherwise extract object
if(is.Raster(mod)){
assertthat::assert_that(terra::nlyr(mod)>1)
assertthat::assert_that(terra::nlyr(mod)>1,msg = "SpatRaster object has less than 2 layers. Use plot().")
obj <- mod
# If number of layers equal to 2 (output from ensemble?), change xvar and yvar
if(terra::nlyr(mod)==2 && !(xvar %in% names(obj))){
Expand Down Expand Up @@ -170,7 +170,8 @@ methods::setMethod(

# Define default title
if(is.null(title)){
title <- paste("Bivariate plot of prediction\n (",mod$model$runname,')')
if(is.Raster(mod)) tt <- "" else tt <- paste0("\n (",mod$model$runname,")")
title <- paste("Bivariate plot of prediction",tt)
}

# Create dimensions
Expand Down
2 changes: 1 addition & 1 deletion man/DistributionModel-class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 18 additions & 9 deletions man/ensemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/ibis.iSDM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion tests/testthat/test_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ test_that('Custom functions - Test gridded transformations and ensembles', {
pp <- ensemble(ras, method = "weighted.mean", weights = runif(3, 0.5,1))
)


# Check centroid calculation
expect_s3_class(raster_centroid(r1), "sf")
expect_s3_class(raster_centroid(r1,patch = TRUE), "sf")
Expand All @@ -155,6 +154,11 @@ test_that('Custom functions - Test gridded transformations and ensembles', {
expect_no_error(tr <- threshold(o,method = "perc",point = pp,return_threshold = TRUE))
expect_type(tr, "double")

# Check attributes
expect_no_error(tr1 <- threshold(r1, method = "perc",point = pp) )
expect_match(attr(tr1, "method"), "percentile")
expect_match(attr(tr1, "format"), "binary")

# --- #
})

Expand Down
29 changes: 17 additions & 12 deletions tests/testthat/test_modelFits.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Further tests for model fits
test_that('Add further tests for model fits', {

skip_if_not_installed("pdp")

# Set to verbose
options("ibis.setupmessages" = FALSE)

Expand Down Expand Up @@ -50,9 +48,14 @@ test_that('Add further tests for model fits', {
suppressMessages(
mod <- threshold(mod, method = "perc", format = "bin")
)
expect_no_error(mod_poipo <- threshold(mod_poipo, method = "perc", value = .2))
expect_gt(mod$get_thresholdvalue(),0)
expect_length(mod$show_rasters(), 2)

# Make an ensemble and check that thresholds are also present
expect_no_error( ens <- ensemble(mod, mod_poipo, method = "mean"))
expect_length(names(ens), 3)

# Summarize model
expect_s3_class( summary(mod), "data.frame" )
expect_s3_class( coef(mod), "data.frame" )
Expand All @@ -73,15 +76,6 @@ test_that('Add further tests for model fits', {
pp <- mod$project(predictors |> as.data.frame(xy = TRUE, na.rm = FALSE))
expect_s4_class(pp, "SpatRaster")

# ----------- #
# Partial stuff
pp <- partial(mod, x.var = "bio19_mean_50km",plot = FALSE)
expect_s3_class(pp, "data.frame")

# Spartial
pp <- spartial(mod,x.var = "bio19_mean_50km",plot = FALSE)
expect_s4_class(pp, "SpatRaster")

# ----------- #
# Create a suitability index
o <- mod$calc_suitabilityindex()
Expand Down Expand Up @@ -119,9 +113,20 @@ test_that('Add further tests for model fits', {

# Make an ensemble
expect_no_error(
o <- ensemble(mod1, mod2, mod3, method = "median")
o <- ensemble(mod1, mod2, mod3, method = "median",uncertainty = "range")
)
expect_s4_class(o, "SpatRaster")
expect_length(names(o), 2) # Should be at maximum 2 layers

# ----------- #
# Partial stuff
skip_if_not_installed("pdp")
pp <- partial(mod, x.var = "bio19_mean_50km",plot = FALSE)
expect_s3_class(pp, "data.frame")

# Spartial
pp <- spartial(mod,x.var = "bio19_mean_50km",plot = FALSE)
expect_s4_class(pp, "SpatRaster")

# ----------- #
# Write model outputs
Expand Down

0 comments on commit dc5cde0

Please sign in to comment.