|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | # ------------------------------------------------------------------------------
|
3 | 3 | #
|
4 |
| -# Copyright 2022-2023 Valory AG |
| 4 | +# Copyright 2022-2024 Valory AG |
5 | 5 | # Copyright 2018-2019 Fetch.AI Limited
|
6 | 6 | #
|
7 | 7 | # Licensed under the Apache License, Version 2.0 (the "License");
|
|
34 | 34 |
|
35 | 35 |
|
36 | 36 | ENV_VARIABLE_RE = re.compile(r"^\$\{(([A-Z0-9_]+):?)?([a-z]+)?(:(.+))?}$")
|
| 37 | +MODELS = "models" |
| 38 | +ARGS = "args" |
| 39 | +ARGS_LEVEL_FROM_MODELS = 2 |
| 40 | +ARG_LEVEL_FROM_MODELS = ARGS_LEVEL_FROM_MODELS + 1 |
| 41 | +RESTRICTION_EXCEPTIONS = frozenset({"setup", "genesis_config"}) |
37 | 42 |
|
38 | 43 |
|
39 | 44 | def is_env_variable(value: Any) -> bool:
|
40 | 45 | """Check is variable string with env variable pattern."""
|
41 | 46 | return isinstance(value, str) and bool(ENV_VARIABLE_RE.match(value))
|
42 | 47 |
|
43 | 48 |
|
44 |
| -def export_path_to_env_var_string(export_path: List[str]) -> str: |
45 |
| - """Conver export path to environment variable string.""" |
| 49 | +def restrict_model_args(export_path: List[str]) -> Tuple[List[str], List[str]]: |
| 50 | + """Do not allow more levels than one for a model's argument.""" |
| 51 | + restricted = [] |
| 52 | + result = [] |
| 53 | + for i, current_path in enumerate(export_path): |
| 54 | + result.append(current_path) |
| 55 | + args_level = i + ARGS_LEVEL_FROM_MODELS |
| 56 | + arg_level = i + ARG_LEVEL_FROM_MODELS |
| 57 | + if ( |
| 58 | + current_path == MODELS |
| 59 | + and arg_level < len(export_path) |
| 60 | + and export_path[args_level] == ARGS |
| 61 | + and export_path[arg_level] not in RESTRICTION_EXCEPTIONS |
| 62 | + ): |
| 63 | + # do not allow more levels than one for a model's argument |
| 64 | + arg_content_level = arg_level + 1 |
| 65 | + result.extend(export_path[i + 1 : arg_content_level]) |
| 66 | + # store the restricted part of the path |
| 67 | + for j in range(arg_content_level, len(export_path)): |
| 68 | + restricted.append(export_path[j]) |
| 69 | + break |
| 70 | + return restricted, result |
| 71 | + |
| 72 | + |
| 73 | +def export_path_to_env_var_string(export_path: List[str]) -> Tuple[List[str], str]: |
| 74 | + """Convert export path to environment variable string.""" |
| 75 | + restricted, export_path = restrict_model_args(export_path) |
46 | 76 | env_var_string = "_".join(map(str, export_path))
|
47 |
| - return env_var_string.upper() |
| 77 | + return restricted, env_var_string.upper() |
48 | 78 |
|
49 | 79 |
|
50 | 80 | NotSet = object()
|
@@ -149,7 +179,7 @@ def apply_env_variables(
|
149 | 179 | data,
|
150 | 180 | env_variables,
|
151 | 181 | default_value,
|
152 |
| - default_var_name=export_path_to_env_var_string(export_path=path), |
| 182 | + default_var_name=export_path_to_env_var_string(export_path=path)[1], |
153 | 183 | )
|
154 | 184 |
|
155 | 185 | return data
|
@@ -242,35 +272,76 @@ def is_strict_list(data: Union[List, Tuple]) -> bool:
|
242 | 272 | return is_strict
|
243 | 273 |
|
244 | 274 |
|
| 275 | +def list_to_nested_dict(lst: list, val: Any) -> dict: |
| 276 | + """Convert a list to a nested dict.""" |
| 277 | + nested_dict = val |
| 278 | + for item in reversed(lst): |
| 279 | + nested_dict = {item: nested_dict} |
| 280 | + return nested_dict |
| 281 | + |
| 282 | + |
| 283 | +def ensure_dict(dict_: Dict[str, Union[dict, str]]) -> dict: |
| 284 | + """Return the given dictionary converting any values which are json strings as dicts.""" |
| 285 | + return {k: json.loads(v) for k, v in dict_.items() if isinstance(v, str)} |
| 286 | + |
| 287 | + |
| 288 | +def ensure_json_content(dict_: dict) -> dict: |
| 289 | + """Return the given dictionary converting any nested dictionary values as json strings.""" |
| 290 | + return {k: json.dumps(v) for k, v in dict_.items() if isinstance(v, dict)} |
| 291 | + |
| 292 | + |
| 293 | +def merge_dicts(a: dict, b: dict) -> dict: |
| 294 | + """Merge two dictionaries.""" |
| 295 | + # shallow copy of `a` |
| 296 | + merged = {**a} |
| 297 | + for key, value in b.items(): |
| 298 | + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): |
| 299 | + # recursively merge nested dictionaries |
| 300 | + merged[key] = merge_dicts(merged[key], value) |
| 301 | + else: |
| 302 | + # if not a nested dictionary, just take the value from `b` |
| 303 | + merged[key] = value |
| 304 | + return merged |
| 305 | + |
| 306 | + |
245 | 307 | def generate_env_vars_recursively(
|
246 | 308 | data: Union[Dict, List],
|
247 | 309 | export_path: List[str],
|
248 | 310 | ) -> Dict:
|
249 | 311 | """Generate environment variables recursively."""
|
250 |
| - env_var_dict = {} |
| 312 | + env_var_dict: Dict[str, Any] = {} |
251 | 313 |
|
252 | 314 | if isinstance(data, dict):
|
253 | 315 | for key, value in data.items():
|
254 |
| - env_var_dict.update( |
255 |
| - generate_env_vars_recursively( |
256 |
| - data=value, |
257 |
| - export_path=[*export_path, key], |
258 |
| - ) |
| 316 | + res = generate_env_vars_recursively( |
| 317 | + data=value, |
| 318 | + export_path=[*export_path, key], |
259 | 319 | )
|
| 320 | + if res: |
| 321 | + env_var = list(res.keys())[0] |
| 322 | + if env_var in env_var_dict: |
| 323 | + dicts = (ensure_dict(dict_) for dict_ in (env_var_dict, res)) |
| 324 | + res = ensure_json_content(merge_dicts(*dicts)) |
| 325 | + env_var_dict.update(res) |
260 | 326 | elif isinstance(data, list):
|
261 | 327 | if is_strict_list(data=data):
|
262 |
| - env_var_dict[ |
263 |
| - export_path_to_env_var_string(export_path=export_path) |
264 |
| - ] = json.dumps(data, separators=(",", ":")) |
| 328 | + restricted, path = export_path_to_env_var_string(export_path=export_path) |
| 329 | + if restricted: |
| 330 | + env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data)) |
| 331 | + else: |
| 332 | + env_var_dict[path] = json.dumps(data, separators=(",", ":")) |
265 | 333 | else:
|
266 | 334 | for key, value in enumerate(data):
|
267 |
| - env_var_dict.update( |
268 |
| - generate_env_vars_recursively( |
269 |
| - data=value, |
270 |
| - export_path=[*export_path, key], |
271 |
| - ) |
| 335 | + res = generate_env_vars_recursively( |
| 336 | + data=value, |
| 337 | + export_path=[*export_path, key], |
272 | 338 | )
|
| 339 | + env_var_dict.update(res) |
273 | 340 | else:
|
274 |
| - env_var_dict[export_path_to_env_var_string(export_path=export_path)] = data |
| 341 | + restricted, path = export_path_to_env_var_string(export_path=export_path) |
| 342 | + if restricted: |
| 343 | + env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data)) |
| 344 | + else: |
| 345 | + env_var_dict[path] = data |
275 | 346 |
|
276 | 347 | return env_var_dict
|
0 commit comments