diff --git a/tests/test_csv_loader.py b/tests/test_csv_loader.py new file mode 100644 index 00000000..6d466309 --- /dev/null +++ b/tests/test_csv_loader.py @@ -0,0 +1,201 @@ +import os +from typing import Any + +import numpy as np +import pytest + +from src.stimulus.data.csv import CsvLoader +from src.stimulus.data.experiments import DnaToFloatExperiment, ProtDnaToFloatExperiment + + +class DataCsvLoader: + """Helper class to store CsvLoader objects and expected values for testing. + + This class initializes CsvLoader objects with given csv data and stores expected + values for testing purposes. + + Attributes: + experiment: An experiment instance to process the data. + csv_path (str): Absolute path to the CSV file. + csv_loader (CsvLoader): Initialized CsvLoader object. + data_length (int, optional): Expected length of the data. + shape_splits (dict, optional): Expected split indices and their lengths. + """ + + def __init__(self, filename: str, experiment: Any): + """Initialize DataCsvLoader with a CSV file and experiment type. + + Args: + filename (str): Path to the CSV file. + experiment (Any): Experiment class to be instantiated. + """ + self.experiment = experiment() + self.csv_path = os.path.abspath(filename) + self.csv_loader = CsvLoader(self.experiment, self.csv_path) + self.data_length = None + self.shape_splits = None + + +@pytest.fixture +def dna_test_data(): + """This stores the basic dna test csv""" + data = DataCsvLoader("tests/test_data/dna_experiment/test.csv", DnaToFloatExperiment) + data.data_length = 2 + return data + + +@pytest.fixture +def dna_test_data_with_split(): + """This stores the basic dna test csv with split""" + data = DataCsvLoader("tests/test_data/dna_experiment/test_with_split.csv", DnaToFloatExperiment) + data.data_length = 48 + data.shape_splits = {0: 16, 1: 16, 2: 16} + return data + + +@pytest.fixture +def prot_dna_test_data(): + """This stores the basic prot-dna test csv""" + data = DataCsvLoader("tests/test_data/prot_dna_experiment/test.csv", ProtDnaToFloatExperiment) + data.data_length = 2 + return data + + +@pytest.fixture +def prot_dna_test_data_with_split(): + """This stores the basic prot-dna test csv with split""" + data = DataCsvLoader("tests/test_data/prot_dna_experiment/test_with_split.csv", ProtDnaToFloatExperiment) + data.data_length = 3 + data.shape_splits = {0: 1, 1: 1, 2: 1} + return data + + +@pytest.mark.parametrize( + "fixture_name", + [ + ("dna_test_data"), + ("dna_test_data_with_split"), + ("prot_dna_test_data"), + ("prot_dna_test_data_with_split"), + ], +) +def test_data_length(request, fixture_name: str): + """Verify data is loaded with correct length. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + """ + data = request.getfixturevalue(fixture_name) + assert len(data.csv_loader) == data.data_length + + +@pytest.mark.parametrize( + "fixture_name", + [ + ("dna_test_data"), + ("prot_dna_test_data"), + ], +) +def test_parse_csv_to_input_label_meta(request, fixture_name: str): + """Test parsing of CSV to input, label, and meta. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + + Verifies: + - Input data is a dictionary + - Label data is a dictionary + - Meta data is a dictionary + """ + data = request.getfixturevalue(fixture_name) + assert isinstance(data.csv_loader.input, dict) + assert isinstance(data.csv_loader.label, dict) + assert isinstance(data.csv_loader.meta, dict) + + +@pytest.mark.parametrize( + "fixture_name", + [ + ("dna_test_data"), + ("prot_dna_test_data"), + ], +) +def test_get_all_items(request, fixture_name: str): + """Test retrieval of all items from the CSV loader. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + + Verifies: + - All returned data (input, label, meta) are dictionaries + """ + data = request.getfixturevalue(fixture_name) + input_data, label_data, meta_data = data.csv_loader.get_all_items() + assert isinstance(input_data, dict) + assert isinstance(label_data, dict) + assert isinstance(meta_data, dict) + + +@pytest.mark.parametrize( + "fixture_name,slice,expected_length", + [ + ("dna_test_data", 0, 1), + ("dna_test_data", slice(0, 2), 2), + ("prot_dna_test_data", 0, 1), + ("prot_dna_test_data", slice(0, 2), 2), + ], +) +def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_length: int): + """Test retrieval of encoded items through slicing. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + slice (int or slice): Index or slice object for data access. + expected_length (int): Expected length of the retrieved data. + + Verifies: + - Returns 3 dictionaries (input, label, meta) + - All items are encoded as numpy arrays + - Arrays have the expected length + """ + data = request.getfixturevalue(fixture_name) + encoded_items = data.csv_loader[slice] + + assert len(encoded_items) == 3 + for i in range(3): + assert isinstance(encoded_items[i], dict) + for item in encoded_items[i].values(): + assert isinstance(item, np.ndarray) + if expected_length > 1: + assert len(item) == expected_length + + +@pytest.mark.parametrize( + "fixture_name", + [ + ("dna_test_data_with_split"), + ("prot_dna_test_data_with_split"), + ], +) +def test_splitting(request, fixture_name): + """Test data splitting functionality. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + + Verifies: + - Data can be loaded with different split indices + - Splits have correct lengths + - Invalid split index raises ValueError + """ + data = request.getfixturevalue(fixture_name) + for i in [0, 1, 2]: + data_i = CsvLoader(data.experiment, data.csv_path, split=i) + assert len(data_i) == data.shape_splits[i] + with pytest.raises(ValueError): + CsvLoader(data.experiment, data.csv_path, split=3) diff --git a/tests/test_csv_processing.py b/tests/test_csv_processing.py index cd2e60d4..23f5fc3e 100644 --- a/tests/test_csv_processing.py +++ b/tests/test_csv_processing.py @@ -35,6 +35,7 @@ def __init__(self, filename: str, experiment: Any): self.expected_split = None self.expected_transformed_values = None + @pytest.fixture def dna_test_data(): """This stores the basic dna test csv""" @@ -49,6 +50,7 @@ def dna_test_data(): } return data + @pytest.fixture def dna_test_data_long(): """This stores the long dna test csv""" @@ -56,19 +58,24 @@ def dna_test_data_long(): data.data_length = 1000 return data + @pytest.fixture def dna_test_data_long_shuffled(): """This stores the shuffled long dna test csv""" - data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment) + data = DataCsvProcessing( + "tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment + ) data.data_length = 1000 return data + @pytest.fixture def dna_config(): """This is the config file for the dna experiment""" with open("tests/test_data/dna_experiment/test_config.json") as f: return json.load(f) - + + @pytest.fixture def prot_dna_test_data(): """This stores the basic prot-dna test csv""" @@ -84,21 +91,26 @@ def prot_dna_test_data(): } return data + @pytest.fixture def prot_dna_config(): """This is the config file for the prot experiment""" with open("tests/test_data/prot_dna_experiment/test_config.json") as f: return json.load(f) -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data"), ("dna_test_data_long"), ("dna_test_data_long_shuffled"), - ("prot_dna_test_data") - ]) + ("prot_dna_test_data"), + ], +) def test_data_length(request, fixture_name): """Test that data is loaded with the correct length. - + Args: request: Pytest fixture request object. fixture_name (str): Name of the fixture to test. @@ -108,13 +120,17 @@ def test_data_length(request, fixture_name): data = request.getfixturevalue(fixture_name) assert len(data.csv_processing.data) == data.data_length -@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [ + +@pytest.mark.parametrize( + "fixture_data_name,fixture_config_name", + [ ("dna_test_data", "dna_config"), - ("prot_dna_test_data", "prot_dna_config") - ]) + ("prot_dna_test_data", "prot_dna_config"), + ], +) def test_add_split(request, fixture_data_name, fixture_config_name): """Test that the add_split function properly adds the split column. - + Args: request: Pytest fixture request object. fixture_data_name (str): Name of the data fixture to test. @@ -128,13 +144,17 @@ def test_add_split(request, fixture_data_name, fixture_config_name): data.csv_processing.add_split(config["split"]) assert data.csv_processing.data["split:split:int"].to_list() == data.expected_split -@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [ + +@pytest.mark.parametrize( + "fixture_data_name,fixture_config_name", + [ ("dna_test_data", "dna_config"), - ("prot_dna_test_data", "prot_dna_config") - ]) + ("prot_dna_test_data", "prot_dna_config"), + ], +) def test_transform_data(request, fixture_data_name, fixture_config_name): """Test that transformation functionalities properly transform the data. - + Args: request: Pytest fixture request object. fixture_data_name (str): Name of the data fixture to test. @@ -144,7 +164,7 @@ def test_transform_data(request, fixture_data_name, fixture_config_name): """ data = request.getfixturevalue(fixture_data_name) config = request.getfixturevalue(fixture_config_name) - + data.csv_processing.add_split(config["split"]) data.csv_processing.transform(config["transform"]) @@ -153,13 +173,14 @@ def test_transform_data(request, fixture_data_name, fixture_config_name): observed_values = [round(v, 6) if isinstance(v, float) else v for v in observed_values] assert observed_values == expected_values + def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled): """Test that shuffling of labels works correctly. - + This test verifies that when labels are shuffled with a fixed seed, they match the expected shuffled values from a pre-computed dataset. Currently only tests the long DNA test data. - + Args: dna_test_data_long: Fixture containing the original unshuffled DNA test data. dna_test_data_long_shuffled: Fixture containing the expected shuffled DNA test data. @@ -169,4 +190,3 @@ def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled): dna_test_data_long.csv_processing.data["hola:label:float"], dna_test_data_long_shuffled.csv_processing.data["hola:label:float"], ) -