Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better interpolation #25

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

- Support interpolated seed in the config file (as a reminder, the seed is treated specifically by confit to initialize random generators **before** any object is resolved)
- Support if/else expressions in interpolation, and only resolve the relevant branch

## v0.7.2 (2024-11-23)

- Seed the program *BEFORE* the config file is resolved and components have been instantiated, to ensure reproducibility.
Expand Down
6 changes: 5 additions & 1 deletion confit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def command(ctx: Context, config: Optional[List[Path]] = None):
default_seed = model_fields.get("seed")
if default_seed is not None:
default_seed = default_seed.get_default()
seed = config.get(name, {}).get("seed", default_seed)
seed = Config.resolve(
config.get(name, {}).get("seed", default_seed),
registry=registry,
root=config,
)
if seed is not None:
set_seed(seed)
resolved_config = Config(config[name]).resolve(
Expand Down
47 changes: 36 additions & 11 deletions confit/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections.abc
import keyword
import re
from collections import UserDict
from configparser import ConfigParser
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -32,6 +34,30 @@
PATH = rf"{UNQUOTED_ID}(?:[.]{PATH_PART})*"


class DynamicLocals(UserDict):
def __init__(self, mapping, root, resolved_locs, get):
super().__init__()
self.name_to_parts = {v: k for k, v in mapping.items()}
self.root = root
self.resolved_locs = resolved_locs
self.get = get

def __getitem__(self, item):
if not isinstance(item, str):
raise KeyError(item)
if item.startswith("__"):
raise KeyError(item)
current = self.root
parts = self.name_to_parts[item]
for part in parts:
current = current[part]
if id(current) not in self.resolved_locs:
resolved = self.get(current, parts)
else:
resolved = self.resolved_locs[id(current)]
return resolved


class Config(dict):
"""
The configuration system consists of a supercharged dict, the `Config` class,
Expand Down Expand Up @@ -341,28 +367,27 @@ def replace(match: re.Match):

path = match.group()
parts = split_path(path.rstrip(":"))

# Check if part is any special python keyword
if len(parts) == 1 and parts[0] in keyword.kwlist:
return match.group()
try:
return local_names[parts] + ("." if path.endswith(":") else "")
except KeyError:
raise KeyError(path)

local_leaves = {}
local_names = {}
for match in pat.finditer(ref.value):
start = match.start()
if start > 0 and ref.value[start - 1] == ":":
continue
path = match.group()
parts = split_path(path.rstrip(":"))
current = root
for part in parts:
current = current[part]
if id(current) not in resolved_locs:
resolved = rec(current, parts)
else:
resolved = resolved_locs[id(current)]
local_names[parts] = f"var_{len(local_leaves)}"
local_leaves[f"var_{len(local_leaves)}"] = resolved
if len(parts) == 1 and parts[0] in keyword.kwlist:
continue
local_names.setdefault(parts, f"var_{len(local_names)}")

local_leaves = DynamicLocals(local_names, root, resolved_locs, rec)

replaced = pat.sub(replace, ref.value)

Expand Down Expand Up @@ -439,7 +464,7 @@ def rec(obj, loc: Tuple[Union[str, int]] = ()):
while resolved is None:
try:
resolved = resolve_reference(obj)
except KeyError:
except (KeyError, NameError):
raise MissingReference(obj)
else:
resolved = obj
Expand Down
1 change: 1 addition & 0 deletions confit/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Transformer(ast.NodeTransformer):
"In",
"NotIn",
"Starred",
"IfExp",
}

def generic_visit(self, node):
Expand Down
28 changes: 28 additions & 0 deletions tests/test_config_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,31 @@ def test_yaml_str_dump():
pollution: false
"""
)


def test_if_else():
config = Config.from_yaml_str(
"""\
params:
a: 1
cond: true
c: ${params.a if params.cond else params.b}
"""
).resolve()
assert config["params"]["c"] == 1


def test_if_else_complex():
config = Config.from_yaml_str(
"""\
model:
"@factory": submodel
value: 12

params:
a: 1
cond: true
c: ${model:value if model:hidden_value == 5 else params.b}
"""
).resolve(registry=registry)
assert config["params"]["c"] == 12
Loading