diff --git a/changelog.md b/changelog.md index d35278a..2834c12 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,9 @@ # Changelog +## Unreleased + +- Seed the program *BEFORE* the config file is resolved and components have been instantiated, to ensure reproducibility. + ## v0.7.1 (2024-11-21) - Force utf-8 encoding when writing a config file (ini or yaml) diff --git a/confit/cli.py b/confit/cli.py index de546fc..25ca718 100644 --- a/confit/cli.py +++ b/confit/cli.py @@ -130,15 +130,15 @@ def command(ctx: Context, config: Optional[List[Path]] = None): current = current.setdefault(part, Config()) current[parts[-1]] = v try: - resolved_config = Config(config[name]).resolve( - registry=registry, root=config - ) default_seed = model_fields.get("seed") if default_seed is not None: default_seed = default_seed.get_default() seed = config.get(name, {}).get("seed", default_seed) if seed is not None: set_seed(seed) + resolved_config = Config(config[name]).resolve( + registry=registry, root=config + ) if has_meta: config_meta = dict( config_path=config_path, diff --git a/tests/test_cli.py b/tests/test_cli.py index c7bb25f..4f16968 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import datetime +import random import pytest from typer.testing import CliRunner @@ -209,3 +210,31 @@ def test_fail_override(change_test_dir): assert result.exit_code == 1 # CLI detects bool param for other which is converted to 1 since we expect an int assert "does not match any existing section in config" in str(result.exception) + + +seed_app = Cli(pretty_exceptions_show_locals=False) + + +@registry.factory.register("randmodel") +class RandModel: + def __init__(self): + self.value = random.randint(0, 100000) + + +@seed_app.command(name="seed", registry=registry) +def print_seed(model: RandModel, seed: int): + print("Value:", model.value) + + +def test_seed(change_test_dir): + """Checks that the program running twice will generate the same random numbers""" + result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"]) + assert result.exit_code == 0, result.stdout + first_seed = int(result.stdout.split(":")[1].strip()) + + result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"]) + assert result.exit_code == 0, result.stdout + second_seed = int(result.stdout.split(":")[1].strip()) + print(first_seed, second_seed) + + assert first_seed == second_seed