Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/openstef-core/src/openstef_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,11 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa
Returns:
Core schema for Pydantic validation.
"""
return core_schema.no_info_after_validator_function(cls, handler(float))
return core_schema.no_info_after_validator_function(
cls,
handler(float),
serialization=core_schema.plain_serializer_function_ser_schema(float),
)

def format(self) -> str:
"""Instance method to format the quantile as a string.
Expand Down
21 changes: 20 additions & 1 deletion packages/openstef-core/tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#
# SPDX-License-Identifier: MPL-2.0

import warnings
from datetime import UTC, datetime, time, timedelta, timezone

import pandas as pd
import pytest
import pytz
from pydantic import BaseModel

from openstef_core.types import AvailableAt, LeadTime
from openstef_core.types import AvailableAt, LeadTime, Quantile, QuantileOrGlobal


@pytest.mark.parametrize(
Expand Down Expand Up @@ -342,3 +344,20 @@ def test_available_at_apply_index_matches_apply():
scalar = pd.DatetimeIndex([at.apply(ts.to_pydatetime()) for ts in index])

pd.testing.assert_index_equal(vectorized, scalar)


def test_quantile_serialization_no_warnings():
"""Quantile used as dict key in a QuantileOrGlobal union must not trigger pydantic serialization warnings."""

class Model(BaseModel):
metrics: dict[QuantileOrGlobal, dict[str, float]]

m = Model(metrics={Quantile(0.05): {"mae": 1.0}, Quantile(0.5): {"mae": 2.0}, "global": {"mae": 1.5}})

with warnings.catch_warnings():
warnings.filterwarnings("error", category=UserWarning, message="Pydantic serializer")
data = m.model_dump_json()

assert '"0.05"' in data
assert '"0.5"' in data
assert '"global"' in data
Loading