Skip to content

Commit

Permalink
feat: support if/else interpolation and falsy branch resolution skipping
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Dec 12, 2024
1 parent 04a1c6c commit 8c9ac3b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased

- Support interpolated seed in the config file (as a reminder, the seed is treated specifically by confit to initialize random generators **before** any object is resolved)
- Support if/else expressions in interpolation, and only resolve the relevant branch

## v0.7.2 (2024-11-23)

Expand Down
47 changes: 36 additions & 11 deletions confit/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections.abc
import keyword
import re
from collections import UserDict
from configparser import ConfigParser
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -32,6 +34,30 @@
PATH = rf"{UNQUOTED_ID}(?:[.]{PATH_PART})*"


class DynamicLocals(UserDict):
def __init__(self, mapping, root, resolved_locs, get):
super().__init__()
self.name_to_parts = {v: k for k, v in mapping.items()}
self.root = root
self.resolved_locs = resolved_locs
self.get = get

def __getitem__(self, item):
if not isinstance(item, str):
raise KeyError(item)
if item.startswith("__"):
raise KeyError(item)
current = self.root
parts = self.name_to_parts[item]
for part in parts:
current = current[part]
if id(current) not in self.resolved_locs:
resolved = self.get(current, parts)
else:
resolved = self.resolved_locs[id(current)]
return resolved


class Config(dict):
"""
The configuration system consists of a supercharged dict, the `Config` class,
Expand Down Expand Up @@ -341,28 +367,27 @@ def replace(match: re.Match):

path = match.group()
parts = split_path(path.rstrip(":"))

# Check if part is any special python keyword
if len(parts) == 1 and parts[0] in keyword.kwlist:
return match.group()
try:
return local_names[parts] + ("." if path.endswith(":") else "")
except KeyError:
raise KeyError(path)

local_leaves = {}
local_names = {}
for match in pat.finditer(ref.value):
start = match.start()
if start > 0 and ref.value[start - 1] == ":":
continue
path = match.group()
parts = split_path(path.rstrip(":"))
current = root
for part in parts:
current = current[part]
if id(current) not in resolved_locs:
resolved = rec(current, parts)
else:
resolved = resolved_locs[id(current)]
local_names[parts] = f"var_{len(local_leaves)}"
local_leaves[f"var_{len(local_leaves)}"] = resolved
if len(parts) == 1 and parts[0] in keyword.kwlist:
continue
local_names.setdefault(parts, f"var_{len(local_names)}")

local_leaves = DynamicLocals(local_names, root, resolved_locs, rec)

replaced = pat.sub(replace, ref.value)

Expand Down Expand Up @@ -439,7 +464,7 @@ def rec(obj, loc: Tuple[Union[str, int]] = ()):
while resolved is None:
try:
resolved = resolve_reference(obj)
except KeyError:
except (KeyError, NameError):
raise MissingReference(obj)
else:
resolved = obj
Expand Down
1 change: 1 addition & 0 deletions confit/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Transformer(ast.NodeTransformer):
"In",
"NotIn",
"Starred",
"IfExp",
}

def generic_visit(self, node):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_config_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,31 @@ def test_yaml_str_dump():
pollution: false
"""
)


def test_if_else():
config = Config.from_yaml_str(
"""\
params:
a: 1
cond: true
c: ${params.a if params.cond else params.b}
"""
).resolve()
assert config["params"]["c"] == 1


def test_if_else_complex():
config = Config.from_yaml_str(
"""\
model:
"@factory": submodel
value: 12
params:
a: 1
cond: true
c: ${model:value if model:hidden_value == 5 else params.b}
"""
).resolve(registry=registry)
assert config["params"]["c"] == 12

0 comments on commit 8c9ac3b

Please sign in to comment.