Skip to content

Commit

Permalink
Merge pull request #375 from tidymodels/respect-inclusive-347
Browse files Browse the repository at this point in the history
`value_seq()` and `value_sample()` respect `inclusive`
  • Loading branch information
hfrick authored Feb 12, 2025
2 parents da8400f + 2919395 commit 78fab13
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 9 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

* For space-filling designs for $p$ parameters, there is a higher likelihood of finding a space-filling design for `1 < size <= p`. Also, single-point designs now default to a random grid (#363).

* `value_seq()` and `value_sample()` now respect the `inclusive` argument of quantitative parameters (#347).

* The constructors, `new_*_parameter()`, now label unlabeled parameter (i.e., constructed with `label = NULL`) as such (#349).

## Breaking changes
Expand Down
69 changes: 60 additions & 9 deletions R/aaa_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,19 @@ value_seq_dbl <- function(object, n, original = TRUE) {
n_safely <- min(length(object$values), n)
res <- object$values[seq_len(n_safely)]
} else {
range_lower <- min(unlist(object$range))
if (!object$inclusive["lower"]) {
range_lower <- range_lower + .Machine$double.eps
}

range_upper <- max(unlist(object$range))
if (!object$inclusive["upper"]) {
range_upper <- range_upper - .Machine$double.eps
}

res <- seq(
from = min(unlist(object$range)),
to = max(unlist(object$range)),
from = range_lower,
to = range_upper,
length.out = n
)
}
Expand All @@ -161,9 +171,19 @@ value_seq_int <- function(object, n, original = TRUE) {
n_safely <- min(length(object$values), n)
res <- object$values[seq_len(n_safely)]
} else {
range_lower <- min(unlist(object$range))
if (!object$inclusive["lower"]) {
range_lower <- range_lower + 1L
}

range_upper <- max(unlist(object$range))
if (!object$inclusive["upper"]) {
range_upper <- range_upper - 1L
}

res <- seq(
from = min(unlist(object$range)),
to = max(unlist(object$range)),
from = range_lower,
to = range_upper,
length.out = n
)
}
Expand Down Expand Up @@ -202,10 +222,20 @@ value_sample <- function(object, n, original = TRUE) {

value_samp_dbl <- function(object, n, original = TRUE) {
if (is.null(object$values)) {
range_lower <- min(unlist(object$range))
if (!object$inclusive["lower"]) {
range_lower <- range_lower + .Machine$double.eps
}

range_upper <- max(unlist(object$range))
if (!object$inclusive["upper"]) {
range_upper <- range_upper - .Machine$double.eps
}

res <- runif(
n,
min = min(unlist(object$range)),
max = max(unlist(object$range))
min = range_lower,
max = range_upper
)
} else {
res <- sample(
Expand All @@ -223,11 +253,22 @@ value_samp_dbl <- function(object, n, original = TRUE) {
value_samp_int <- function(object, n, original = TRUE) {
if (is.null(object$trans)) {
if (is.null(object$values)) {
range_lower <- min(unlist(object$range))
if (!object$inclusive["lower"]) {
range_lower <- range_lower + 1L
}

range_upper <- max(unlist(object$range))
if (!object$inclusive["upper"]) {
range_upper <- range_upper - 1L
}

res <- sample(
min(unlist(object$range)):max(unlist(object$range)),
seq(from = range_lower, to = range_upper),
size = n,
replace = TRUE
)
res <- as.integer(res)
} else {
res <- sample(
object$values,
Expand All @@ -237,10 +278,20 @@ value_samp_int <- function(object, n, original = TRUE) {
}
} else {
if (is.null(object$values)) {
range_lower <- min(unlist(object$range))
if (!object$inclusive["lower"]) {
range_lower <- range_lower + .Machine$double.eps
}

range_upper <- max(unlist(object$range))
if (!object$inclusive["upper"]) {
range_upper <- range_upper - .Machine$double.eps
}

res <- runif(
n,
min = min(unlist(object$range)),
max = max(unlist(object$range))
min = range_lower,
max = range_upper
)
} else {
res <- sample(
Expand Down
56 changes: 56 additions & 0 deletions tests/testthat/test-aaa_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,59 @@ test_that("value_set() checks inputs", {
value_set(cost_complexity(), numeric(0))
})
})

test_that("`value_seq()` respects `inclusive` #347", {
double_non_incl <- new_quant_param(
type = "double",
range = c(0, 1),
inclusive = c(FALSE, FALSE),
trans = NULL,
label = c(param_non_incl = "some label"),
finalize = NULL
)

vals_double <- value_seq(double_non_incl, 10)
expect_gt(min(vals_double), 0)
expect_lt(max(vals_double), 1)

int_non_incl <- new_quant_param(
type = "integer",
range = c(0, 2),
inclusive = c(FALSE, FALSE),
trans = NULL,
label = c(param_non_incl = "some label"),
finalize = NULL
)

vals_int <- value_seq(int_non_incl, 10)
expect_gt(min(vals_int), 0)
expect_lt(max(vals_int), 2)
})

test_that("`value_sample()` respects `inclusive` #347", {
int_non_incl <- new_quant_param(
type = "integer",
range = c(0, 2),
inclusive = c(FALSE, FALSE),
trans = NULL,
label = c(param_non_incl = "some label"),
finalize = NULL
)

vals_int <- value_sample(int_non_incl, 10)
expect_gt(min(vals_int), 0)
expect_lt(max(vals_int), 2)

int_non_incl_trans <- new_quant_param(
type = "integer",
range = c(0, 2),
inclusive = c(FALSE, FALSE),
trans = scales::transform_log(),
label = c(param_non_incl = "some label"),
finalize = NULL
)

vals_int <- value_sample(int_non_incl_trans, n = 10, original = FALSE)
expect_gt(min(vals_int), 0)
expect_lt(max(vals_int), 2)
})

0 comments on commit 78fab13

Please sign in to comment.