Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/reformatters/common/ingest_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections.abc import Sequence
from typing import Protocol

import pandas as pd
import xarray as xr

from reformatters.common.logging import get_logger
from reformatters.common.types import Timedelta, Timestamp

log = get_logger(__name__)


# This Protocol tells the type checker: "Trust me, these objects have time info"
class HasTimeInfo(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

init_time: Timestamp
lead_time: Timedelta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This Protocol tells the type checker: "Trust me, these objects have time info"
class HasTimeInfo(Protocol):
init_time: Timestamp
lead_time: Timedelta
class DeterministicForecastSourceFileCoord(Protocol):
init_time: Timestamp
lead_time: Timedelta



def update_ingested_forecast_length(
template_ds: xr.Dataset,
results_coords: Sequence[HasTimeInfo],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's allow callers to pass in the process_results directly and handle taking the max across variable names (the str key in this Mapping are variable names) within this function, rather than needing to make all callers do the same flattening into a Sequence[DeterministicForecastSourceFileCoord]

Suggested change
results_coords: Sequence[HasTimeInfo],
results_coords: Mapping[str, Sequence[DeterministicForecastSourceFileCoord]] ,

Then also add to the docstring the note that "The maximum processed lead time across all variables is set as the ingested_forecast_length. This can hide the nuance of a specific variable having fewer lead times processed than others."

) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> None:
) -> xr.Dataset:

lets have this return the modified dataset so callers would do ds = update_ingested_forecast_length(...)

"""
Updates the 'ingested_forecast_length' coordinate in the template dataset.
"""
if "ingested_forecast_length" not in template_ds.coords:
log.warning(
"ingested_forecast_length coordinate not found in template dataset."
)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "ingested_forecast_length" not in template_ds.coords:
log.warning(
"ingested_forecast_length coordinate not found in template dataset."
)
return
assert "ingested_forecast_length" in template_ds.coords


# 1. Group lead times by init_time
max_lead_per_init: dict[Timestamp, Timedelta] = {}

for coord in results_coords:
# We check if we found a new 'longest' forecast for this specific start time
if (
coord.init_time not in max_lead_per_init
or coord.lead_time > max_lead_per_init[coord.init_time]
):
max_lead_per_init[coord.init_time] = coord.lead_time

# 2. Update the dataset
for init_time, max_lead in max_lead_per_init.items():
if init_time in template_ds.coords["init_time"]:
current_val = template_ds["ingested_forecast_length"].loc[
{"init_time": init_time}
]

# Use .values and pd.isnull to safely check for NaT (Not a Time)
if pd.isnull(current_val.values) or max_lead > current_val:
log.info(
f"Updating ingested_forecast_length for {init_time} to {max_lead}"
)
template_ds["ingested_forecast_length"].loc[
{"init_time": init_time}
] = max_lead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to look at existing values because of the way we update datasets by overwriting everything in a shard, so overwriting with whatever we processed this run is correct. In practice, we make sure we're only adding to a dataset, but that happens outside of here.

Suggested change
if init_time in template_ds.coords["init_time"]:
current_val = template_ds["ingested_forecast_length"].loc[
{"init_time": init_time}
]
# Use .values and pd.isnull to safely check for NaT (Not a Time)
if pd.isnull(current_val.values) or max_lead > current_val:
log.info(
f"Updating ingested_forecast_length for {init_time} to {max_lead}"
)
template_ds["ingested_forecast_length"].loc[
{"init_time": init_time}
] = max_lead
template_ds["ingested_forecast_length"].loc[
{"init_time": init_time}
] = max_lead

21 changes: 21 additions & 0 deletions src/reformatters/noaa/gfs/region_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from reformatters.common.download import (
http_download_to_disk,
)
from reformatters.common.ingest_stats import update_ingested_forecast_length
from reformatters.common.iterating import digest, group_by
from reformatters.common.logging import get_logger
from reformatters.common.region_job import (
Expand Down Expand Up @@ -150,6 +151,26 @@ def apply_data_transformations(
if isinstance(keep_mantissa_bits, int):
round_float32_inplace(data_array.values, keep_mantissa_bits)

def update_template_with_results(
self,
process_results: Mapping[str, Sequence[NoaaGfsSourceFileCoord]],
) -> xr.Dataset:
# 1. Run the standard update logic from the parent class
# This returns the updated dataset
ds = super().update_template_with_results(process_results)

# 2. Extract the coordinates from the dictionary
# process_results is { "filename": [coord1, coord2], ... }
all_coords = []
for coord_list in process_results.values():
all_coords.extend(coord_list)

# 3. Run our new logic
update_ingested_forecast_length(ds, all_coords)

# 4. Return the modified dataset (Crucial!)
return ds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

less code is easier to understand code!

Suggested change
# 1. Run the standard update logic from the parent class
# This returns the updated dataset
ds = super().update_template_with_results(process_results)
# 2. Extract the coordinates from the dictionary
# process_results is { "filename": [coord1, coord2], ... }
all_coords = []
for coord_list in process_results.values():
all_coords.extend(coord_list)
# 3. Run our new logic
update_ingested_forecast_length(ds, all_coords)
# 4. Return the modified dataset (Crucial!)
return ds
ds = super().update_template_with_results(process_results)
return update_ingested_forecast_length(ds, process_results)


@classmethod
def operational_update_jobs(
cls,
Expand Down
85 changes: 85 additions & 0 deletions tests/common/test_ingest_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections.abc import Mapping
from typing import Any, cast

Check failure on line 2 in tests/common/test_ingest_stats.py

View workflow job for this annotation

GitHub Actions / Code Quality (amd64)

Ruff (F401)

tests/common/test_ingest_stats.py:2:25: F401 `typing.cast` imported but unused

Check failure on line 2 in tests/common/test_ingest_stats.py

View workflow job for this annotation

GitHub Actions / Code Quality (amd64)

Ruff (F401)

tests/common/test_ingest_stats.py:2:20: F401 `typing.Any` imported but unused
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove, unused


import pandas as pd
import xarray as xr

from reformatters.common.ingest_stats import update_ingested_forecast_length
from reformatters.common.region_job import CoordinateValueOrRange, SourceFileCoord
from reformatters.common.types import Dim, Timedelta, Timestamp


# --- Mock Class ---
class MockSourceFileCoord(SourceFileCoord):
init_time: Timestamp
lead_time: Timedelta

def out_loc(self) -> Mapping[Dim, CoordinateValueOrRange]:
return {}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing a test that checks that the existing values in the array not not modified.

# ------------------


def test_update_ingested_forecast_length_simple() -> None:
# 1. Setup a dummy dataset
init_times = [
pd.Timestamp("2025-01-01 12:00"),
pd.Timestamp("2025-01-01 18:00"),
]

# We use 'cast' to silence the strict type checker here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this comment :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫣

empty_deltas = pd.to_timedelta([pd.NaT, pd.NaT]).values

ds = xr.Dataset(
coords={
"init_time": init_times,
"ingested_forecast_length": (("init_time",), empty_deltas),
}
)

# 2. Setup the Results
coord1 = MockSourceFileCoord(
init_time=pd.Timestamp("2025-01-01 12:00"),
lead_time=pd.Timedelta(hours=6),
)
coord2 = MockSourceFileCoord(
init_time=pd.Timestamp("2025-01-01 18:00"),
lead_time=pd.Timedelta(hours=48),
)

results = [coord1, coord2]

# 3. Run the function
update_ingested_forecast_length(ds, results)

# 4. Check the answers
assert ds["ingested_forecast_length"].sel(
init_time="2025-01-01 12:00"
).values == pd.Timedelta(hours=6)
assert ds["ingested_forecast_length"].sel(
init_time="2025-01-01 18:00"
).values == pd.Timedelta(hours=48)


def test_update_ingested_forecast_length_update_existing() -> None:
init_time = pd.Timestamp("2025-01-01 12:00")

# Start with 6 hours already recorded
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all the comments in this file except for this one. This one is helpful because it highlights the case we're testing

ds = xr.Dataset(
coords={
"init_time": [init_time],
"ingested_forecast_length": (("init_time",), [pd.Timedelta(hours=6)]),
}
)

new_coord = MockSourceFileCoord(
init_time=init_time,
lead_time=pd.Timedelta(hours=12),
)

update_ingested_forecast_length(ds, [new_coord])

assert ds["ingested_forecast_length"].sel(
init_time=init_time
).values == pd.Timedelta(hours=12)
Loading