diff --git a/src/database/model/models_and_experiments/experiment.py b/src/database/model/models_and_experiments/experiment.py index 85395ad3..196526c7 100644 --- a/src/database/model/models_and_experiments/experiment.py +++ b/src/database/model/models_and_experiments/experiment.py @@ -1,14 +1,24 @@ +from typing import TYPE_CHECKING + from sqlmodel import Field, Relationship from database.model.ai_asset.ai_asset import AIAssetBase, AIAsset +from database.model.dataset.dataset import Dataset from database.model.field_length import SHORT, LONG from database.model.helper_functions import many_to_many_link_factory from database.model.models_and_experiments.badge import Badge from database.model.models_and_experiments.runnable_distribution import RunnableDistribution from database.model.relationships import ManyToMany, OneToMany -from database.model.serializers import AttributeSerializer, FindByNameDeserializerList +from database.model.serializers import ( + AttributeSerializer, + FindByIdentifierDeserializerList, + FindByNameDeserializerList, +) from versioning import Version, VersionedResource, VersionedResourceCollection +if TYPE_CHECKING: + from database.model.models_and_experiments.ml_model import MLModel + class ExperimentBase(AIAssetBase): pid: str | None = Field( @@ -51,6 +61,24 @@ class Experiment(ExperimentBase, AIAsset, table=True): # type: ignore [call-arg "experiment", Badge.__tablename__, from_identifier_type=str ), ) + uses_model: list["MLModel"] = Relationship( + link_model=many_to_many_link_factory( + "experiment", + "ml_model", + table_prefix="uses_model", + from_identifier_type=str, + to_identifier_type=str, + ), + ) + uses_dataset: list[Dataset] = Relationship( + link_model=many_to_many_link_factory( + "experiment", + Dataset.__tablename__, + table_prefix="uses_dataset", + from_identifier_type=str, + to_identifier_type=str, + ), + ) class RelationshipConfig(AIAsset.RelationshipConfig): badge: list[str] = ManyToMany( @@ -61,6 +89,19 @@ class RelationshipConfig(AIAsset.RelationshipConfig): example=["ACM Artifacts Evaluated - Reusable"], ) distribution: list[RunnableDistribution] = OneToMany(default_factory_pydantic=list) + uses_model: list[str] = ManyToMany( + description="ML models used during the execution of this experiment.", + _serializer=AttributeSerializer("identifier"), + default_factory_pydantic=list, + example=[], + ) + uses_dataset: list[str] = ManyToMany( + description="Datasets used for the execution of the experiment.", + _serializer=AttributeSerializer("identifier"), + deserializer=FindByIdentifierDeserializerList(Dataset), + default_factory_pydantic=list, + example=[], + ) experiment_versions = VersionedResourceCollection( diff --git a/src/database/model/models_and_experiments/ml_model.py b/src/database/model/models_and_experiments/ml_model.py index 9c61621d..746b180f 100644 --- a/src/database/model/models_and_experiments/ml_model.py +++ b/src/database/model/models_and_experiments/ml_model.py @@ -58,6 +58,10 @@ class RelationshipConfig(AIAsset.RelationshipConfig): ) +# Complete the uses_model deserializer for Experiment, which cannot be set in experiment.py +# due to a circular import (MLModel imports Experiment and vice versa). +Experiment.RelationshipConfig.uses_model.deserializer = FindByIdentifierDeserializerList(MLModel) + ml_model_versions = VersionedResourceCollection( { Version.V2: VersionedResource(MLModel), diff --git a/src/tests/routers/resource_routers/test_router_experiment.py b/src/tests/routers/resource_routers/test_router_experiment.py index e1ed1246..93ecadd5 100644 --- a/src/tests/routers/resource_routers/test_router_experiment.py +++ b/src/tests/routers/resource_routers/test_router_experiment.py @@ -3,18 +3,36 @@ from starlette.testclient import TestClient +from database.model.dataset.dataset import Dataset +from database.model.models_and_experiments.ml_model import MLModel +from database.session import DbSession + def test_happy_path( client: TestClient, mocked_privileged_token: Mock, body_asset: dict, + dataset: Dataset, + ml_model: MLModel, auto_publish: None, ): + with DbSession() as session: + session.add(dataset) + session.commit() + session.refresh(dataset) + + with DbSession() as session: + session.add(ml_model) + session.commit() + session.refresh(ml_model) + body = copy.copy(body_asset) body["pid"] = "https://doi.org/10.1000/182" body["experimental_workflow"] = "Example workflow." body["execution_settings"] = "Example execution settings." body["reproducibility_explanation"] = "Example reproducibility explanation." + body["uses_model"] = [ml_model.identifier] + body["uses_dataset"] = [dataset.identifier] distribution = { "checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", @@ -41,7 +59,7 @@ def test_happy_path( response = client.post("/experiments", json=body, headers={"Authorization": "Fake token"}) assert response.status_code == 200, response.json() - identifier = response.json()['identifier'] + identifier = response.json()["identifier"] response = client.get(f"/experiments/{identifier}") assert response.status_code == 200, response.json() @@ -52,3 +70,5 @@ def test_happy_path( assert response_json["execution_settings"] == "Example execution settings." assert response_json["reproducibility_explanation"] == "Example reproducibility explanation." assert response_json["distribution"] == [distribution] + assert response_json["uses_model"] == [ml_model.identifier] + assert response_json["uses_dataset"] == [dataset.identifier] diff --git a/src/tests/testutils/default_instances.py b/src/tests/testutils/default_instances.py index 350334a5..7cabe221 100644 --- a/src/tests/testutils/default_instances.py +++ b/src/tests/testutils/default_instances.py @@ -18,6 +18,7 @@ from database.model.dataset.dataset import Dataset from database.model.knowledge_asset.publication import Publication from database.model.models_and_experiments.experiment import Experiment +from database.model.models_and_experiments.ml_model import MLModel from database.model.platform.platform import Platform from database.model.resource_read_and_create import resource_create from database.model.serializers import deserialize_resource_relationships @@ -76,7 +77,9 @@ def body_agent(body_resource: dict, load_body_agent: dict) -> dict: return copy.deepcopy(body) -def make_publication(body_asset: dict, with_random_platform_identifier: bool = False) -> Publication: +def make_publication( + body_asset: dict, with_random_platform_identifier: bool = False +) -> Publication: body = copy.deepcopy(body_asset) body["permanent_identifier"] = "http://dx.doi.org/10.1093/ajae/aaq063" body["isbn"] = "9783161484100" @@ -140,6 +143,12 @@ def person(body_agent: dict) -> Person: def experiment(body_asset: dict) -> Experiment: return _create_class_with_body(Experiment, body_asset) + +@pytest.fixture +def ml_model(body_asset: dict) -> MLModel: + return _create_class_with_body(MLModel, body_asset) + + @pytest.fixture def project(body_asset: dict) -> Project: return _create_class_with_body(Project, body_asset)