Skip to content

Commit 6b18e9b

Browse files
initial logic update
Signed-off-by: Shashank Mittal <[email protected]>
1 parent fd460b6 commit 6b18e9b

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

pkg/suggestion/v1beta1/optuna/base_service.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import optuna
1818

19+
from pkg.apis.manager.v1beta1.python import api_pb2
1920
from pkg.suggestion.v1beta1.internal.constant import (
2021
CATEGORICAL,
2122
DISCRETE,
@@ -110,13 +111,37 @@ def _get_optuna_search_space(self):
110111
search_space = {}
111112
for param in self.search_space.params:
112113
if param.type == INTEGER:
113-
search_space[param.name] = optuna.distributions.IntDistribution(
114-
int(param.min), int(param.max)
115-
)
114+
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
115+
if param.step:
116+
search_space[param.name] = optuna.distributions.IntDistribution(
117+
int(param.min), int(param.max), False, param.step
118+
)
119+
else:
120+
search_space[param.name] = optuna.distributions.IntDistribution(
121+
int(param.min), int(param.max)
122+
)
123+
if param.distribution == api_pb2.LOG_UNIFORM:
124+
search_space[param.name] = optuna.distributions.IntDistribution(
125+
int(param.min), int(param.max), True, param.step
126+
)
116127
elif param.type == DOUBLE:
117-
search_space[param.name] = optuna.distributions.FloatDistribution(
118-
float(param.min), float(param.max)
119-
)
128+
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
129+
if param.step:
130+
search_space[param.name] = (
131+
optuna.distributions.FloatDistribution(
132+
int(param.min), int(param.max), False, param.step
133+
)
134+
)
135+
else:
136+
search_space[param.name] = (
137+
optuna.distributions.FloatDistribution(
138+
int(param.min), int(param.max)
139+
)
140+
)
141+
if param.distribution == api_pb2.LOG_UNIFORM:
142+
search_space[param.name] = optuna.distributions.FloatDistribution(
143+
int(param.min), int(param.max), True, param.step
144+
)
120145
elif param.type == CATEGORICAL or param.type == DISCRETE:
121146
search_space[param.name] = optuna.distributions.CategoricalDistribution(
122147
param.list

0 commit comments

Comments
 (0)