diff --git a/changelog.md b/changelog.md index f14b0b0..4292b22 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,7 @@ ## Unreleased - Support interpolated seed in the config file (as a reminder, the seed is treated specifically by confit to initialize random generators **before** any object is resolved) +- Support if/else expressions in interpolation, and only resolve the relevant branch ## v0.7.2 (2024-11-23) diff --git a/confit/config.py b/confit/config.py index 16d4765..648cdef 100644 --- a/confit/config.py +++ b/confit/config.py @@ -1,5 +1,7 @@ import collections.abc +import keyword import re +from collections import UserDict from configparser import ConfigParser from io import StringIO from pathlib import Path @@ -32,6 +34,30 @@ PATH = rf"{UNQUOTED_ID}(?:[.]{PATH_PART})*" +class DynamicLocals(UserDict): + def __init__(self, mapping, root, resolved_locs, get): + super().__init__() + self.name_to_parts = {v: k for k, v in mapping.items()} + self.root = root + self.resolved_locs = resolved_locs + self.get = get + + def __getitem__(self, item): + if not isinstance(item, str): + raise KeyError(item) + if item.startswith("__"): + raise KeyError(item) + current = self.root + parts = self.name_to_parts[item] + for part in parts: + current = current[part] + if id(current) not in self.resolved_locs: + resolved = self.get(current, parts) + else: + resolved = self.resolved_locs[id(current)] + return resolved + + class Config(dict): """ The configuration system consists of a supercharged dict, the `Config` class, @@ -341,12 +367,15 @@ def replace(match: re.Match): path = match.group() parts = split_path(path.rstrip(":")) + + # Check if part is any special python keyword + if len(parts) == 1 and parts[0] in keyword.kwlist: + return match.group() try: return local_names[parts] + ("." if path.endswith(":") else "") except KeyError: raise KeyError(path) - local_leaves = {} local_names = {} for match in pat.finditer(ref.value): start = match.start() @@ -354,15 +383,11 @@ def replace(match: re.Match): continue path = match.group() parts = split_path(path.rstrip(":")) - current = root - for part in parts: - current = current[part] - if id(current) not in resolved_locs: - resolved = rec(current, parts) - else: - resolved = resolved_locs[id(current)] - local_names[parts] = f"var_{len(local_leaves)}" - local_leaves[f"var_{len(local_leaves)}"] = resolved + if len(parts) == 1 and parts[0] in keyword.kwlist: + continue + local_names.setdefault(parts, f"var_{len(local_names)}") + + local_leaves = DynamicLocals(local_names, root, resolved_locs, rec) replaced = pat.sub(replace, ref.value) @@ -439,7 +464,7 @@ def rec(obj, loc: Tuple[Union[str, int]] = ()): while resolved is None: try: resolved = resolve_reference(obj) - except KeyError: + except (KeyError, NameError): raise MissingReference(obj) else: resolved = obj diff --git a/confit/utils/eval.py b/confit/utils/eval.py index ab5594a..67daf03 100644 --- a/confit/utils/eval.py +++ b/confit/utils/eval.py @@ -50,6 +50,7 @@ class Transformer(ast.NodeTransformer): "In", "NotIn", "Starred", + "IfExp", } def generic_visit(self, node): diff --git a/tests/test_config_instance.py b/tests/test_config_instance.py index f7c6e0f..90dc6f0 100644 --- a/tests/test_config_instance.py +++ b/tests/test_config_instance.py @@ -657,3 +657,31 @@ def test_yaml_str_dump(): pollution: false """ ) + + +def test_if_else(): + config = Config.from_yaml_str( + """\ +params: + a: 1 + cond: true + c: ${params.a if params.cond else params.b} +""" + ).resolve() + assert config["params"]["c"] == 1 + + +def test_if_else_complex(): + config = Config.from_yaml_str( + """\ +model: + "@factory": submodel + value: 12 + +params: + a: 1 + cond: true + c: ${model:value if model:hidden_value == 5 else params.b} +""" + ).resolve(registry=registry) + assert config["params"]["c"] == 12