From 342be760dee9242608a5533f3d871bacb73be56d Mon Sep 17 00:00:00 2001 From: Suzanne Jin Date: Fri, 8 Nov 2024 14:54:51 +0100 Subject: [PATCH 1/3] add some test in test_csv_loader.py --- tests/test_csv_loader.py | 125 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tests/test_csv_loader.py diff --git a/tests/test_csv_loader.py b/tests/test_csv_loader.py new file mode 100644 index 00000000..e7a469a2 --- /dev/null +++ b/tests/test_csv_loader.py @@ -0,0 +1,125 @@ +import json +import os +from typing import Any + +import numpy as np +import numpy.testing as npt +import pytest + +from src.stimulus.data.csv import CsvLoader +from src.stimulus.data.experiments import DnaToFloatExperiment, ProtDnaToFloatExperiment + + +class DataCsvLoader: + """It stores the CsvLoader objects initialized on a given csv data and the expected values. + + One can use this class to create the data fixtures. + """ + + def __init__(self, filename: str, experiment: Any): + self.experiment = experiment() + self.csv_path = os.path.abspath(filename) + self.csv_loader = CsvLoader(self.experiment, self.csv_path) + self.data_length = None + + +@pytest.fixture +def dna_test_data(): + """This stores the basic dna test csv""" + data = DataCsvLoader("tests/test_data/dna_experiment/test.csv", DnaToFloatExperiment) + data.data_length = 2 + return data + +@pytest.fixture +def dna_test_data_with_split(): + """This stores the basic dna test csv with split""" + data = DataCsvLoader("tests/test_data/dna_experiment/test_with_split.csv", DnaToFloatExperiment) + data.data_length = 48 + data.shape_splits = {0: 16, 1: 16, 2: 16} + return data + +@pytest.fixture +def prot_dna_test_data(): + """This stores the basic prot-dna test csv""" + data = DataCsvLoader("tests/test_data/prot_dna_experiment/test.csv", ProtDnaToFloatExperiment) + data.data_length = 2 + return data + +@pytest.fixture +def prot_dna_test_data_with_split(): + """This stores the basic prot-dna test csv with split""" + data = DataCsvLoader("tests/test_data/prot_dna_experiment/test_with_split.csv", ProtDnaToFloatExperiment) + data.data_length = 3 + data.shape_splits = {0: 1, 1: 1, 2: 1} + return data + +@pytest.mark.parametrize("fixture_name", [ + ("dna_test_data"), + ("prot_dna_test_data") + ]) +def test_data_length(request, fixture_name): + """Verify data is loaded with correct length""" + data = request.getfixturevalue(fixture_name) + assert len(data.csv_loader) == data.data_length + +@pytest.mark.parametrize("fixture_name", [ + ("dna_test_data"), + ("prot_dna_test_data") + ]) +def test_parse_csv_to_input_label_meta(request, fixture_name): + """Test parsing of CSV to input, label, and meta.""" + data = request.getfixturevalue(fixture_name) + assert isinstance(data.csv_loader.input, dict) + assert isinstance(data.csv_loader.label, dict) + assert isinstance(data.csv_loader.meta, dict) + +@pytest.mark.parametrize("fixture_name", [ + ("dna_test_data"), + ("prot_dna_test_data") + ]) +def test_get_all_items(request, fixture_name): + """Test getting all items.""" + data = request.getfixturevalue(fixture_name) + input_data, label_data, meta_data = data.get_all_items() + assert isinstance(input_data, dict) + assert isinstance(label_data, dict) + assert isinstance(meta_data, dict) + +@pytest.mark.parametrize("fixture_name,slice,expected_length", [ + ("dna_test_data", 0, 1), + ("dna_test_data", slice(0,2), 2), + ("prot_dna_test_data", 0, 1), + ("prot_dna_test_data", slice(0,2), 2) + ]) +def test_get_encoded_item(request, fixture_name, slice, expected_length): + """Check that one can get the items properly through slicing.""" + # get encoded item + data = request.getfixturevalue(fixture_name) + encoded_items = data.csv_loader[slice] + + # it should have 3 dictionaries for input, label, and meta + assert len(encoded_items) == 3 + for i in range(3): + assert isinstance(encoded_items[i], dict) + + # for each dictionary, check that the sliced items are encoded as numpy arrays, and that match the expected length + for item in encoded_items[i].values(): + assert isinstance(item, np.ndarray) + if (expected_length > 1): # If the expected length is 0, this will fail as we are trying to find the length of an object size 0. + assert len(item) == expected_length + +@pytest.mark.parametrize("fixture_name", [ + ("dna_test_data_with_split"), + ("prot_dna_test_data_with_split") + ]) +def test_load_with_split(request, fixture_name): + """Test loading with split.""" + data = request.getfixturevalue(fixture_name) + assert len(data.csv_loader) == data.data_length + + for i in [0, 1, 2]: + data_i = CsvLoader(data.experiment, data.csv_path, split=i) + assert len(data_i) == data.shape_splits[i] + + # with self.assertRaises(ValueError): + # CsvLoader(self.experiment, self.csv_path_split, split=3) From 6b9aeca338cf91e0a48e8c628eb22aec13e49c15 Mon Sep 17 00:00:00 2001 From: suzannejin Date: Fri, 8 Nov 2024 15:51:54 +0100 Subject: [PATCH 2/3] finished test_csv_loader.py --- tests/test_csv_loader.py | 105 +++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 26 deletions(-) diff --git a/tests/test_csv_loader.py b/tests/test_csv_loader.py index e7a469a2..2c53100e 100644 --- a/tests/test_csv_loader.py +++ b/tests/test_csv_loader.py @@ -1,9 +1,7 @@ -import json import os from typing import Any import numpy as np -import numpy.testing as npt import pytest from src.stimulus.data.csv import CsvLoader @@ -11,16 +9,31 @@ class DataCsvLoader: - """It stores the CsvLoader objects initialized on a given csv data and the expected values. + """Helper class to store CsvLoader objects and expected values for testing. - One can use this class to create the data fixtures. + This class initializes CsvLoader objects with given csv data and stores expected + values for testing purposes. + + Attributes: + experiment: An experiment instance to process the data. + csv_path (str): Absolute path to the CSV file. + csv_loader (CsvLoader): Initialized CsvLoader object. + data_length (int, optional): Expected length of the data. + shape_splits (dict, optional): Expected split indices and their lengths. """ def __init__(self, filename: str, experiment: Any): + """Initialize DataCsvLoader with a CSV file and experiment type. + + Args: + filename (str): Path to the CSV file. + experiment (Any): Experiment class to be instantiated. + """ self.experiment = experiment() self.csv_path = os.path.abspath(filename) self.csv_loader = CsvLoader(self.experiment, self.csv_path) self.data_length = None + self.shape_splits = None @pytest.fixture @@ -55,10 +68,17 @@ def prot_dna_test_data_with_split(): @pytest.mark.parametrize("fixture_name", [ ("dna_test_data"), - ("prot_dna_test_data") + ("dna_test_data_with_split"), + ("prot_dna_test_data"), + ("prot_dna_test_data_with_split") ]) -def test_data_length(request, fixture_name): - """Verify data is loaded with correct length""" +def test_data_length(request, fixture_name: str): + """Verify data is loaded with correct length. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + """ data = request.getfixturevalue(fixture_name) assert len(data.csv_loader) == data.data_length @@ -66,8 +86,18 @@ def test_data_length(request, fixture_name): ("dna_test_data"), ("prot_dna_test_data") ]) -def test_parse_csv_to_input_label_meta(request, fixture_name): - """Test parsing of CSV to input, label, and meta.""" +def test_parse_csv_to_input_label_meta(request, fixture_name:str): + """Test parsing of CSV to input, label, and meta. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + + Verifies: + - Input data is a dictionary + - Label data is a dictionary + - Meta data is a dictionary + """ data = request.getfixturevalue(fixture_name) assert isinstance(data.csv_loader.input, dict) assert isinstance(data.csv_loader.label, dict) @@ -77,10 +107,18 @@ def test_parse_csv_to_input_label_meta(request, fixture_name): ("dna_test_data"), ("prot_dna_test_data") ]) -def test_get_all_items(request, fixture_name): - """Test getting all items.""" +def test_get_all_items(request, fixture_name: str): + """Test retrieval of all items from the CSV loader. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + + Verifies: + - All returned data (input, label, meta) are dictionaries + """ data = request.getfixturevalue(fixture_name) - input_data, label_data, meta_data = data.get_all_items() + input_data, label_data, meta_data = data.csv_loader.get_all_items() assert isinstance(input_data, dict) assert isinstance(label_data, dict) assert isinstance(meta_data, dict) @@ -91,35 +129,50 @@ def test_get_all_items(request, fixture_name): ("prot_dna_test_data", 0, 1), ("prot_dna_test_data", slice(0,2), 2) ]) -def test_get_encoded_item(request, fixture_name, slice, expected_length): - """Check that one can get the items properly through slicing.""" - # get encoded item +def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_length: int): + """Test retrieval of encoded items through slicing. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + slice (int or slice): Index or slice object for data access. + expected_length (int): Expected length of the retrieved data. + + Verifies: + - Returns 3 dictionaries (input, label, meta) + - All items are encoded as numpy arrays + - Arrays have the expected length + """ data = request.getfixturevalue(fixture_name) encoded_items = data.csv_loader[slice] - # it should have 3 dictionaries for input, label, and meta assert len(encoded_items) == 3 for i in range(3): assert isinstance(encoded_items[i], dict) - - # for each dictionary, check that the sliced items are encoded as numpy arrays, and that match the expected length for item in encoded_items[i].values(): assert isinstance(item, np.ndarray) - if (expected_length > 1): # If the expected length is 0, this will fail as we are trying to find the length of an object size 0. + if (expected_length > 1): assert len(item) == expected_length @pytest.mark.parametrize("fixture_name", [ ("dna_test_data_with_split"), ("prot_dna_test_data_with_split") ]) -def test_load_with_split(request, fixture_name): - """Test loading with split.""" - data = request.getfixturevalue(fixture_name) - assert len(data.csv_loader) == data.data_length +def test_splitting(request, fixture_name): + """Test data splitting functionality. + + Args: + request: Pytest fixture request object. + fixture_name (str): Name of the fixture to test. + Verifies: + - Data can be loaded with different split indices + - Splits have correct lengths + - Invalid split index raises ValueError + """ + data = request.getfixturevalue(fixture_name) for i in [0, 1, 2]: data_i = CsvLoader(data.experiment, data.csv_path, split=i) assert len(data_i) == data.shape_splits[i] - - # with self.assertRaises(ValueError): - # CsvLoader(self.experiment, self.csv_path_split, split=3) + with pytest.raises(ValueError): + CsvLoader(data.experiment, data.csv_path, split=3) From e01cd48115bb5f30f55f355ed24f557fca1dab03 Mon Sep 17 00:00:00 2001 From: Suzanne Jin Date: Fri, 8 Nov 2024 15:55:01 +0100 Subject: [PATCH 3/3] make format --- tests/test_csv_loader.py | 59 +++++++++++++++++++++++++----------- tests/test_csv_processing.py | 56 +++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/tests/test_csv_loader.py b/tests/test_csv_loader.py index 2c53100e..6d466309 100644 --- a/tests/test_csv_loader.py +++ b/tests/test_csv_loader.py @@ -43,6 +43,7 @@ def dna_test_data(): data.data_length = 2 return data + @pytest.fixture def dna_test_data_with_split(): """This stores the basic dna test csv with split""" @@ -51,6 +52,7 @@ def dna_test_data_with_split(): data.shape_splits = {0: 16, 1: 16, 2: 16} return data + @pytest.fixture def prot_dna_test_data(): """This stores the basic prot-dna test csv""" @@ -58,6 +60,7 @@ def prot_dna_test_data(): data.data_length = 2 return data + @pytest.fixture def prot_dna_test_data_with_split(): """This stores the basic prot-dna test csv with split""" @@ -66,12 +69,16 @@ def prot_dna_test_data_with_split(): data.shape_splits = {0: 1, 1: 1, 2: 1} return data -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data"), ("dna_test_data_with_split"), ("prot_dna_test_data"), - ("prot_dna_test_data_with_split") - ]) + ("prot_dna_test_data_with_split"), + ], +) def test_data_length(request, fixture_name: str): """Verify data is loaded with correct length. @@ -82,11 +89,15 @@ def test_data_length(request, fixture_name: str): data = request.getfixturevalue(fixture_name) assert len(data.csv_loader) == data.data_length -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data"), - ("prot_dna_test_data") - ]) -def test_parse_csv_to_input_label_meta(request, fixture_name:str): + ("prot_dna_test_data"), + ], +) +def test_parse_csv_to_input_label_meta(request, fixture_name: str): """Test parsing of CSV to input, label, and meta. Args: @@ -103,10 +114,14 @@ def test_parse_csv_to_input_label_meta(request, fixture_name:str): assert isinstance(data.csv_loader.label, dict) assert isinstance(data.csv_loader.meta, dict) -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data"), - ("prot_dna_test_data") - ]) + ("prot_dna_test_data"), + ], +) def test_get_all_items(request, fixture_name: str): """Test retrieval of all items from the CSV loader. @@ -123,12 +138,16 @@ def test_get_all_items(request, fixture_name: str): assert isinstance(label_data, dict) assert isinstance(meta_data, dict) -@pytest.mark.parametrize("fixture_name,slice,expected_length", [ + +@pytest.mark.parametrize( + "fixture_name,slice,expected_length", + [ ("dna_test_data", 0, 1), - ("dna_test_data", slice(0,2), 2), + ("dna_test_data", slice(0, 2), 2), ("prot_dna_test_data", 0, 1), - ("prot_dna_test_data", slice(0,2), 2) - ]) + ("prot_dna_test_data", slice(0, 2), 2), + ], +) def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_length: int): """Test retrieval of encoded items through slicing. @@ -151,13 +170,17 @@ def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_lengt assert isinstance(encoded_items[i], dict) for item in encoded_items[i].values(): assert isinstance(item, np.ndarray) - if (expected_length > 1): + if expected_length > 1: assert len(item) == expected_length -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data_with_split"), - ("prot_dna_test_data_with_split") - ]) + ("prot_dna_test_data_with_split"), + ], +) def test_splitting(request, fixture_name): """Test data splitting functionality. diff --git a/tests/test_csv_processing.py b/tests/test_csv_processing.py index cd2e60d4..23f5fc3e 100644 --- a/tests/test_csv_processing.py +++ b/tests/test_csv_processing.py @@ -35,6 +35,7 @@ def __init__(self, filename: str, experiment: Any): self.expected_split = None self.expected_transformed_values = None + @pytest.fixture def dna_test_data(): """This stores the basic dna test csv""" @@ -49,6 +50,7 @@ def dna_test_data(): } return data + @pytest.fixture def dna_test_data_long(): """This stores the long dna test csv""" @@ -56,19 +58,24 @@ def dna_test_data_long(): data.data_length = 1000 return data + @pytest.fixture def dna_test_data_long_shuffled(): """This stores the shuffled long dna test csv""" - data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment) + data = DataCsvProcessing( + "tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment + ) data.data_length = 1000 return data + @pytest.fixture def dna_config(): """This is the config file for the dna experiment""" with open("tests/test_data/dna_experiment/test_config.json") as f: return json.load(f) - + + @pytest.fixture def prot_dna_test_data(): """This stores the basic prot-dna test csv""" @@ -84,21 +91,26 @@ def prot_dna_test_data(): } return data + @pytest.fixture def prot_dna_config(): """This is the config file for the prot experiment""" with open("tests/test_data/prot_dna_experiment/test_config.json") as f: return json.load(f) -@pytest.mark.parametrize("fixture_name", [ + +@pytest.mark.parametrize( + "fixture_name", + [ ("dna_test_data"), ("dna_test_data_long"), ("dna_test_data_long_shuffled"), - ("prot_dna_test_data") - ]) + ("prot_dna_test_data"), + ], +) def test_data_length(request, fixture_name): """Test that data is loaded with the correct length. - + Args: request: Pytest fixture request object. fixture_name (str): Name of the fixture to test. @@ -108,13 +120,17 @@ def test_data_length(request, fixture_name): data = request.getfixturevalue(fixture_name) assert len(data.csv_processing.data) == data.data_length -@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [ + +@pytest.mark.parametrize( + "fixture_data_name,fixture_config_name", + [ ("dna_test_data", "dna_config"), - ("prot_dna_test_data", "prot_dna_config") - ]) + ("prot_dna_test_data", "prot_dna_config"), + ], +) def test_add_split(request, fixture_data_name, fixture_config_name): """Test that the add_split function properly adds the split column. - + Args: request: Pytest fixture request object. fixture_data_name (str): Name of the data fixture to test. @@ -128,13 +144,17 @@ def test_add_split(request, fixture_data_name, fixture_config_name): data.csv_processing.add_split(config["split"]) assert data.csv_processing.data["split:split:int"].to_list() == data.expected_split -@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [ + +@pytest.mark.parametrize( + "fixture_data_name,fixture_config_name", + [ ("dna_test_data", "dna_config"), - ("prot_dna_test_data", "prot_dna_config") - ]) + ("prot_dna_test_data", "prot_dna_config"), + ], +) def test_transform_data(request, fixture_data_name, fixture_config_name): """Test that transformation functionalities properly transform the data. - + Args: request: Pytest fixture request object. fixture_data_name (str): Name of the data fixture to test. @@ -144,7 +164,7 @@ def test_transform_data(request, fixture_data_name, fixture_config_name): """ data = request.getfixturevalue(fixture_data_name) config = request.getfixturevalue(fixture_config_name) - + data.csv_processing.add_split(config["split"]) data.csv_processing.transform(config["transform"]) @@ -153,13 +173,14 @@ def test_transform_data(request, fixture_data_name, fixture_config_name): observed_values = [round(v, 6) if isinstance(v, float) else v for v in observed_values] assert observed_values == expected_values + def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled): """Test that shuffling of labels works correctly. - + This test verifies that when labels are shuffled with a fixed seed, they match the expected shuffled values from a pre-computed dataset. Currently only tests the long DNA test data. - + Args: dna_test_data_long: Fixture containing the original unshuffled DNA test data. dna_test_data_long_shuffled: Fixture containing the expected shuffled DNA test data. @@ -169,4 +190,3 @@ def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled): dna_test_data_long.csv_processing.data["hola:label:float"], dna_test_data_long_shuffled.csv_processing.data["hola:label:float"], ) -