Skip to content

Commit 7bbfc50

Browse files
refactored code
Signed-off-by: Shashank Mittal <[email protected]>
1 parent 2c56864 commit 7bbfc50

File tree

2 files changed

+28
-65
lines changed

2 files changed

+28
-65
lines changed

pkg/suggestion/v1beta1/hyperopt/base_service.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ def create_hyperopt_domain(self):
6363
# Construct search space, example: {"x": hyperopt.hp.uniform('x', -10, 10), "x2":
6464
# hyperopt.hp.uniform('x2', -10, 10)}
6565
hyperopt_search_space = {}
66+
6667
for param in self.search_space.params:
6768
if param.type in [INTEGER, DOUBLE]:
68-
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
69+
if param.distribution in [api_pb2.UNIFORM, None]:
6970
# Uniform distribution: values are sampled between min and max.
7071
# If step is defined, we use the quantized version quniform.
7172
if param.step:
@@ -80,14 +81,10 @@ def create_hyperopt_domain(self):
8081
param.name, float(param.min), float(param.max)
8182
)
8283
else:
83-
if param.type == INTEGER:
84-
hyperopt_search_space[param.name] = hyperopt.hp.uniformint(
85-
param.name, float(param.min), float(param.max)
86-
)
87-
else:
88-
hyperopt_search_space[param.name] = hyperopt.hp.uniform(
89-
param.name, float(param.min), float(param.max)
90-
)
84+
hyperopt_search_space[param.name] = hyperopt.hp.uniform(
85+
param.name, float(param.min), float(param.max)
86+
)
87+
9188
elif param.distribution == api_pb2.LOG_UNIFORM:
9289
# Log-uniform distribution: used for parameters that vary exponentially.
9390
# We convert min and max to their logarithmic scale using math.log, because
@@ -105,28 +102,23 @@ def create_hyperopt_domain(self):
105102
math.log(float(param.min)),
106103
math.log(float(param.max)),
107104
)
105+
108106
elif param.distribution == api_pb2.NORMAL:
109107
# Normal distribution: used when values are centered around the mean (mu)
110108
# and spread out by sigma. We calculate mu as the midpoint between
111109
# min and max, and sigma as (max - min) / 6. This is based on the assumption
112110
# that 99.7% of the values in a normal distribution fall within ±3 sigma.
113111
mu = (float(param.min) + float(param.max)) / 2
114-
# We consider the normal distribution based on the range of ±3 sigma.
115112
sigma = (float(param.max) - float(param.min)) / 6
116-
117113
if param.step:
118114
hyperopt_search_space[param.name] = hyperopt.hp.qnormal(
119-
param.name,
120-
mu,
121-
sigma,
122-
float(param.step),
115+
param.name, mu, sigma, float(param.step)
123116
)
124117
else:
125118
hyperopt_search_space[param.name] = hyperopt.hp.normal(
126-
param.name,
127-
mu,
128-
sigma,
119+
param.name, mu, sigma
129120
)
121+
130122
elif param.distribution == api_pb2.LOG_NORMAL:
131123
# Log-normal distribution: applies when the logarithm
132124
# of the parameter follows a normal distribution.
@@ -137,21 +129,16 @@ def create_hyperopt_domain(self):
137129
log_max = math.log(float(param.max))
138130
mu = (log_min + log_max) / 2
139131
sigma = (log_max - log_min) / 6
140-
141132
if param.step:
142133
hyperopt_search_space[param.name] = hyperopt.hp.qlognormal(
143-
param.name,
144-
mu,
145-
sigma,
146-
float(param.step),
134+
param.name, mu, sigma, float(param.step)
147135
)
148136
else:
149137
hyperopt_search_space[param.name] = hyperopt.hp.lognormal(
150-
param.name,
151-
mu,
152-
sigma,
138+
param.name, mu, sigma
153139
)
154-
elif param.type == CATEGORICAL or param.type == DISCRETE:
140+
141+
elif param.type in [CATEGORICAL, DISCRETE]:
155142
hyperopt_search_space[param.name] = hyperopt.hp.choice(
156143
param.name, param.list
157144
)

pkg/suggestion/v1beta1/optuna/base_service.py

+14-38
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,15 @@ def _get_optuna_search_space(self):
112112

113113
for param in self.search_space.params:
114114
if param.type == INTEGER:
115-
step = int(param.step) if param.step else None
116-
117-
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
115+
if param.distribution in [api_pb2.UNIFORM, None]:
118116
# Uniform integer distribution: samples integers between min and max.
119117
# If step is defined, use a quantized version.
120-
if step:
121-
search_space[param.name] = optuna.distributions.IntDistribution(
122-
low=int(param.min),
123-
high=int(param.max),
124-
log=False,
125-
step=step,
126-
)
127-
else:
128-
search_space[param.name] = optuna.distributions.IntDistribution(
129-
low=int(param.min),
130-
high=int(param.max),
131-
log=False,
132-
step=None,
133-
)
118+
search_space[param.name] = optuna.distributions.IntDistribution(
119+
low=int(param.min),
120+
high=int(param.max),
121+
log=False,
122+
step=int(param.step) if param.step else None,
123+
)
134124
elif param.distribution == api_pb2.LOG_UNIFORM:
135125
# Log-uniform integer distribution: used for exponentially varying integers.
136126
search_space[param.name] = optuna.distributions.IntDistribution(
@@ -141,29 +131,15 @@ def _get_optuna_search_space(self):
141131
)
142132

143133
elif param.type == DOUBLE:
144-
step = float(param.step) if param.step else None
145-
146-
if param.distribution == api_pb2.UNIFORM or param.distribution is None:
134+
if param.distribution in [api_pb2.UNIFORM, None]:
147135
# Uniform float distribution: samples values between min and max.
148136
# If step is provided, use a quantized version.
149-
if step:
150-
search_space[param.name] = (
151-
optuna.distributions.FloatDistribution(
152-
low=float(param.min),
153-
high=float(param.max),
154-
log=False,
155-
step=step,
156-
)
157-
)
158-
else:
159-
search_space[param.name] = (
160-
optuna.distributions.FloatDistribution(
161-
low=float(param.min),
162-
high=float(param.max),
163-
log=False,
164-
step=None,
165-
)
166-
)
137+
search_space[param.name] = optuna.distributions.FloatDistribution(
138+
low=float(param.min),
139+
high=float(param.max),
140+
log=False,
141+
step=float(param.step) if param.step else None,
142+
)
167143
elif param.distribution == api_pb2.LOG_UNIFORM:
168144
# Log-uniform float distribution: used for exponentially varying values.
169145
search_space[param.name] = optuna.distributions.FloatDistribution(

0 commit comments

Comments
 (0)