Skip to content

Commit

Permalink
fix and test (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Nov 29, 2024
1 parent b0f74e2 commit a7a4af4
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 19 deletions.
9 changes: 5 additions & 4 deletions R/TuningInstanceAsyncMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
search_space = search_space$subset(setdiff(sids, internal_tune_ids))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

if (!is.null(self$internal_search_space) && self$internal_search_space$has_trafo) {
stopf("Internal tuning and parameter transformations are currently not supported.
Expand All @@ -115,6 +111,11 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
learner$param_set$set_values(.values = learner$param_set$convert_internal_search_space(self$internal_search_space))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

if (is.null(rush)) rush = rush::rsh()

# create codomain from measure
Expand Down
10 changes: 5 additions & 5 deletions R/TuningInstanceAsyncSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ TuningInstanceAsyncSingleCrit = R6Class("TuningInstanceAsyncSingleCrit",
search_space = search_space$subset(setdiff(sids, internal_tune_ids))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

if (!is.null(self$internal_search_space) && self$internal_search_space$has_trafo) {
stopf("Internal tuning and parameter transformations are currently not supported.
If you manually provided a search space that has a trafo and parameters tagged with 'internal_tuning',
Expand All @@ -125,6 +120,11 @@ TuningInstanceAsyncSingleCrit = R6Class("TuningInstanceAsyncSingleCrit",
learner$param_set$set_values(.values = learner$param_set$convert_internal_search_space(self$internal_search_space))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

if (is.null(rush)) rush = rush::rsh()

# create codomain from measure
Expand Down
10 changes: 5 additions & 5 deletions R/TuningInstanceBatchMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,6 @@ TuningInstanceBatchMultiCrit = R6Class("TuningInstanceBatchMultiCrit",
search_space = search_space$subset(setdiff(sids, internal_tune_ids))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token", check_required = TRUE)
}

if (!is.null(self$internal_search_space) && self$internal_search_space$has_trafo) {
stopf("Internal tuning and parameter transformations are currently not supported.
If you manually provided a search space that has a trafo and parameters tagged with 'internal_tuning',
Expand All @@ -144,6 +139,11 @@ TuningInstanceBatchMultiCrit = R6Class("TuningInstanceBatchMultiCrit",
learner$param_set$set_values(.values = learner$param_set$convert_internal_search_space(self$internal_search_space))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token", check_required = TRUE)
}

# create codomain from measure
measures = assert_measures(as_measures(measures, task_type = task$task_type), task = task, learner = learner)
codomain = measures_to_codomain(measures)
Expand Down
10 changes: 5 additions & 5 deletions R/TuningInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ TuningInstanceBatchSingleCrit = R6Class("TuningInstanceBatchSingleCrit",
search_space = search_space$subset(setdiff(sids, internal_tune_ids))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

if (!is.null(self$internal_search_space) && self$internal_search_space$has_trafo) {
stopf("Internal tuning and parameter transformations are currently not supported.
If you manually provided a search space that has a trafo and parameters tagged with 'internal_tuning',
Expand All @@ -194,6 +189,11 @@ TuningInstanceBatchSingleCrit = R6Class("TuningInstanceBatchSingleCrit",
learner$param_set$set_values(.values = learner$param_set$convert_internal_search_space(self$internal_search_space))
}

# set learner parameter values
if (search_space_from_tokens) {
learner$param_set$values = learner$param_set$get_values(type = "without_token")
}

# create codomain from measure
measures = assert_measures(as_measures(measure, task_type = task$task_type), task = task, learner = learner)
codomain = measures_to_codomain(measures)
Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test_TuningInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,40 @@ test_that("Batch single-crit internal tuning works", {
expect_equal(instance$result_learner_param_vals$iter, 99)
})

test_that("required parameter can be tuned internally without having a value set", {
learner = lrn("classif.debug")
tags = learner$param_set$tags
tags$iter = union(tags$iter, "required")
learner$param_set$tags = tags

learner$param_set$set_values(
early_stopping = TRUE,
iter = NULL
)
learner$validate = "test"

internal_search_space = ps(
iter = p_int(upper = 1000, aggr = function(x) as.integer(mean(unlist(x))))
)


expect_error(tune(
task = tsk("iris"),
tuner = tnr("internal"),
learner = learner,
internal_search_space = internal_search_space,
resampling = rsmp("holdout"),
store_benchmark_result = TRUE
), regexp = NA)

learner$param_set$set_values(
iter = to_tune(upper = 1000, internal = TRUE)
)
expect_error(tune(
task = tsk("iris"),
tuner = tnr("internal"),
learner = learner,
resampling = rsmp("holdout"),
store_benchmark_result = TRUE
), regexp = NA)
})

0 comments on commit a7a4af4

Please sign in to comment.