Skip to content

Commit

Permalink
make format
Browse files Browse the repository at this point in the history
  • Loading branch information
suzannejin committed Nov 8, 2024
1 parent 6b9aeca commit e01cd48
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 36 deletions.
59 changes: 41 additions & 18 deletions tests/test_csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def dna_test_data():
data.data_length = 2
return data


@pytest.fixture
def dna_test_data_with_split():
"""This stores the basic dna test csv with split"""
Expand All @@ -51,13 +52,15 @@ def dna_test_data_with_split():
data.shape_splits = {0: 16, 1: 16, 2: 16}
return data


@pytest.fixture
def prot_dna_test_data():
"""This stores the basic prot-dna test csv"""
data = DataCsvLoader("tests/test_data/prot_dna_experiment/test.csv", ProtDnaToFloatExperiment)
data.data_length = 2
return data


@pytest.fixture
def prot_dna_test_data_with_split():
"""This stores the basic prot-dna test csv with split"""
Expand All @@ -66,12 +69,16 @@ def prot_dna_test_data_with_split():
data.shape_splits = {0: 1, 1: 1, 2: 1}
return data

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("dna_test_data_with_split"),
("prot_dna_test_data"),
("prot_dna_test_data_with_split")
])
("prot_dna_test_data_with_split"),
],
)
def test_data_length(request, fixture_name: str):
"""Verify data is loaded with correct length.
Expand All @@ -82,11 +89,15 @@ def test_data_length(request, fixture_name: str):
data = request.getfixturevalue(fixture_name)
assert len(data.csv_loader) == data.data_length

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("prot_dna_test_data")
])
def test_parse_csv_to_input_label_meta(request, fixture_name:str):
("prot_dna_test_data"),
],
)
def test_parse_csv_to_input_label_meta(request, fixture_name: str):
"""Test parsing of CSV to input, label, and meta.
Args:
Expand All @@ -103,10 +114,14 @@ def test_parse_csv_to_input_label_meta(request, fixture_name:str):
assert isinstance(data.csv_loader.label, dict)
assert isinstance(data.csv_loader.meta, dict)

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("prot_dna_test_data")
])
("prot_dna_test_data"),
],
)
def test_get_all_items(request, fixture_name: str):
"""Test retrieval of all items from the CSV loader.
Expand All @@ -123,12 +138,16 @@ def test_get_all_items(request, fixture_name: str):
assert isinstance(label_data, dict)
assert isinstance(meta_data, dict)

@pytest.mark.parametrize("fixture_name,slice,expected_length", [

@pytest.mark.parametrize(
"fixture_name,slice,expected_length",
[
("dna_test_data", 0, 1),
("dna_test_data", slice(0,2), 2),
("dna_test_data", slice(0, 2), 2),
("prot_dna_test_data", 0, 1),
("prot_dna_test_data", slice(0,2), 2)
])
("prot_dna_test_data", slice(0, 2), 2),
],
)
def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_length: int):
"""Test retrieval of encoded items through slicing.
Expand All @@ -151,13 +170,17 @@ def test_get_encoded_item(request, fixture_name: str, slice: Any, expected_lengt
assert isinstance(encoded_items[i], dict)
for item in encoded_items[i].values():
assert isinstance(item, np.ndarray)
if (expected_length > 1):
if expected_length > 1:
assert len(item) == expected_length

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data_with_split"),
("prot_dna_test_data_with_split")
])
("prot_dna_test_data_with_split"),
],
)
def test_splitting(request, fixture_name):
"""Test data splitting functionality.
Expand Down
56 changes: 38 additions & 18 deletions tests/test_csv_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, filename: str, experiment: Any):
self.expected_split = None
self.expected_transformed_values = None


@pytest.fixture
def dna_test_data():
"""This stores the basic dna test csv"""
Expand All @@ -49,26 +50,32 @@ def dna_test_data():
}
return data


@pytest.fixture
def dna_test_data_long():
"""This stores the long dna test csv"""
data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long.csv", DnaToFloatExperiment)
data.data_length = 1000
return data


@pytest.fixture
def dna_test_data_long_shuffled():
"""This stores the shuffled long dna test csv"""
data = DataCsvProcessing("tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment)
data = DataCsvProcessing(
"tests/test_data/dna_experiment/test_shuffling_long_shuffled.csv", ProtDnaToFloatExperiment
)
data.data_length = 1000
return data


@pytest.fixture
def dna_config():
"""This is the config file for the dna experiment"""
with open("tests/test_data/dna_experiment/test_config.json") as f:
return json.load(f)



@pytest.fixture
def prot_dna_test_data():
"""This stores the basic prot-dna test csv"""
Expand All @@ -84,21 +91,26 @@ def prot_dna_test_data():
}
return data


@pytest.fixture
def prot_dna_config():
"""This is the config file for the prot experiment"""
with open("tests/test_data/prot_dna_experiment/test_config.json") as f:
return json.load(f)

@pytest.mark.parametrize("fixture_name", [

@pytest.mark.parametrize(
"fixture_name",
[
("dna_test_data"),
("dna_test_data_long"),
("dna_test_data_long_shuffled"),
("prot_dna_test_data")
])
("prot_dna_test_data"),
],
)
def test_data_length(request, fixture_name):
"""Test that data is loaded with the correct length.
Args:
request: Pytest fixture request object.
fixture_name (str): Name of the fixture to test.
Expand All @@ -108,13 +120,17 @@ def test_data_length(request, fixture_name):
data = request.getfixturevalue(fixture_name)
assert len(data.csv_processing.data) == data.data_length

@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [

@pytest.mark.parametrize(
"fixture_data_name,fixture_config_name",
[
("dna_test_data", "dna_config"),
("prot_dna_test_data", "prot_dna_config")
])
("prot_dna_test_data", "prot_dna_config"),
],
)
def test_add_split(request, fixture_data_name, fixture_config_name):
"""Test that the add_split function properly adds the split column.
Args:
request: Pytest fixture request object.
fixture_data_name (str): Name of the data fixture to test.
Expand All @@ -128,13 +144,17 @@ def test_add_split(request, fixture_data_name, fixture_config_name):
data.csv_processing.add_split(config["split"])
assert data.csv_processing.data["split:split:int"].to_list() == data.expected_split

@pytest.mark.parametrize("fixture_data_name,fixture_config_name", [

@pytest.mark.parametrize(
"fixture_data_name,fixture_config_name",
[
("dna_test_data", "dna_config"),
("prot_dna_test_data", "prot_dna_config")
])
("prot_dna_test_data", "prot_dna_config"),
],
)
def test_transform_data(request, fixture_data_name, fixture_config_name):
"""Test that transformation functionalities properly transform the data.
Args:
request: Pytest fixture request object.
fixture_data_name (str): Name of the data fixture to test.
Expand All @@ -144,7 +164,7 @@ def test_transform_data(request, fixture_data_name, fixture_config_name):
"""
data = request.getfixturevalue(fixture_data_name)
config = request.getfixturevalue(fixture_config_name)

data.csv_processing.add_split(config["split"])
data.csv_processing.transform(config["transform"])

Expand All @@ -153,13 +173,14 @@ def test_transform_data(request, fixture_data_name, fixture_config_name):
observed_values = [round(v, 6) if isinstance(v, float) else v for v in observed_values]
assert observed_values == expected_values


def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled):
"""Test that shuffling of labels works correctly.
This test verifies that when labels are shuffled with a fixed seed,
they match the expected shuffled values from a pre-computed dataset.
Currently only tests the long DNA test data.
Args:
dna_test_data_long: Fixture containing the original unshuffled DNA test data.
dna_test_data_long_shuffled: Fixture containing the expected shuffled DNA test data.
Expand All @@ -169,4 +190,3 @@ def test_shuffle_labels(dna_test_data_long, dna_test_data_long_shuffled):
dna_test_data_long.csv_processing.data["hola:label:float"],
dna_test_data_long_shuffled.csv_processing.data["hola:label:float"],
)

0 comments on commit e01cd48

Please sign in to comment.