Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeded CLI #24

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## v0.7.2 (2024-11-23)

- Seed the program *BEFORE* the config file is resolved and components have been instantiated, to ensure reproducibility.

## v0.7.1 (2024-11-21)

- Force utf-8 encoding when writing a config file (ini or yaml)
Expand Down
2 changes: 1 addition & 1 deletion confit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
VisibleDeprecationWarning,
)

__version__ = "0.7.1"
__version__ = "0.7.2"
6 changes: 3 additions & 3 deletions confit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def command(ctx: Context, config: Optional[List[Path]] = None):
current = current.setdefault(part, Config())
current[parts[-1]] = v
try:
resolved_config = Config(config[name]).resolve(
registry=registry, root=config
)
default_seed = model_fields.get("seed")
if default_seed is not None:
default_seed = default_seed.get_default()
seed = config.get(name, {}).get("seed", default_seed)
if seed is not None:
set_seed(seed)
resolved_config = Config(config[name]).resolve(
registry=registry, root=config
)
if has_meta:
config_meta = dict(
config_path=config_path,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import random

import pytest
from typer.testing import CliRunner
Expand Down Expand Up @@ -209,3 +210,31 @@ def test_fail_override(change_test_dir):
assert result.exit_code == 1
# CLI detects bool param for other which is converted to 1 since we expect an int
assert "does not match any existing section in config" in str(result.exception)


seed_app = Cli(pretty_exceptions_show_locals=False)


@registry.factory.register("randmodel")
class RandModel:
def __init__(self):
self.value = random.randint(0, 100000)


@seed_app.command(name="seed", registry=registry)
def print_seed(model: RandModel, seed: int):
print("Value:", model.value)


def test_seed(change_test_dir):
"""Checks that the program running twice will generate the same random numbers"""
result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"])
assert result.exit_code == 0, result.stdout
first_seed = int(result.stdout.split(":")[1].strip())

result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"])
assert result.exit_code == 0, result.stdout
second_seed = int(result.stdout.split(":")[1].strip())
print(first_seed, second_seed)

assert first_seed == second_seed
Loading