|
16 | 16 |
|
17 | 17 | import optuna
|
18 | 18 |
|
| 19 | +from pkg.apis.manager.v1beta1.python import api_pb2 |
19 | 20 | from pkg.suggestion.v1beta1.internal.constant import (
|
20 | 21 | CATEGORICAL,
|
21 | 22 | DISCRETE,
|
@@ -110,13 +111,37 @@ def _get_optuna_search_space(self):
|
110 | 111 | search_space = {}
|
111 | 112 | for param in self.search_space.params:
|
112 | 113 | if param.type == INTEGER:
|
113 |
| - search_space[param.name] = optuna.distributions.IntDistribution( |
114 |
| - int(param.min), int(param.max) |
115 |
| - ) |
| 114 | + if param.distribution == api_pb2.UNIFORM or param.distribution is None: |
| 115 | + if param.step: |
| 116 | + search_space[param.name] = optuna.distributions.IntDistribution( |
| 117 | + int(param.min), int(param.max), False, param.step |
| 118 | + ) |
| 119 | + else: |
| 120 | + search_space[param.name] = optuna.distributions.IntDistribution( |
| 121 | + int(param.min), int(param.max) |
| 122 | + ) |
| 123 | + if param.distribution == api_pb2.LOG_UNIFORM: |
| 124 | + search_space[param.name] = optuna.distributions.IntDistribution( |
| 125 | + int(param.min), int(param.max), True, param.step |
| 126 | + ) |
116 | 127 | elif param.type == DOUBLE:
|
117 |
| - search_space[param.name] = optuna.distributions.FloatDistribution( |
118 |
| - float(param.min), float(param.max) |
119 |
| - ) |
| 128 | + if param.distribution == api_pb2.UNIFORM or param.distribution is None: |
| 129 | + if param.step: |
| 130 | + search_space[param.name] = ( |
| 131 | + optuna.distributions.FloatDistribution( |
| 132 | + int(param.min), int(param.max), False, param.step |
| 133 | + ) |
| 134 | + ) |
| 135 | + else: |
| 136 | + search_space[param.name] = ( |
| 137 | + optuna.distributions.FloatDistribution( |
| 138 | + int(param.min), int(param.max) |
| 139 | + ) |
| 140 | + ) |
| 141 | + if param.distribution == api_pb2.LOG_UNIFORM: |
| 142 | + search_space[param.name] = optuna.distributions.FloatDistribution( |
| 143 | + int(param.min), int(param.max), True, param.step |
| 144 | + ) |
120 | 145 | elif param.type == CATEGORICAL or param.type == DISCRETE:
|
121 | 146 | search_space[param.name] = optuna.distributions.CategoricalDistribution(
|
122 | 147 | param.list
|
|
0 commit comments