Skip to content

Commit

Permalink
fix: 🚨 seed the program *before* config fields are instantiated
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 26, 2024
1 parent ed8116b commit a29e699
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleased

- Seed the program *BEFORE* the config file is resolved and components have been instantiated, to ensure reproducibility.

## v0.7.1 (2024-11-21)

- Force utf-8 encoding when writing a config file (ini or yaml)
Expand Down
6 changes: 3 additions & 3 deletions confit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def command(ctx: Context, config: Optional[List[Path]] = None):
current = current.setdefault(part, Config())
current[parts[-1]] = v
try:
resolved_config = Config(config[name]).resolve(
registry=registry, root=config
)
default_seed = model_fields.get("seed")
if default_seed is not None:
default_seed = default_seed.get_default()
seed = config.get(name, {}).get("seed", default_seed)
if seed is not None:
set_seed(seed)
resolved_config = Config(config[name]).resolve(
registry=registry, root=config
)
if has_meta:
config_meta = dict(
config_path=config_path,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import random

import pytest
from typer.testing import CliRunner
Expand Down Expand Up @@ -209,3 +210,31 @@ def test_fail_override(change_test_dir):
assert result.exit_code == 1
# CLI detects bool param for other which is converted to 1 since we expect an int
assert "does not match any existing section in config" in str(result.exception)


seed_app = Cli(pretty_exceptions_show_locals=False)


@registry.factory.register("randmodel")
class RandModel:
def __init__(self):
self.value = random.randint(0, 100000)


@seed_app.command(name="seed", registry=registry)
def print_seed(model: RandModel, seed: int):
print("Value:", model.value)


def test_seed(change_test_dir):
"""Checks that the program running twice will generate the same random numbers"""
result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"])
assert result.exit_code == 0, result.stdout
first_seed = int(result.stdout.split(":")[1].strip())

result = runner.invoke(seed_app, ["--seed", "42", "--model.@factory", "randmodel"])
assert result.exit_code == 0, result.stdout
second_seed = int(result.stdout.split(":")[1].strip())
print(first_seed, second_seed)

assert first_seed == second_seed

0 comments on commit a29e699

Please sign in to comment.