Skip to content

Commit a489897

Browse files
committed
LLM-generated fn for neuron search space
1 parent 31b3964 commit a489897

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

benchmarks/rf_use_case/run_benchmark.R

+27-12
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,28 @@ task_list = mlr3misc::pmap(cc18_small, function(data_id, name, NumberOfFeatures,
2020
task_list
2121

2222
# define the learners
23+
neurons = function(n_layers, latent_dim) {
24+
rep(latent_dim, n_layers)
25+
}
26+
27+
n_layers_values <- 1:10
28+
latent_dim_values <- seq(10, 500, by = 10)
29+
neurons_search_space <- mapply(
30+
neurons,
31+
expand.grid(n_layers = n_layers_values, latent_dim = latent_dim_values)$n_layers,
32+
expand.grid(n_layers = n_layers_values, latent_dim = latent_dim_values)$latent_dim,
33+
SIMPLIFY = FALSE
34+
)
35+
2336
mlp = lrn("classif.mlp",
2437
activation = nn_relu,
25-
neurons = to_tune(ps(
26-
n_layers = p_int(lower = 1, upper = 10), latent = p_int(10, 500),
27-
.extra_trafo = function(x, param_set) {
28-
list(neurons = rep(x$latent, x$n_layers))
29-
})
30-
),
38+
# neurons = to_tune(ps(
39+
# n_layers = p_int(lower = 1, upper = 10), latent = p_int(10, 500),
40+
# .extra_trafo = function(x, param_set) {
41+
# list(neurons = rep(x$latent, x$n_layers))
42+
# })
43+
# ),
44+
neurons = to_tune(neurons_search_space),
3145
batch_size = to_tune(c(16, 32, 64, 128, 256)),
3246
p = to_tune(0.1, 0.9),
3347
epochs = to_tune(upper = 1000L, internal = TRUE),
@@ -37,7 +51,7 @@ mlp = lrn("classif.mlp",
3751
device = "cpu"
3852
)
3953

40-
# define the optimizatio nstrategy
54+
# define the optimization strategy
4155
bayesopt_ego = mlr_loop_functions$get("bayesopt_ego")
4256
surrogate = srlrn(lrn("regr.km", covtype = "matern5_2",
4357
optim.method = "BFGS", control = list(trace = FALSE)))
@@ -54,19 +68,20 @@ tnr_mbo = tnr("mbo",
5468
# define an AutoTuner that wraps the classif.mlp
5569
at = auto_tuner(
5670
learner = mlp,
57-
tuner = tnr("grid_search"),
71+
tuner = tnr_mbo,
5872
resampling = rsmp("cv"),
5973
measure = msr("classif.acc"),
60-
term_evals = 1000
74+
term_evals = 10
6175
)
6276

63-
future::plan("multisession", workers = 8)
77+
future::plan("multisession", workers = 64)
6478

6579
lrn_rf = lrn("classif.ranger")
6680
design = benchmark_grid(
6781
task_list,
6882
learners = list(at, lrn_rf),
69-
resampling = rsmp("cv", folds = 10))
83+
resampling = rsmp("cv", folds = 10)
84+
)
7085

7186
time = bench::system_time(
7287
bmr <- benchmark(design)
@@ -75,4 +90,4 @@ time = bench::system_time(
7590
bmrdt = as.data.table(bmr)
7691

7792
fwrite(bmrdt, here("R", "rf_use_case", "results", "bmrdt.csv"))
78-
fwrite(time, here("R", "rf_use_case", "results", "time.csv"))
93+
fwrite(time, here("R", "rf_use_case", "results", "time.csv"))

0 commit comments

Comments
 (0)