diff --git a/packages/openstef-core/src/openstef_core/types.py b/packages/openstef-core/src/openstef_core/types.py index 1d536f9cb..2bd452bdd 100644 --- a/packages/openstef-core/src/openstef_core/types.py +++ b/packages/openstef-core/src/openstef_core/types.py @@ -338,7 +338,11 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa Returns: Core schema for Pydantic validation. """ - return core_schema.no_info_after_validator_function(cls, handler(float)) + return core_schema.no_info_after_validator_function( + cls, + handler(float), + serialization=core_schema.plain_serializer_function_ser_schema(float), + ) def format(self) -> str: """Instance method to format the quantile as a string. diff --git a/packages/openstef-core/tests/unit/test_types.py b/packages/openstef-core/tests/unit/test_types.py index 189c02733..1eb166929 100644 --- a/packages/openstef-core/tests/unit/test_types.py +++ b/packages/openstef-core/tests/unit/test_types.py @@ -2,13 +2,15 @@ # # SPDX-License-Identifier: MPL-2.0 +import warnings from datetime import UTC, datetime, time, timedelta, timezone import pandas as pd import pytest import pytz +from pydantic import BaseModel -from openstef_core.types import AvailableAt, LeadTime +from openstef_core.types import AvailableAt, LeadTime, Quantile, QuantileOrGlobal @pytest.mark.parametrize( @@ -342,3 +344,20 @@ def test_available_at_apply_index_matches_apply(): scalar = pd.DatetimeIndex([at.apply(ts.to_pydatetime()) for ts in index]) pd.testing.assert_index_equal(vectorized, scalar) + + +def test_quantile_serialization_no_warnings(): + """Quantile used as dict key in a QuantileOrGlobal union must not trigger pydantic serialization warnings.""" + + class Model(BaseModel): + metrics: dict[QuantileOrGlobal, dict[str, float]] + + m = Model(metrics={Quantile(0.05): {"mae": 1.0}, Quantile(0.5): {"mae": 2.0}, "global": {"mae": 1.5}}) + + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=UserWarning, message="Pydantic serializer") + data = m.model_dump_json() + + assert '"0.05"' in data + assert '"0.5"' in data + assert '"global"' in data