Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions docs/templates/template.pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from pathlib import Path
from typing import Any

from pydantic import BaseModel

from marimba.core.pipeline import BasePipeline
from marimba.core.schemas.base import BaseMetadata
from marimba.core.schemas.header.base import MetadataHeader
from marimba.core.schemas.ifdo import iFDOMetadata


Expand Down Expand Up @@ -111,7 +114,10 @@ def _package(
data_dir: Path,
config: dict[str, Any],
**kwargs: dict[str, Any],
) -> dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]]:
) -> tuple[
dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]],
dict[type[BaseMetadata], MetadataHeader[BaseModel]],
]:
"""
Package data from data_dir for distribution.

Expand All @@ -123,5 +129,11 @@ def _package(
Returns:
Dictionary mapping source paths to tuples of (destination path, BaseMetadata list, metadata).
"""
data_mapping: dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]] = {}
data_mapping: tuple[
dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]],
dict[type[BaseMetadata], MetadataHeader[BaseModel]],
] = (
{},
{},
)
return data_mapping
40 changes: 35 additions & 5 deletions marimba/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,33 @@

Classes:
- BasePipeline: Abstract base class for Marimba pipelines.
- PackageEntry: Package metadata for a single file.

"""

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, NamedTuple

from pydantic import BaseModel

from marimba.core.schemas.base import BaseMetadata
from marimba.core.schemas.header.base import MetadataHeader
from marimba.core.utils.log import LogMixin
from marimba.core.utils.paths import format_path_for_logging
from marimba.core.utils.rich import format_command, format_entity


class PackageEntry(NamedTuple):
"""
Package metadata for a single file.
"""

path: Path
metadata: list[BaseMetadata] | None = None
extra: dict[str, Any] | None = None


class BasePipeline(ABC, LogMixin):
"""
Marimba pipeline abstract base class. All pipelines should inherit from this class.
Expand Down Expand Up @@ -184,7 +198,10 @@ def run_package(
data_dir: Path,
config: dict[str, Any],
**kwargs: dict[str, Any],
) -> dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]]:
) -> tuple[
dict[Path, PackageEntry],
dict[type[BaseMetadata], MetadataHeader[BaseModel]],
]:
"""
Package a dataset from the given data directories and their corresponding collection configurations.

Expand All @@ -204,13 +221,20 @@ def run_package(
f"data_dir={format_path_for_logging(data_dir, Path(self._root_path).parents[2])}, {config=}, {kwargs=}",
)

data_mapping = self._package(data_dir, config, **kwargs)
result = self._package(data_dir, config, **kwargs)

metadata_header: dict[type[BaseMetadata], MetadataHeader[BaseModel]]
if isinstance(result, tuple):
data_mapping, metadata_header = result
else:
data_mapping = result
metadata_header = {}

self.logger.info(
f"Completed {format_command('package')} command for pipeline {format_entity(self.class_name)}",
)

return data_mapping
return {key: PackageEntry(*entry) for key, entry in data_mapping.items()}, metadata_header

def run_post_package(
self,
Expand Down Expand Up @@ -273,7 +297,13 @@ def _package(
data_dir: Path,
config: dict[str, Any],
**kwargs: dict[str, Any],
) -> dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]]:
) -> (
dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]]
| tuple[
dict[Path, tuple[Path, list[BaseMetadata] | None, dict[str, Any] | None]],
dict[type[BaseMetadata], MetadataHeader[BaseModel]],
]
):
"""
`run_package` implementation; override this.
"""
Expand Down
14 changes: 13 additions & 1 deletion marimba/core/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from pathlib import Path
from typing import Any

from pydantic import BaseModel

from marimba.core.schemas.header.base import MetadataHeader


class BaseMetadata(ABC):
"""
Expand Down Expand Up @@ -79,6 +83,7 @@ def create_dataset_metadata(
dataset_name: str,
root_dir: Path,
items: dict[str, list["BaseMetadata"]],
metadata_header: MetadataHeader[BaseModel] | None = None,
metadata_name: str | None = None,
*,
dry_run: bool = False,
Expand All @@ -91,7 +96,14 @@ def create_dataset_metadata(
@abstractmethod
def process_files(
cls,
dataset_mapping: dict[Path, tuple[list["BaseMetadata"], dict[str, Any] | None]],
dataset_mapping: dict[
Path,
tuple[
list["BaseMetadata"],
dict[str, Any] | None,
MetadataHeader[BaseModel] | None,
],
],
max_workers: int | None = None,
logger: logging.Logger | None = None,
*,
Expand Down
13 changes: 12 additions & 1 deletion marimba/core/schemas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from pathlib import Path
from typing import Any, Union, cast

from pydantic import BaseModel

from marimba.core.schemas.base import BaseMetadata
from marimba.core.schemas.header.base import MetadataHeader
from marimba.core.utils.metadata import yaml_saver


Expand Down Expand Up @@ -177,6 +180,7 @@ def create_dataset_metadata(
dataset_name: str,
root_dir: Path,
items: dict[str, list["BaseMetadata"]],
_metadata_header: MetadataHeader[BaseModel] | None = None,
metadata_name: str | None = None,
*,
dry_run: bool = False,
Expand Down Expand Up @@ -212,7 +216,14 @@ def create_dataset_metadata(
@classmethod
def process_files(
cls,
dataset_mapping: dict[Path, tuple[list["BaseMetadata"], dict[str, Any] | None]],
dataset_mapping: dict[
Path,
tuple[
list["BaseMetadata"],
dict[str, Any] | None,
MetadataHeader[BaseModel] | None,
],
],
max_workers: int | None = None,
logger: logging.Logger | None = None,
*,
Expand Down
6 changes: 6 additions & 0 deletions marimba/core/schemas/header/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from marimba.core.schemas.header.base import (
HeaderMergeConflictError,
MetadataHeader,
)

__all__ = ["MetadataHeader", "HeaderMergeConflictError"]
101 changes: 101 additions & 0 deletions marimba/core/schemas/header/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Module containing the implementation of the metadata header.

Classes:
HeaderMergeConflictError: Custom Error-Type for signaling that two header cannot be merged.
MetadataHeader: Metadata header class wrapping mergeable header data.
"""

from __future__ import annotations

import inspect
from copy import copy
from typing import Any, Generic, TypeVar

from pydantic import BaseModel

T = TypeVar("T", bound=BaseModel)


class HeaderMergeConflictError(Exception):
"""
Custom Error-Type for signaling that two header cannot be merged.
"""

def __init__(self, conflict_attr: str, *args: object) -> None:
"""
Initializes a HeaderMergeConflictError instance.

Args:
conflict_attr: The name of the attribute responsible for the merge conflict.
*args: Error
"""
super().__init__(*args)
self._conflict_attr = conflict_attr

def __str__(self) -> str:
return f"Conflicting header information in field: {self._conflict_attr}"


class MetadataHeader(Generic[T]):
"""
Metadata header class wrapping mergeable header data.

For this the data has to be able to be parsed into a Python dictionary.
"""

def __init__(self, header: T) -> None:
"""
Initializes a MetadataHeader instance.

Args:
header: Header data.
"""
self._header = header

@property
def header(self) -> T:
"""
Returns inner header data.
"""
return self._header

def __add__(self, other: MetadataHeader[T]) -> MetadataHeader[T]:
result_data = self.header.model_dump(mode="python")
other_data = other.header.model_dump(mode="python")

for attr_name, own_value in result_data.items():
other_value = other_data.get(attr_name, None)
if own_value == other_value:
continue

if other_value is None:
continue

if own_value is not None:
raise HeaderMergeConflictError(attr_name)

result_data[attr_name] = other_value

return MetadataHeader(type(self.header).model_validate(result_data))

def merge(self, other: MetadataHeader[T] | None) -> MetadataHeader[T]:
"""
Merge a metadata header with this header.

Args:
other: Other metadata header.

Returns:
Merged header of the this and the other header.
"""
if other is None:
return copy(self)
return self + other

@staticmethod
def _get_attributes(value: object) -> list[tuple[str, Any]]:
members = inspect.getmembers(value)
return [
(name, value) for name, value in members if (not name.startswith("_")) and (not inspect.ismethod(value))
]
Loading