Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/database/model/models_and_experiments/experiment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from typing import TYPE_CHECKING

from sqlmodel import Field, Relationship

from database.model.ai_asset.ai_asset import AIAssetBase, AIAsset
from database.model.dataset.dataset import Dataset
from database.model.field_length import SHORT, LONG
from database.model.helper_functions import many_to_many_link_factory
from database.model.models_and_experiments.badge import Badge
from database.model.models_and_experiments.runnable_distribution import RunnableDistribution
from database.model.relationships import ManyToMany, OneToMany
from database.model.serializers import AttributeSerializer, FindByNameDeserializerList
from database.model.serializers import (
AttributeSerializer,
FindByIdentifierDeserializerList,
FindByNameDeserializerList,
)
from versioning import Version, VersionedResource, VersionedResourceCollection

if TYPE_CHECKING:
from database.model.models_and_experiments.ml_model import MLModel


class ExperimentBase(AIAssetBase):
pid: str | None = Field(
Expand Down Expand Up @@ -51,6 +61,24 @@ class Experiment(ExperimentBase, AIAsset, table=True): # type: ignore [call-arg
"experiment", Badge.__tablename__, from_identifier_type=str
),
)
uses_model: list["MLModel"] = Relationship(
link_model=many_to_many_link_factory(
"experiment",
"ml_model",
table_prefix="uses_model",
from_identifier_type=str,
to_identifier_type=str,
),
)
uses_dataset: list[Dataset] = Relationship(
link_model=many_to_many_link_factory(
"experiment",
Dataset.__tablename__,
table_prefix="uses_dataset",
from_identifier_type=str,
to_identifier_type=str,
),
)

class RelationshipConfig(AIAsset.RelationshipConfig):
badge: list[str] = ManyToMany(
Expand All @@ -61,6 +89,19 @@ class RelationshipConfig(AIAsset.RelationshipConfig):
example=["ACM Artifacts Evaluated - Reusable"],
)
distribution: list[RunnableDistribution] = OneToMany(default_factory_pydantic=list)
uses_model: list[str] = ManyToMany(
description="ML models used during the execution of this experiment.",
_serializer=AttributeSerializer("identifier"),
default_factory_pydantic=list,
example=[],
)
uses_dataset: list[str] = ManyToMany(
description="Datasets used for the execution of the experiment.",
_serializer=AttributeSerializer("identifier"),
deserializer=FindByIdentifierDeserializerList(Dataset),
default_factory_pydantic=list,
example=[],
)


experiment_versions = VersionedResourceCollection(
Expand Down
4 changes: 4 additions & 0 deletions src/database/model/models_and_experiments/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class RelationshipConfig(AIAsset.RelationshipConfig):
)


# Complete the uses_model deserializer for Experiment, which cannot be set in experiment.py
# due to a circular import (MLModel imports Experiment and vice versa).
Experiment.RelationshipConfig.uses_model.deserializer = FindByIdentifierDeserializerList(MLModel)

ml_model_versions = VersionedResourceCollection(
{
Version.V2: VersionedResource(MLModel),
Expand Down
22 changes: 21 additions & 1 deletion src/tests/routers/resource_routers/test_router_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,36 @@

from starlette.testclient import TestClient

from database.model.dataset.dataset import Dataset
from database.model.models_and_experiments.ml_model import MLModel
from database.session import DbSession


def test_happy_path(
client: TestClient,
mocked_privileged_token: Mock,
body_asset: dict,
dataset: Dataset,
ml_model: MLModel,
auto_publish: None,
):
with DbSession() as session:
session.add(dataset)
session.commit()
session.refresh(dataset)

with DbSession() as session:
session.add(ml_model)
session.commit()
session.refresh(ml_model)

body = copy.copy(body_asset)
body["pid"] = "https://doi.org/10.1000/182"
body["experimental_workflow"] = "Example workflow."
body["execution_settings"] = "Example execution settings."
body["reproducibility_explanation"] = "Example reproducibility explanation."
body["uses_model"] = [ml_model.identifier]
body["uses_dataset"] = [dataset.identifier]

distribution = {
"checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
Expand All @@ -41,7 +59,7 @@ def test_happy_path(

response = client.post("/experiments", json=body, headers={"Authorization": "Fake token"})
assert response.status_code == 200, response.json()
identifier = response.json()['identifier']
identifier = response.json()["identifier"]

response = client.get(f"/experiments/{identifier}")
assert response.status_code == 200, response.json()
Expand All @@ -52,3 +70,5 @@ def test_happy_path(
assert response_json["execution_settings"] == "Example execution settings."
assert response_json["reproducibility_explanation"] == "Example reproducibility explanation."
assert response_json["distribution"] == [distribution]
assert response_json["uses_model"] == [ml_model.identifier]
assert response_json["uses_dataset"] == [dataset.identifier]
11 changes: 10 additions & 1 deletion src/tests/testutils/default_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from database.model.dataset.dataset import Dataset
from database.model.knowledge_asset.publication import Publication
from database.model.models_and_experiments.experiment import Experiment
from database.model.models_and_experiments.ml_model import MLModel
from database.model.platform.platform import Platform
from database.model.resource_read_and_create import resource_create
from database.model.serializers import deserialize_resource_relationships
Expand Down Expand Up @@ -76,7 +77,9 @@ def body_agent(body_resource: dict, load_body_agent: dict) -> dict:
return copy.deepcopy(body)


def make_publication(body_asset: dict, with_random_platform_identifier: bool = False) -> Publication:
def make_publication(
body_asset: dict, with_random_platform_identifier: bool = False
) -> Publication:
body = copy.deepcopy(body_asset)
body["permanent_identifier"] = "http://dx.doi.org/10.1093/ajae/aaq063"
body["isbn"] = "9783161484100"
Expand Down Expand Up @@ -140,6 +143,12 @@ def person(body_agent: dict) -> Person:
def experiment(body_asset: dict) -> Experiment:
return _create_class_with_body(Experiment, body_asset)


@pytest.fixture
def ml_model(body_asset: dict) -> MLModel:
return _create_class_with_body(MLModel, body_asset)


@pytest.fixture
def project(body_asset: dict) -> Project:
return _create_class_with_body(Project, body_asset)
Expand Down