Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use automatic or user-supplied symbolic derivatives for the density of transformed distributions #101

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Makefile
inst/doc
tests/testthat/Rplots.pdf
revdep
local/
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Imports:
stats,
numDeriv,
utils,
lifecycle
lifecycle,
Deriv
Suggests:
testthat (>= 2.1.0),
covr,
Expand Down
14 changes: 12 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ S3method(Math,dist_lognormal)
S3method(Math,dist_na)
S3method(Math,dist_normal)
S3method(Math,dist_sample)
S3method(Math,dist_transformed)
S3method(Ops,dist_default)
S3method(Ops,dist_na)
S3method(Ops,dist_normal)
S3method(Ops,dist_sample)
S3method(Ops,dist_transformed)
S3method(cdf,dist_bernoulli)
S3method(cdf,dist_beta)
S3method(cdf,dist_binomial)
Expand Down Expand Up @@ -138,6 +136,15 @@ S3method(dim,dist_default)
S3method(dim,dist_multinomial)
S3method(dim,dist_mvnorm)
S3method(dimnames,distribution)
S3method(eval_deriv,dist_default)
S3method(eval_deriv,dist_transformed)
S3method(eval_deriv,distribution)
S3method(eval_inverse,dist_default)
S3method(eval_inverse,dist_transformed)
S3method(eval_inverse,distribution)
S3method(eval_transform,dist_default)
S3method(eval_transform,dist_transformed)
S3method(eval_transform,distribution)
S3method(family,dist_default)
S3method(family,distribution)
S3method(format,dist_bernoulli)
Expand Down Expand Up @@ -477,6 +484,9 @@ export(dist_truncated)
export(dist_uniform)
export(dist_weibull)
export(dist_wrap)
export(eval_deriv)
export(eval_inverse)
export(eval_transform)
export(generate)
export(hdr)
export(hilo)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

## New features

* automatic symbolic derivatives for transformed distributions (@venpopov, #101)
* dist_transformed() now accepts a `d_inverse` argument for a user-supplied derivative
function on the inverse transformation (#101)
* `support()` now shows whether the interval of support is open or
closed (@venpopov, #97)

Expand Down
104 changes: 55 additions & 49 deletions R/default.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,29 +189,32 @@ invert_fail <- function(...) stop("Inverting transformations for distributions i
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @noRd
get_unary_inverse <- function(f) {
get_unary_inverse <- function(f, ...) {
switch(f,
sqrt = function(x) x^2,
exp = log,
log = function(x, base = exp(1)) base ^ x,
log2 = function(x) 2^x,
log10 = function(x) 10^x,
expm1 = log1p,
log1p = expm1,
cos = acos,
sin = asin,
tan = atan,
acos = cos,
asin = sin,
atan = tan,
cosh = acosh,
sinh = asinh,
tanh = atanh,
acosh = cosh,
asinh = sinh,
atanh = tanh,

invert_fail
sqrt = function(x) x^2,
exp = function(x) log(x),
log = (function(x, base) {
if (missing(base)) function(x) exp(x)
else new_function(exprs(x = ), expr((!!base)^x))
})(x, ...),
log2 = function(x) 2^x,
log10 = function(x) 10^x,
expm1 = function(x) log1p(x),
log1p = function(x) expm1(x),
cos = function(x) acos(x),
sin = function(x) asin(x),
tan = function(x) atan(x),
acos = function(x) cos(x),
asin = function(x) sin(x),
atan = function(x) tan(x),
cosh = function(x) acosh(x),
sinh = function(x) asinh(x),
tanh = function(x) atanh(x),
acosh = function(x) cosh(x),
asinh = function(x) sinh(x),
atanh = function(x) tanh(x),

invert_fail
)
}

Expand All @@ -221,16 +224,15 @@ get_unary_inverse <- function(f) {
#' @param constant a constant value
#' @noRd
get_binary_inverse_1 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) x + constant,
`*` = function(x) x / constant,
`/` = function(x) x * constant,
`^` = function(x) x ^ (1/constant),
`+` = new_function(exprs(x = ), body = expr(x - !!constant)),
`-` = new_function(exprs(x = ), body = expr(x + !!constant)),
`*` = new_function(exprs(x = ), body = expr(x / !!constant)),
`/` = new_function(exprs(x = ), body = expr(x * !!constant)),
`^` = new_function(exprs(x = ), body = expr(x ^ (1/!!constant))),

invert_fail
invert_fail
)
}

Expand All @@ -243,61 +245,65 @@ get_binary_inverse_2 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) constant - x,
`*` = function(x) x / constant,
`/` = function(x) constant / x,
`^` = function(x) log(x, base = constant),
`+` = new_function(exprs(x = ), body = expr(x - !!constant)),
`-` = new_function(exprs(x = ), body = expr(!!constant - x)),
`*` = new_function(exprs(x = ), body = expr(x / !!constant)),
`/` = new_function(exprs(x = ), body = expr(!!constant / x)),
`^` = new_function(exprs(x = ), body = expr(log(x, base = !!constant))),

invert_fail
invert_fail
)
}

#' @method Math dist_default
#' @export
Math.dist_default <- function(x, ...) {
if(dim(x) > 1) stop("Transformations of multivariate distributions are not yet supported.")
if (dim(x) > 1) stop("Transformations of multivariate distributions are not yet supported.")

trans <- new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!!dots_list(...))))
transform <- new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!!dots_list(...))))
inverse <- get_unary_inverse(.Generic, ...)
d_inverse <- symbolic_derivative(inverse, fallback_numderiv = TRUE)

inverse_fun <- get_unary_inverse(.Generic)
inverse <- new_function(exprs(x = ), body = expr((!!inverse_fun)(x, !!!dots_list(...))))

vec_data(dist_transformed(wrap_dist(list(x)), trans, inverse))[[1]]
vec_data(dist_transformed(wrap_dist(list(x)), transform, inverse, d_inverse))[[1]]
}

#' @method Ops dist_default
#' @export
Ops.dist_default <- function(e1, e2) {
if(.Generic %in% c("-", "+") && missing(e2)){
if (.Generic %in% c("-", "+") && missing(e2)){
e2 <- e1
e1 <- if(.Generic == "+") 1 else -1
.Generic <- "*"
}
is_dist <- c(inherits(e1, "dist_default"), inherits(e2, "dist_default"))
if(any(vapply(list(e1, e2)[is_dist], dim, numeric(1L)) > 1)){
if (any(vapply(list(e1, e2)[is_dist], dim, numeric(1L)) > 1)){
stop("Transformations of multivariate distributions are not yet supported.")
}

trans <- if(all(is_dist)) {
if(identical(e1$dist, e2$dist)){
transform <- if (all(is_dist)) {
if (identical(e1$dist, e2$dist)){
new_function(exprs(x = ), expr((!!sym(.Generic))((!!e1$transform)(x), (!!e2$transform)(x))))
} else {
stop(sprintf("The %s operation is not supported for <%s> and <%s>", .Generic, class(e1)[1], class(e2)[1]))
stop(sprintf("The %s operation is not supported for <%s> and <%s>",
.Generic, class(e1)[1], class(e2)[1]))
}
} else if(is_dist[1]){
} else if (is_dist[1]){
new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!e2)))
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, x)))
}

inverse <- if(all(is_dist)) {
inverse <- if (all(is_dist)) {
invert_fail
} else if(is_dist[1]){
get_binary_inverse_1(.Generic, e2)
} else {
get_binary_inverse_2(.Generic, e1)
}

vec_data(dist_transformed(wrap_dist(list(e1,e2)[which(is_dist)]), trans, inverse))[[1]]
d_inverse <- symbolic_derivative(inverse, fallback_numderiv = TRUE)

dist <- list(e1,e2)[which(is_dist)]

vec_data(dist_transformed(wrap_dist(dist), transform, inverse, d_inverse))[[1]]
}
24 changes: 24 additions & 0 deletions R/derivative.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
numderiv <- function(f) {
function(., ...) {
vapply(., numDeriv::jacobian, numeric(1L), func = f, ...)
}
}

symbolic_derivative <- function(inverse, fallback_numderiv = TRUE) {
if (!fallback_numderiv) return(Deriv::Deriv(inverse, x = 'x'))

tryCatch(
Deriv::Deriv(inverse, x = 'x'),
error = function(...) {
if (getOption('dist.verbose', FALSE)) {
message('Cannot compute the derivative of the inverse function symbolicly.')
}
numderiv(inverse)
}
)
}

# Chain rule
chain_rule <- function(x, y, d_x = symbolic_derivative(x), d_y = symbolic_derivative(y)) {
function(x) d_x(y(x)) * d_y(x)
}
Loading
Loading