Skip to content

Commit

Permalink
Merge pull request #15 from suzannejin/docstring-add-test-csv-loader
Browse files Browse the repository at this point in the history
add test csv loader
  • Loading branch information
mathysgrapotte authored Nov 8, 2024
2 parents d4af833 + e01cd48 commit 5910c12
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 18 deletions.
201 changes: 201 additions & 0 deletions tests/test_csv_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os
from typing import Any

import numpy as np
import pytest

from src.stimulus.data.csv import CsvLoader
from src.stimulus.data.experiments import DnaToFloatExperiment, ProtDnaToFloatExperiment


class DataCsvLoader:
"""Helper class to store CsvLoader objects and expected values for testing.
This class initializes CsvLoader objects with given csv data and stores expected
values for testing purposes.
Attributes:
experiment: An experiment instance to process the data.
csv_path (str): Absolute path to the CSV file.
csv_loader (CsvLoader): Initialized CsvLoader object.
data_length (int, optional): Expected length of the data.
shape_splits (dict, optional): Expected split indices and their lengths.
"""

def __init__(self, filename: str, experiment: Any):
"""Initialize DataCsvLoader with a CSV file and experiment type.
Args:
filename (str): Path to the CSV file.
experiment (Any): Experiment class to be instantiated.
"""
self.experiment = experiment()
self.csv_path = os.path.abspath(filename)
self.csv_loader = CsvLoader(self.experiment, self.csv_path)
self.data_length = None
self.shape_splits = None


@pytest.fixture
def dna_test_data():
"""This stores the basic dna test csv"""
data = DataCsvLoader("tests/test_data/dna_experiment/test.csv", DnaToFloatExperiment)
data.data_length = 2
return data


@pytest.fixture
def dna_test_data_with_split():
"""This stores the basic dna test csv with split"""
data = DataCsvLoader("tests/test_data/dna_experiment/test_with_split.csv", DnaToFloatExperiment)
data.data_length = 48
data.shape_splits = {0: 16, 1: 16, 2: 16}
return data


@pytest.fixture
def prot_dna_test_data():
"""This stores the basic prot-dna test csv"""
data = DataCsvLoader("tests/test_data/prot_dna_experiment/test.csv", ProtDnaToFloatExperiment)
data.data_length = 2
return data


@pytest.fixture
def prot_dna_test_data_with_split():
"""This stores the basic prot-dna test csv with split"""
data = DataCsvLoader("tests/test_data/prot_dna_experiment/test_with_split.csv", ProtDnaToFloatExperiment)
data.data_length = 3
data.shape_splits = {0: 1, 1: 1, 2: 1}
return data


@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("dna_test_data_with_split"),
("prot_dna_test_data"),
("prot_dna_test_data_with_split"),
],
)
def test_data_length(request, fixture_name: str):
"""Verify data is loaded with correct length.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
"""
data = request.getfixturevalue(fixture_name)
assert len(data.csv_loader) == data.data_length


@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("prot_dna_test_data"),
],
)
def test_parse_csv_to_input_label_meta(request, fixture_name: str):
"""Test parsing of CSV to input, label, and meta.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
Verifies:
- Input data is a dictionary
- Label data is a dictionary
- Meta data is a dictionary
"""
data = request.getfixturevalue(fixture_name)
assert isinstance(data.csv_loader.input, dict)
assert isinstance(data.csv_loader.label, dict)
assert isinstance(data.csv_loader.meta, dict)


@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("prot_dna_test_data"),
],
)
def test_get_all_items(request, fixture_name: str):
"""Test retrieval of all items from the CSV loader.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
Verifies:
- All returned data (input, label, meta) are dictionaries
"""
data = request.getfixturevalue(fixture_name)
input_data, label_data, meta_data = data.csv_loader.get_all_items()
assert isinstance(input_data, dict)
assert isinstance(label_data, dict)
assert isinstance(meta_data, dict)


@pytest.mark.parametrize(
"fixture_name,slice,expected_length",
[
("dna_test_data", 0, 1),
("dna_test_data", slice(0, 2), 2),
("prot_dna_test_data", 0, 1),
("prot_dna_test_data", slice(0, 2), 2),
],
)
def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_length: int):
"""Test retrieval of encoded items through slicing.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
slice (int or slice): Index or slice object for data access.
expected_length (int): Expected length of the retrieved data.
Verifies:
- Returns 3 dictionaries (input, label, meta)
- All items are encoded as numpy arrays
- Arrays have the expected length
"""
data = request.getfixturevalue(fixture_name)
encoded_items = data.csv_loader[slice]

assert len(encoded_items) == 3
for i in range(3):
assert isinstance(encoded_items[i], dict)
for item in encoded_items[i].values():
assert isinstance(item, np.ndarray)
if expected_length > 1:
assert len(item) == expected_length


@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data_with_split"),
("prot_dna_test_data_with_split"),
],
)
def test_splitting(request, fixture_name):
"""Test data splitting functionality.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
Verifies:
- Data can be loaded with different split indices
- Splits have correct lengths
- Invalid split index raises ValueError
"""
data = request.getfixturevalue(fixture_name)
for i in [0, 1, 2]:
data_i = CsvLoader(data.experiment, data.csv_path, split=i)
assert len(data_i) == data.shape_splits[i]
with pytest.raises(ValueError):
CsvLoader(data.experiment, data.csv_path, split=3)
56 changes: 38 additions & 18 deletions tests/test_csv_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, filename: str, experiment: Any):
self.expected_split = None
self.expected_transformed_values = None


@pytest.fixture
def dna_test_data():
"""This stores the basic dna test csv"""
Expand All @@ -49,26 +50,32 @@ def dna_test_data():
}
return data


@pytest.fixture
def dna_test_data_long():
"""This stores the long dna test csv"""
data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long.csv", DnaToFloatExperiment)
data.data_length = 1000
return data


@pytest.fixture
def dna_test_data_long_shuffled():
"""This stores the shuffled long dna test csv"""
data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment)
data = DataCsvProcessing(
"tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment
)
data.data_length = 1000
return data


@pytest.fixture
def dna_config():
"""This is the config file for the dna experiment"""
with open("tests/test_data/dna_experiment/test_config.json") as f:
return json.load(f)



@pytest.fixture
def prot_dna_test_data():
"""This stores the basic prot-dna test csv"""
Expand All @@ -84,21 +91,26 @@ def prot_dna_test_data():
}
return data


@pytest.fixture
def prot_dna_config():
"""This is the config file for the prot experiment"""
with open("tests/test_data/prot_dna_experiment/test_config.json") as f:
return json.load(f)

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("dna_test_data_long"),
("dna_test_data_long_shuffled"),
("prot_dna_test_data")
])
("prot_dna_test_data"),
],
)
def test_data_length(request, fixture_name):
"""Test that data is loaded with the correct length.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
Expand All @@ -108,13 +120,17 @@ def test_data_length(request, fixture_name):
data = request.getfixturevalue(fixture_name)
assert len(data.csv_processing.data) == data.data_length

@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [

@pytest.mark.parametrize(
"fixture_data_name,fixture_config_name",
[
("dna_test_data", "dna_config"),
("prot_dna_test_data", "prot_dna_config")
])
("prot_dna_test_data", "prot_dna_config"),
],
)
def test_add_split(request, fixture_data_name, fixture_config_name):
"""Test that the add_split function properly adds the split column.
Args:
request: Pytest fixture request object.
fixture_data_name (str): Name of the data fixture to test.
Expand All @@ -128,13 +144,17 @@ def test_add_split(request, fixture_data_name, fixture_config_name):
data.csv_processing.add_split(config["split"])
assert data.csv_processing.data["split:split:int"].to_list() == data.expected_split

@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [

@pytest.mark.parametrize(
"fixture_data_name,fixture_config_name",
[
("dna_test_data", "dna_config"),
("prot_dna_test_data", "prot_dna_config")
])
("prot_dna_test_data", "prot_dna_config"),
],
)
def test_transform_data(request, fixture_data_name, fixture_config_name):
"""Test that transformation functionalities properly transform the data.
Args:
request: Pytest fixture request object.
fixture_data_name (str): Name of the data fixture to test.
Expand All @@ -144,7 +164,7 @@ def test_transform_data(request, fixture_data_name, fixture_config_name):
"""
data = request.getfixturevalue(fixture_data_name)
config = request.getfixturevalue(fixture_config_name)

data.csv_processing.add_split(config["split"])
data.csv_processing.transform(config["transform"])

Expand All @@ -153,13 +173,14 @@ def test_transform_data(request, fixture_data_name, fixture_config_name):
observed_values = [round(v, 6) if isinstance(v, float) else v for v in observed_values]
assert observed_values == expected_values


def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled):
"""Test that shuffling of labels works correctly.
This test verifies that when labels are shuffled with a fixed seed,
they match the expected shuffled values from a pre-computed dataset.
Currently only tests the long DNA test data.
Args:
dna_test_data_long: Fixture containing the original unshuffled DNA test data.
dna_test_data_long_shuffled: Fixture containing the expected shuffled DNA test data.
Expand All @@ -169,4 +190,3 @@ def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled):
dna_test_data_long.csv_processing.data["hola:label:float"],
dna_test_data_long_shuffled.csv_processing.data["hola:label:float"],
)

0 comments on commit 5910c12

Please sign in to comment.