Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e2fe6e8
fix(STEF-2854): handle InsufficientlyCompleteError during backtest tr…
egordm Mar 13, 2026
9f830e9
test(STEF-2854): replace mock with real NaN data in insufficient-data…
egordm Mar 13, 2026
c3a9fda
fix(STEF-2854): return None from predict() when no model fitted
egordm Mar 13, 2026
1ab0eef
fix(STEF-2854): make WindowedMetricVisualization robust to missing data
egordm Mar 13, 2026
dc170b9
Merge branch 'release/v4.0.0' into fix/STEF-2854-handle-insufficient-…
egordm Mar 13, 2026
ce6274e
fix(STEF-2854): fix combiner label/weight shape mismatch and Quantile…
egordm Mar 18, 2026
be93fb0
fix(STEF-2854): add Pydantic serializer to Quantile to suppress warnings
egordm Mar 18, 2026
a62459a
fix(STEF-2854): raise InsufficientlyCompleteError for empty datasets …
egordm Mar 18, 2026
625abfe
fix(STEF-2854): raise InsufficientlyCompleteError on empty combiner d…
egordm Mar 18, 2026
4d00735
feat(STEF-2854): add strict parameter to BenchmarkComparisonPipeline.…
egordm Mar 18, 2026
7f1a06d
fix(STEF-2854): use 'global' subdirectory for RUN_AND_GROUP analysis …
egordm Mar 18, 2026
d2f64f4
fix(STEF-2854): renormalize ensemble weights when base model predicti…
egordm Mar 18, 2026
368cb85
refactor(STEF-2854): extract nan_aware_weighted_mean helper
egordm Mar 18, 2026
aacbdde
feat(STEF-2854): add skip_analysis param to BenchmarkPipeline.run()
egordm Mar 18, 2026
9777e0e
feat(STEF-2854): add filterings override to AnalysisConfig
egordm Mar 19, 2026
96e7015
fix(STEF-2854): resolve ruff lint warnings
egordm Mar 19, 2026
6618b4d
fix(STEF-2854): resolve pyright type errors in modified files
egordm Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class AnalysisConfig(BaseConfig):
visualization_providers: list[VisualizationProvider] = Field(
default=[], description="List of visualization providers to use for generating analysis outputs"
)
filterings: list[Filtering] | None = Field(
default=None,
description="When set, only include these filterings (e.g. LeadTime, AvailableAt) in analysis. "
"None means use all filterings found in the evaluation data.",
)


class AnalysisPipeline:
Expand All @@ -61,8 +66,8 @@ def __init__(
super().__init__()
self.config = config

@staticmethod
def _group_by_filtering(
self,
reports: Sequence[tuple[TargetMetadata, EvaluationReport]],
) -> dict[Filtering, list[ReportTuple]]:
"""Group reports by their lead time filtering conditions.
Expand All @@ -71,13 +76,17 @@ def _group_by_filtering(
1-hour ahead vs 24-hour ahead forecasts), enabling comparison of model
performance across different forecasting horizons.

When ``config.filterings`` is set, only subsets matching those filterings are included.

Returns:
Dictionary mapping lead time filtering conditions to lists of report tuples.
"""
allowed = set(self.config.filterings) if self.config.filterings is not None else None
return groupby(
(subset.filtering, (base_metadata.with_filtering(subset.filtering), subset))
for base_metadata, report in reports
for subset in report.subset_reports
if allowed is None or subset.filtering in allowed
)

def run_for_subsets(
Expand All @@ -103,10 +112,7 @@ def run_for_subsets(
no providers support the requested aggregation level.
"""
return [
provider.create(
reports=reports,
aggregation=aggregation,
)
provider.create(reports=reports, aggregation=aggregation)
for provider in self.config.visualization_providers
if aggregation in provider.supported_aggregations
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
how performance metrics evolve across different time windows.
"""

import logging
import operator
from collections import defaultdict
from datetime import datetime
Expand All @@ -23,6 +24,8 @@
from openstef_beam.evaluation import EvaluationSubsetReport, Window
from openstef_core.types import Quantile

_logger = logging.getLogger(__name__)


class WindowedMetricVisualization(VisualizationProvider):
"""Creates time series plots showing metric evolution across evaluation windows.
Expand Down Expand Up @@ -180,7 +183,8 @@
time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global)

if not time_value_pairs:
raise ValueError("No windowed metrics found for the specified window and metric.")
_logger.warning("No windowed metrics for %s (%s) — skipping visualization.", metadata.name, self.name)
return self._empty_output(f"No windowed metrics available for {metadata.name}")

# Unpack the sorted pairs
timestamps = [pair[0] for pair in time_value_pairs]
Expand All @@ -198,19 +202,23 @@

return VisualizationOutput(name=self.name, figure=figure)

def _empty_output(self, message: str) -> VisualizationOutput:
return VisualizationOutput(name=self.name, html=f"<p>{message}</p>")

@override
def create_by_run_and_none(self, reports: dict[RunName, list[ReportTuple]]) -> VisualizationOutput:
metric_name, quantile_or_global = self._get_metric_info()
plotter = WindowedMetricPlotter()
has_data = False

# Collect data for each run
for run_name, report_pairs in reports.items():
for _metadata, report in report_pairs:
time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global)

# Skip if no data points found for this run
if not time_value_pairs:
raise ValueError("No windowed metrics found for the specified window, metric and run.")
_logger.warning("No windowed metrics for run '%s' (%s) — skipping.", run_name, self.name)
continue

# Unpack the sorted pairs
timestamps = [pair[0] for pair in time_value_pairs]
Expand All @@ -221,6 +229,10 @@
timestamps=timestamps,
metric_values=metric_values,
)
has_data = True

if not has_data:
return self._empty_output("No windowed metrics available for any run")

Check failure on line 235 in packages/openstef-beam/src/openstef_beam/analysis/visualizations/windowed_metric_visualization.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "No windowed metrics available for any run" 3 times.

See more on https://sonarcloud.io/project/issues?id=OpenSTEF_openstef&issues=AZzntK27HeSBMzjUfgmh&open=AZzntK27HeSBMzjUfgmh&pullRequest=837

title = self._create_plot_title(metric_name, quantile_or_global, "by Run")
figure = plotter.plot(title=title)
Expand All @@ -238,13 +250,14 @@
# Get the run name from the first target metadata for the title
run_name = reports[0][0].run_name if reports else ""

has_data = False
# Process each target's report
for metadata, report in reports:
time_value_pairs = self._extract_windowed_metric_values(report, metric_name, quantile_or_global)

# Skip if no data points found for this target
if not time_value_pairs:
raise ValueError("No windowed metrics found for the specified window, metric and target.")
_logger.warning("No windowed metrics for target '%s' (%s) — skipping.", metadata.name, self.name)
continue

# Unpack the sorted pairs
timestamps = [pair[0] for pair in time_value_pairs]
Expand All @@ -256,6 +269,10 @@
timestamps=timestamps,
metric_values=metric_values,
)
has_data = True

if not has_data:
return self._empty_output("No windowed metrics available for any target")

title_suffix = "by Target"
if run_name:
Expand All @@ -274,11 +291,13 @@
) -> VisualizationOutput:
metric_name, quantile_or_global = self._get_metric_info()
plotter = WindowedMetricPlotter()
has_data = False

# Process each run and calculate averaged metrics across its targets
for run_name, target_reports in reports.items():
if not target_reports:
raise ValueError("No windowed metrics found for the specified window, metric and run.")
_logger.warning("No reports for run '%s' (%s) — skipping.", run_name, self.name)
continue

# Average windowed metrics across all targets for this run
averaged_pairs = self._average_time_series_across_targets(
Expand All @@ -287,9 +306,9 @@
quantile_or_global=quantile_or_global,
)

# Skip if no averaged data points found for this run
if not averaged_pairs:
raise ValueError("No windowed averaged metrics found for the specified window, metric and run.")
_logger.warning("No windowed averaged metrics for run '%s' (%s) — skipping.", run_name, self.name)
continue

# Unpack the averaged pairs
timestamps = [pair[0] for pair in averaged_pairs]
Expand All @@ -301,6 +320,10 @@
timestamps=timestamps,
metric_values=metric_values,
)
has_data = True

if not has_data:
return self._empty_output("No windowed metrics available for any run")

title = self._create_plot_title(metric_name, quantile_or_global, "by run (averaged over targets in group)")
figure = plotter.plot(title=title, metric_name=metric_name)
Expand All @@ -321,10 +344,12 @@
for (run_name, _group_name), target_reports in reports.items():
run_to_targets.setdefault(run_name, []).extend(target_reports)

has_data = False
# Average metrics over all targets for each run
for run_name, all_target_reports in run_to_targets.items():
if not all_target_reports:
raise ValueError("No windowed metrics found for the specified window, metric and run.")
_logger.warning("No reports for run '%s' (%s) — skipping.", run_name, self.name)
continue

# Average windowed metrics across all targets for this run
averaged_pairs = self._average_time_series_across_targets(
Expand All @@ -334,7 +359,8 @@
)

if not averaged_pairs:
raise ValueError("No windowed averaged metrics found for the specified window, metric and run.")
_logger.warning("No windowed averaged metrics for run '%s' (%s) — skipping.", run_name, self.name)
continue

timestamps = [pair[0] for pair in averaged_pairs]
metric_values = [pair[1] for pair in averaged_pairs]
Expand All @@ -345,6 +371,10 @@
timestamps=timestamps,
metric_values=metric_values,
)
has_data = True

if not has_data:
return self._empty_output("No windowed metrics available for any run")

title = self._create_plot_title(metric_name, quantile_or_global, "by run (averaged over all targets)")
figure = plotter.plot(title=title, metric_name=metric_name)
Expand Down Expand Up @@ -373,9 +403,7 @@
)

if not averaged_pairs:
raise ValueError(
"No windowed averaged metrics found for the specified window, metric and run across all groups."
)
return self._empty_output("No windowed metrics available across all groups")

timestamps = [pair[0] for pair in averaged_pairs]
metric_values = [pair[1] for pair in averaged_pairs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
)
from openstef_core.base_model import BaseModel
from openstef_core.datasets import TimeSeriesDataset
from openstef_core.exceptions import FlatlinerDetectedError, MissingExtraError, NotFittedError
from openstef_core.exceptions import (
FlatlinerDetectedError,
InsufficientlyCompleteError,
MissingExtraError,
)
from openstef_core.types import Q
from openstef_models.presets import ForecastingWorkflowConfig, create_forecasting_workflow
from openstef_models.presets.forecasting_workflow import LocationConfig
Expand Down Expand Up @@ -118,6 +122,9 @@ def fit(self, data: RestrictedHorizonVersionedTimeSeries) -> None:
self._logger.warning("Flatliner detected during training")
self._is_flatliner_detected = True
return # Skip setting the workflow on flatliner detection
except InsufficientlyCompleteError:
self._logger.warning("Insufficient training data at %s, retaining previous model", data.horizon)
return # Retain previous model state; predictions will use the last successful fit

self._workflow = workflow

Expand All @@ -128,7 +135,8 @@ def predict(self, data: RestrictedHorizonVersionedTimeSeries) -> TimeSeriesDatas
return None

if self._workflow is None:
raise NotFittedError("Must call fit() before predict()")
self._logger.info("No fitted model available, skipping prediction")
return None

# Extract the dataset including both historical context and forecast period
predict_data = data.get_window(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def run(
self,
run_data: dict[RunName, BenchmarkStorage],
filter_args: F | None = None,
*,
strict: bool = True,
):
"""Execute comparison analysis across multiple benchmark runs.

Expand All @@ -132,6 +134,8 @@ def run(
Each storage backend should contain evaluation results for the run.
filter_args: Optional criteria for filtering targets. Only targets
matching these criteria will be included in the comparison.
strict: If True, raise an error when evaluation is missing for a target.
If False, skip missing targets.
"""
targets = self.target_provider.get_targets(filter_args)

Expand All @@ -142,7 +146,7 @@ def run(
targets=targets,
storage=run_storage,
run_name=run_name,
strict=True,
strict=strict,
)
reports.extend(run_reports)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def run(
run_name: str = "default",
filter_args: F | None = None,
n_processes: int | None = None,
*,
skip_analysis: bool = False,
) -> None:
"""Runs the benchmark for all targets, optionally filtered and in parallel.

Expand All @@ -174,6 +176,8 @@ def run(
matching these criteria will be processed.
n_processes: Number of processes to use for parallel execution. If None or 1,
targets are processed sequentially.
skip_analysis: When True, skips per-target and global analysis steps.
Useful when analysis will be run separately later.
"""
context = BenchmarkContext(run_name=run_name)

Expand All @@ -184,13 +188,13 @@ def run(

_logger.info("Running benchmark in parallel with %d processes", n_processes)
run_parallel(
process_fn=partial(self._run_for_target, context, forecaster_factory),
process_fn=partial(self._run_for_target, context, forecaster_factory, skip_analysis=skip_analysis),
items=targets,
n_processes=n_processes,
mode="loky",
)

if not self.storage.has_analysis_output(
if not skip_analysis and not self.storage.has_analysis_output(
AnalysisScope(
aggregation=AnalysisAggregation.GROUP,
run_name=context.run_name,
Expand All @@ -203,7 +207,14 @@ def run(

self.callback_manager.on_benchmark_complete(runner=self, targets=cast(list[BenchmarkTarget], targets))

def _run_for_target(self, context: BenchmarkContext, model_factory: ForecasterFactory[T], target: T) -> None:
def _run_for_target(
self,
context: BenchmarkContext,
model_factory: ForecasterFactory[T],
target: T,
*,
skip_analysis: bool = False,
) -> None:
"""Run benchmark for a single target."""
if not self.callback_manager.on_target_start(runner=self, target=target):
_logger.info("Skipping target")
Expand All @@ -221,7 +232,7 @@ def _run_for_target(self, context: BenchmarkContext, model_factory: ForecasterFa
predictions = self.storage.load_backtest_output(target)
self.run_evaluation_for_target(target=target, predictions=predictions, quantiles=forecaster.quantiles)

if not self.storage.has_analysis_output(
if not skip_analysis and not self.storage.has_analysis_output(
scope=AnalysisScope(
aggregation=AnalysisAggregation.TARGET,
target_name=target.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_analysis_path(self, scope: AnalysisScope) -> Path:
elif scope.aggregation == AnalysisAggregation.RUN_AND_NONE:
output_dir = base_dir / str(scope.group_name) / str(scope.target_name)
elif scope.aggregation == AnalysisAggregation.RUN_AND_GROUP:
output_dir = base_dir
output_dir = base_dir / "global"
elif scope.aggregation == AnalysisAggregation.RUN_AND_TARGET:
output_dir = base_dir / str(scope.group_name) / "global"
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,17 @@ def test_create_by_none_creates_time_series_visualization(
assert result.figure == mock_plotly_figure


def test_create_by_none_raises_error_when_no_windowed_data(
def test_create_by_none_returns_empty_output_when_no_windowed_data(
empty_evaluation_report: EvaluationSubsetReport, simple_target_metadata: TargetMetadata, sample_window: Window
):
"""Test that create_by_none raises appropriate error when no windowed metrics are found."""
"""Test that create_by_none returns an HTML placeholder when no windowed metrics are found."""
viz = WindowedMetricVisualization(name="test_viz", metric="mae", window=sample_window)

with pytest.raises(ValueError, match="No windowed metrics found for the specified window and metric"):
viz.create_by_none(empty_evaluation_report, simple_target_metadata)
result = viz.create_by_none(empty_evaluation_report, simple_target_metadata)

assert result.figure is None
assert result.html is not None
assert "No windowed metrics" in result.html


def test_create_by_target_adds_time_series_per_target(
Expand Down
Loading
Loading