Skip to content

Commit

Permalink
refactor: extract internal tuned values in instance (#453)
Browse files Browse the repository at this point in the history
* refactor: extract internal tuned values in instance

* ...

* ...

* ...
  • Loading branch information
be-marc authored Oct 14, 2024
1 parent c21bae8 commit bb614ac
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 26 deletions.
11 changes: 8 additions & 3 deletions R/TuningInstanceAsyncMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
#'
#' @param ydt (`numeric(1)`)\cr
#' Optimal outcomes, e.g. the Pareto front.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, ydt, learner_param_vals = NULL, ...) {
#' @param xydt (`data.table::data.table()`)\cr
#' Point, outcome, and additional information.
assign_result = function(xdt, ydt, learner_param_vals = NULL, xydt = NULL) {
# extract internal tuned values
if ("internal_tuned_values" %in% names(xydt)) {
set(xdt, j = "internal_tuned_values", value = list(xydt[["internal_tuned_values"]]))
}

# set the column with the learner param_vals that were not optimized over but set implicitly
if (is.null(learner_param_vals)) {
learner_param_vals = self$objective$learner$param_set$values
Expand Down
11 changes: 8 additions & 3 deletions R/TuningInstanceAsyncSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,17 @@ TuningInstanceAsyncSingleCrit = R6Class("TuningInstanceAsyncSingleCrit",
#'
#' @param y (`numeric(1)`)\cr
#' Optimal outcome.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, y, learner_param_vals = NULL, ...) {
#' @param xydt (`data.table::data.table()`)\cr
#' Point, outcome, and additional information.
assign_result = function(xdt, y, learner_param_vals = NULL, xydt = NULL) {
# set the column with the learner param_vals that were not optimized over but set implicitly
assert_list(learner_param_vals, null.ok = TRUE, names = "named")

# extract internal tuned values
if ("internal_tuned_values" %in% names(xydt)) {
set(xdt, j = "internal_tuned_values", value = list(xydt[["internal_tuned_values"]]))
}

if (is.null(learner_param_vals)) {
learner_param_vals = self$objective$learner$param_set$values
}
Expand Down
11 changes: 8 additions & 3 deletions R/TuningInstanceBatchMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,14 @@ TuningInstanceBatchMultiCrit = R6Class("TuningInstanceBatchMultiCrit",
#'
#' @param ydt (`data.table::data.table()`)\cr
#' Optimal outcomes, e.g. the Pareto front.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, ydt, learner_param_vals = NULL) {
#' @param xydt (`data.table::data.table()`)\cr
#' Point, outcome, and additional information.
assign_result = function(xdt, ydt, learner_param_vals = NULL, xydt = NULL) {
# extract internal tuned values
if ("internal_tuned_values" %in% names(xydt)) {
set(xdt, j = "internal_tuned_values", value = list(xydt[["internal_tuned_values"]]))
}

# set the column with the learner param_vals that were not optimized over but set implicitly
if (is.null(learner_param_vals)) {
learner_param_vals = self$objective$learner$param_set$values
Expand Down
12 changes: 9 additions & 3 deletions R/TuningInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,18 @@ TuningInstanceBatchSingleCrit = R6Class("TuningInstanceBatchSingleCrit",
#'
#' @param y (`numeric(1)`)\cr
#' Optimal outcome.
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, y, learner_param_vals = NULL, ...) {
#' @param xydt (`data.table::data.table()`)\cr
#' Point, outcome, and additional information.
assign_result = function(xdt, y, learner_param_vals = NULL, xydt = NULL) {

# set the column with the learner param_vals that were not optimized over but set implicitly
assert_list(learner_param_vals, null.ok = TRUE, names = "named")

# extract internal tuned values
if ("internal_tuned_values" %in% names(xydt)) {
set(xdt, j = "internal_tuned_values", value = list(xydt[["internal_tuned_values"]]))
}

# learner param values
if (is.null(learner_param_vals)) {
learner_param_vals = self$objective$learner$param_set$values
Expand Down
2 changes: 1 addition & 1 deletion man/ContextAsyncTuning.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/TuningInstanceAsyncMultiCrit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/TuningInstanceAsyncSingleCrit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions man/TuningInstanceBatchMultiCrit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/TuningInstanceBatchSingleCrit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions tests/testthat/test_Tuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ test_that("we get a result when some subordinate params are not fulfilled", {
})

test_that("print method workds", {
skip_if_not_installed("GenSA")

param_set = ps(p1 = p_lgl())
param_set$values$p1 = TRUE
param_classes = "ParamLgl"
Expand Down Expand Up @@ -125,6 +127,8 @@ test_that("Tuner works with instantiated resampling", {
})

test_that("Tuner active bindings work", {
skip_if_not_installed("GenSA")

param_set = ps(p1 = p_lgl())
param_set$values$p1 = TRUE
param_classes = "ParamLgl"
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test_TunerBatchCmaes.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
test_that("TunerBatchCmaes", {
skip_if_not_installed("adagio")

expect_tuner(tnr("cmaes"))
expect_tuner(tnr("cmaes"))

learner = lrn("classif.rpart",
cp = to_tune(1e-04, 1e-1, logscale = TRUE),
Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test_TunerBatchNLoptr.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
test_that("TunerNLoptr", {
skip_on_os("windows")
skip_if_not_installed("nloptr")

test_tuner("nloptr", algorithm = "NLOPT_LN_BOBYQA", term_evals = 4)
})
4 changes: 4 additions & 0 deletions tests/testthat/test_mlr_tuners.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
test_that("mlr_tuners", {
skip_if_not_installed(c("rush", "adagio", "GenSA", "irace", "nloptr"))

expect_dictionary(mlr_tuners, min_items = 1L)
keys = mlr_tuners$keys()

Expand All @@ -14,6 +16,8 @@ test_that("mlr_tuners sugar", {
})

test_that("as.data.table objects parameter", {
skip_if_not_installed(c("rush", "adagio", "GenSA", "irace", "nloptr"))

tab = as.data.table(mlr_tuners, objects = TRUE)
expect_data_table(tab)
expect_list(tab$object, "Tuner", any.missing = FALSE)
Expand Down

0 comments on commit bb614ac

Please sign in to comment.