Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ ARG PKG_NAME="v6-session-basics"

# install federated algorithm
COPY . /app
RUN uv pip install --system -e /app
# TODO v5+ should remove --prerelease=allow when official release is made
RUN uv pip install --system -e /app --prerelease=allow

# Set environment variable to make name of the package available within the
# docker image.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requires-python = ">=3.13"
dependencies = [
"pandas",
"pytest",
"vantage6-algorithm-tools==5.0.0a36",
"vantage6-algorithm-tools==5.0.0a41",
]
authors = [
{ name = "Bart van Beusekom", email = "[email protected]" },
Expand Down
59 changes: 0 additions & 59 deletions test/test.py

This file was deleted.

8 changes: 5 additions & 3 deletions test/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.resources import files

from vantage6.mock.mock_network import MockNetwork, MockUserClient
import pytest
from vantage6.mock.network import MockNetwork, MockUserClient


@pytest.fixture
Expand All @@ -21,11 +21,12 @@ def mock_client() -> MockUserClient:
"database": test_data,
"db_type": "csv",
},
}
]
},
],
)
return MockUserClient(mock_network)


def test_metadata_function(mock_client: MockUserClient):
"""Test the metadata function"""
# Get organizations
Expand All @@ -39,6 +40,7 @@ def test_metadata_function(mock_client: MockUserClient):

# Wait for results
results = mock_client.wait_for_results(task.get("id"))
print(results)

# Assertions
assert results is not None
Expand Down
63 changes: 63 additions & 0 deletions test/test_pre_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from importlib.resources import files

import pandas as pd
import pytest
from vantage6.mock.network import MockNetwork, MockUserClient


@pytest.fixture
def mock_client() -> MockUserClient:
test_data = files("v6-session-basics").joinpath("data/test_data.csv")
mock_network = MockNetwork(
module_name="v6-session-basics",
datasets=[
{
"test_data_1": {
"database": pd.read_csv(test_data),
"db_type": "csv",
},
},
{
"test_data_1": {
"database": pd.read_csv(test_data),
"db_type": "csv",
},
},
],
)
return mock_network.user_client


def test_pre_process_function(mock_client: MockUserClient):
"""Test the pre_process function"""
# Get organizations
# orgs = mock_client.organization.list()
# org_ids = [org["id"] for org in orgs]
DATAFRAME_ID = 1

# Check what dtype the dataframe has
old_response = mock_client.network.server.get_dataframe(DATAFRAME_ID)
old_dtypes = [
column["dtype"] for column in old_response["columns"] if column["name"] == "Age"
]
assert [str(dtype) for dtype in old_dtypes] == [
"int64" for _ in range(len(old_dtypes))
]

# Note that the tasks here are run in sequence, thus sleeping for 1 seconds will
# be multiplied by the number of organizations.
mock_client.dataframe.preprocess(
id_=DATAFRAME_ID,
image="mock-image",
method="pre_process",
arguments={"column": "Age", "dtype": "int32"},
)

df_response = mock_client.network.server.get_dataframe(DATAFRAME_ID)
new_dtypes = [
column["dtype"] for column in df_response["columns"] if column["name"] == "Age"
]
assert old_dtypes != new_dtypes
assert [str(dtype) for dtype in new_dtypes] == [
"int32" for _ in range(len(new_dtypes))
]
12 changes: 6 additions & 6 deletions test/test_read_csv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from importlib.resources import files

import pandas as pd
import pytest

from vantage6.mock.mock_network import MockNetwork, MockUserClient
from vantage6.mock.network import MockNetwork, MockUserClient


@pytest.fixture
Expand All @@ -22,8 +22,8 @@ def mock_client() -> MockUserClient:
"database": test_data,
"db_type": "csv",
},
}
]
},
],
)
return MockUserClient(mock_network)

Expand All @@ -41,7 +41,7 @@ def test_read_csv_function(mock_client: MockUserClient):
arguments={},
action="data_extraction",
label="test_data_1",
name="my_dataframe_by_frank"
name="my_dataframe_by_frank",
)

# A data extraction job should create a dataframe on each node, lets check if this
Expand All @@ -51,4 +51,4 @@ def test_read_csv_function(mock_client: MockUserClient):
for node in mock_client.network.nodes:
assert len(node.dataframes) == 1
assert "my_dataframe_by_frank" in node.dataframes
assert isinstance(node.dataframes["my_dataframe_by_frank"], pd.DataFrame)
assert isinstance(node.dataframes["my_dataframe_by_frank"], pd.DataFrame)
9 changes: 5 additions & 4 deletions test/test_sleep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.resources import files

from vantage6.mock.mock_network import MockNetwork, MockUserClient
import pytest
from vantage6.mock.network import MockNetwork, MockUserClient


@pytest.fixture
Expand All @@ -21,11 +21,12 @@ def mock_client() -> MockUserClient:
"database": test_data,
"db_type": "csv",
},
}
]
},
],
)
return mock_network.user_client


def test_sleep_function(mock_client: MockUserClient):
"""Test the metadata function"""
# Get organizations
Expand All @@ -46,4 +47,4 @@ def test_sleep_function(mock_client: MockUserClient):
assert len(results) == 2 # Two organizations
for result in results:
assert "sleep" in result
assert result["sleep"] == "done"
assert result["sleep"] == "done"
54 changes: 54 additions & 0 deletions test/test_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from importlib.resources import files

import pandas as pd
import pytest
from vantage6.mock.network import MockNetwork, MockUserClient

TEST_DATAFRAME = pd.read_csv(files("v6-session-basics").joinpath("data/test_data.csv"))
DATAFRAME_LABEL = "test_data_1"


@pytest.fixture
def mock_client() -> MockUserClient:
mock_network = MockNetwork(
module_name="v6-session-basics",
datasets=[
{
DATAFRAME_LABEL: {
"database": TEST_DATAFRAME,
"db_type": "csv",
},
},
{
DATAFRAME_LABEL: {
"database": TEST_DATAFRAME,
"db_type": "csv",
},
},
],
)
return mock_network.user_client


def test_sum_function(mock_client: MockUserClient):
"""Test the sum function"""
# Get organizations
orgs = mock_client.organization.list()
org_ids = [org["id"] for org in orgs]

# Note that the tasks here are run in sequence, thus sleeping for 1 seconds will
# be multiplied by the number of organizations.
column_to_sum = "Age"
task = mock_client.task.create(
method="sum",
organizations=org_ids,
arguments={"column": column_to_sum},
databases=[{"label": DATAFRAME_LABEL}],
)

print(task)
# Wait for results
results = mock_client.wait_for_results(task.get("id"))

print(results)
assert results[0]["sum"] == TEST_DATAFRAME[column_to_sum].sum()
82 changes: 82 additions & 0 deletions test/test_sum_many.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from importlib.resources import files

import pandas as pd
import pytest
from vantage6.mock.network import MockNetwork, MockUserClient

# Read full dataframe and split into 4 random subsets
full_df = pd.read_csv(files("v6-session-basics").joinpath("data/test_data.csv"))
TEST_DATAFRAME_1 = full_df.sample(n=10, random_state=42)
TEST_DATAFRAME_2 = full_df.sample(n=15, random_state=43)
TEST_DATAFRAME_3 = full_df.sample(n=11, random_state=44)
TEST_DATAFRAME_4 = full_df.sample(n=15, random_state=45)
DATAFRAME_LABEL_1 = "test_data_1"
DATAFRAME_LABEL_2 = "test_data_2"


@pytest.fixture
def mock_client() -> MockUserClient:
mock_network = MockNetwork(
module_name="v6-session-basics",
datasets=[
{
DATAFRAME_LABEL_1: {
"database": TEST_DATAFRAME_1,
"db_type": "csv",
},
DATAFRAME_LABEL_2: {
"database": TEST_DATAFRAME_2,
"db_type": "csv",
},
},
{
DATAFRAME_LABEL_1: {
"database": TEST_DATAFRAME_3,
"db_type": "csv",
},
DATAFRAME_LABEL_2: {
"database": TEST_DATAFRAME_4,
"db_type": "csv",
},
},
],
)
return mock_network.user_client


def test_sum_many_function(mock_client: MockUserClient):
"""Test the sum_many function"""
# Get organizations
orgs = mock_client.organization.list()
org_ids = [org["id"] for org in orgs]

column_to_sum = "Age"
task = mock_client.task.create(
method="sum_many",
organizations=org_ids,
arguments={"column": column_to_sum},
databases=[[{"label": DATAFRAME_LABEL_1}, {"label": DATAFRAME_LABEL_2}]],
)

# Wait for results
results = mock_client.wait_for_results(task.get("id"))
print(results)

# # Verify results
assert results is not None
assert len(results) == len(org_ids) # Two organizations

assert "sums" in results[0]
assert DATAFRAME_LABEL_1 in results[0]["sums"]
assert (
results[0]["sums"][DATAFRAME_LABEL_1] == TEST_DATAFRAME_1[column_to_sum].sum()
)
assert (
results[0]["sums"][DATAFRAME_LABEL_2] == TEST_DATAFRAME_2[column_to_sum].sum()
)
assert (
results[1]["sums"][DATAFRAME_LABEL_1] == TEST_DATAFRAME_3[column_to_sum].sum()
)
assert (
results[1]["sums"][DATAFRAME_LABEL_2] == TEST_DATAFRAME_4[column_to_sum].sum()
)
Loading