diff --git a/src/alignrl/config.py b/src/alignrl/config.py index 2fd1ff6..f60bb9f 100644 --- a/src/alignrl/config.py +++ b/src/alignrl/config.py @@ -3,10 +3,14 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING import yaml from pydantic import BaseModel, Field +if TYPE_CHECKING: + from typing_extensions import Self + class BaseTrainConfig(BaseModel): """Shared training configuration.""" @@ -37,10 +41,10 @@ class BaseTrainConfig(BaseModel): load_in_4bit: bool = True @classmethod - def from_yaml(cls, path: Path) -> BaseTrainConfig: + def from_yaml(cls, path: Path) -> Self: with open(path) as f: data = yaml.safe_load(f) - return cls(**data) + return cls(**(data or {})) # ChatML template used as fallback when the tokenizer doesn't have one set. diff --git a/tests/test_config.py b/tests/test_config.py index f165401..af17b95 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,6 +32,18 @@ def test_from_yaml_partial(self, tmp_path: Path) -> None: assert cfg.model_name == "partial-model" assert cfg.learning_rate == 2e-4 # default preserved + def test_from_yaml_empty_file(self, tmp_path: Path) -> None: + yaml_path = tmp_path / "empty.yaml" + yaml_path.write_text("") + cfg = BaseTrainConfig.from_yaml(yaml_path) + assert cfg.model_name == "Qwen/Qwen2.5-3B" # all defaults + + def test_from_yaml_empty_doc(self, tmp_path: Path) -> None: + yaml_path = tmp_path / "empty_doc.yaml" + yaml_path.write_text("---\n") + cfg = BaseTrainConfig.from_yaml(yaml_path) + assert cfg.model_name == "Qwen/Qwen2.5-3B" + def test_output_dir_is_path(self) -> None: cfg = BaseTrainConfig(output_dir="./my-output") assert isinstance(cfg.output_dir, Path)