Skip to content

Commit

Permalink
Merge pull request #38 from itrujnara/dev
Browse files Browse the repository at this point in the history
Add snapshot testing
  • Loading branch information
mathysgrapotte authored Jan 17, 2025
2 parents 60b0f2d + 3dae2b0 commit 8cfab78
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"safetensors>=0.4.5",
"scikit-learn>=1.5.0",
"scipy==1.14.1",
"syrupy>=4.8.0",
"torch>=2.2.2",
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'"
]
Expand Down
2 changes: 1 addition & 1 deletion src/stimulus/utils/yaml_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class YamlSplit(BaseModel):
split_input_columns: List[str]



class YamlConfigDict(BaseModel):
global_params: YamlGlobalParams
columns: List[YamlColumns]
Expand Down Expand Up @@ -321,7 +322,6 @@ def fix_params(input_dict):

dict_data = fix_params(dict_data)

# Write to file
with open(f"{directory_path}/{base_name}_{i}.yaml", "w") as f:
yaml.dump(
dict_data,
Expand Down
8 changes: 8 additions & 0 deletions tests/cli/__snapshots__/test_split_yaml.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# serializer version: 1
# name: test_split_yaml[correct_yaml_path-None]
list([
'455bac9343934e1ff40130ee94d5aa29',
'5a8a9dd96d15932d28254bde3949d7ea',
'a66d7aa1817e90ecdc81f02591f50289',
])
# ---
13 changes: 11 additions & 2 deletions tests/cli/test_split_yaml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import os
import tempfile

import pytest
Expand Down Expand Up @@ -27,12 +29,19 @@ def wrong_yaml_path() -> str:

# Tests
@pytest.mark.parametrize("yaml_type, error", test_cases)
def test_split_yaml(request: pytest.FixtureRequest, yaml_type: str, error: Exception | None) -> None:
def test_split_yaml(request: pytest.FixtureRequest, snapshot, yaml_type: str, error: Exception | None) -> None:
"""Tests the CLI command with correct and wrong YAML files."""
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = tempfile.gettempdir()
if error:
with pytest.raises(error):
main(yaml_path, tmpdir)
else:
assert main(yaml_path, tmpdir) is None
assert main(yaml_path, tmpdir) is None # this is to assert that the function does not raise any exceptions
files = os.listdir(tmpdir)
test_out = [f for f in files if f.startswith("test_")]
hashes = []
for f in test_out:
with open(os.path.join(tmpdir, f)) as file:
hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324
assert hashes == snapshot
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Configuration for the pytest test suite."""

pytest_plugins = ("syrupy",)
3 changes: 0 additions & 3 deletions tests/data/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def test_dataset_manager_get_transform_logic(dump_single_split_config_to_disk):
assert transform_logic["transformation_name"] == "noise"
assert len(transform_logic["transformations"]) == 2


# Test EncodeManager
def test_encode_manager_init():
encoder_loader = experiments.EncoderLoader()
Expand Down Expand Up @@ -166,7 +165,6 @@ def test_split_manager_apply_split(split_loader):
assert len(split_indices[1]) == 15
assert len(split_indices[2]) == 15


# Test DatasetHandler


Expand All @@ -185,7 +183,6 @@ def test_dataset_handler_init(
assert isinstance(handler.transform_manager, TransformManager)
assert isinstance(handler.split_manager, SplitManager)


def test_dataset_hanlder_apply_split(
dump_single_split_config_to_disk, titanic_csv_path, encoder_loader, transform_loader, split_loader
):
Expand Down
1 change: 1 addition & 0 deletions tests/data/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def TextOneHotEncoder_name_and_params():
return "TextOneHotEncoder", {"alphabet": "acgt"}



def test_get_encoder(TextOneHotEncoder_name_and_params):
"""Test the get_encoder method of the AbstractExperiment class.
Expand Down

0 comments on commit 8cfab78

Please sign in to comment.