-
Notifications
You must be signed in to change notification settings - Fork 6
feat: Implement ingested_forecast_length utility and integrate with GFS (#412) #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
6332ae1
6d5f762
e79f998
6f0270c
75f169c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,57 @@ | ||||||||||||||||||||||||||||||||||
| from collections.abc import Sequence | ||||||||||||||||||||||||||||||||||
| from typing import Protocol | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from reformatters.common.logging import get_logger | ||||||||||||||||||||||||||||||||||
| from reformatters.common.types import Timedelta, Timestamp | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| log = get_logger(__name__) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # This Protocol tells the type checker: "Trust me, these objects have time info" | ||||||||||||||||||||||||||||||||||
| class HasTimeInfo(Protocol): | ||||||||||||||||||||||||||||||||||
| init_time: Timestamp | ||||||||||||||||||||||||||||||||||
| lead_time: Timedelta | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| # This Protocol tells the type checker: "Trust me, these objects have time info" | |
| class HasTimeInfo(Protocol): | |
| init_time: Timestamp | |
| lead_time: Timedelta | |
| class DeterministicForecastSourceFileCoord(Protocol): | |
| init_time: Timestamp | |
| lead_time: Timedelta |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's allow callers to pass in the process_results directly and handle taking the max across variable names (the str key in this Mapping are variable names) within this function, rather than needing to make all callers do the same flattening into a Sequence[DeterministicForecastSourceFileCoord]
| results_coords: Sequence[HasTimeInfo], | |
| results_coords: Mapping[str, Sequence[DeterministicForecastSourceFileCoord]] , |
Then also add to the docstring the note that "The maximum processed lead time across all variables is set as the ingested_forecast_length. This can hide the nuance of a specific variable having fewer lead times processed than others."
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) -> None: | |
| ) -> xr.Dataset: |
lets have this return the modified dataset so callers would do ds = update_ingested_forecast_length(...)
ArkVex marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if "ingested_forecast_length" not in template_ds.coords: | |
| log.warning( | |
| "ingested_forecast_length coordinate not found in template dataset." | |
| ) | |
| return | |
| assert "ingested_forecast_length" in template_ds.coords |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to look at existing values because of the way we update datasets by overwriting everything in a shard, so overwriting with whatever we processed this run is correct. In practice, we make sure we're only adding to a dataset, but that happens outside of here.
| if init_time in template_ds.coords["init_time"]: | |
| current_val = template_ds["ingested_forecast_length"].loc[ | |
| {"init_time": init_time} | |
| ] | |
| # Use .values and pd.isnull to safely check for NaT (Not a Time) | |
| if pd.isnull(current_val.values) or max_lead > current_val: | |
| log.info( | |
| f"Updating ingested_forecast_length for {init_time} to {max_lead}" | |
| ) | |
| template_ds["ingested_forecast_length"].loc[ | |
| {"init_time": init_time} | |
| ] = max_lead | |
| template_ds["ingested_forecast_length"].loc[ | |
| {"init_time": init_time} | |
| ] = max_lead |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||||||||||||
| from reformatters.common.download import ( | ||||||||||||||||||||||||||||||||||||
| http_download_to_disk, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| from reformatters.common.ingest_stats import update_ingested_forecast_length | ||||||||||||||||||||||||||||||||||||
| from reformatters.common.iterating import digest, group_by | ||||||||||||||||||||||||||||||||||||
| from reformatters.common.logging import get_logger | ||||||||||||||||||||||||||||||||||||
| from reformatters.common.region_job import ( | ||||||||||||||||||||||||||||||||||||
|
|
@@ -150,6 +151,26 @@ def apply_data_transformations( | |||||||||||||||||||||||||||||||||||
| if isinstance(keep_mantissa_bits, int): | ||||||||||||||||||||||||||||||||||||
| round_float32_inplace(data_array.values, keep_mantissa_bits) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def update_template_with_results( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| process_results: Mapping[str, Sequence[NoaaGfsSourceFileCoord]], | ||||||||||||||||||||||||||||||||||||
| ) -> xr.Dataset: | ||||||||||||||||||||||||||||||||||||
| # 1. Run the standard update logic from the parent class | ||||||||||||||||||||||||||||||||||||
| # This returns the updated dataset | ||||||||||||||||||||||||||||||||||||
| ds = super().update_template_with_results(process_results) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # 2. Extract the coordinates from the dictionary | ||||||||||||||||||||||||||||||||||||
| # process_results is { "filename": [coord1, coord2], ... } | ||||||||||||||||||||||||||||||||||||
| all_coords = [] | ||||||||||||||||||||||||||||||||||||
| for coord_list in process_results.values(): | ||||||||||||||||||||||||||||||||||||
| all_coords.extend(coord_list) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # 3. Run our new logic | ||||||||||||||||||||||||||||||||||||
| update_ingested_forecast_length(ds, all_coords) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # 4. Return the modified dataset (Crucial!) | ||||||||||||||||||||||||||||||||||||
| return ds | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| # 1. Run the standard update logic from the parent class | |
| # This returns the updated dataset | |
| ds = super().update_template_with_results(process_results) | |
| # 2. Extract the coordinates from the dictionary | |
| # process_results is { "filename": [coord1, coord2], ... } | |
| all_coords = [] | |
| for coord_list in process_results.values(): | |
| all_coords.extend(coord_list) | |
| # 3. Run our new logic | |
| update_ingested_forecast_length(ds, all_coords) | |
| # 4. Return the modified dataset (Crucial!) | |
| return ds | |
| ds = super().update_template_with_results(process_results) | |
| return update_ingested_forecast_length(ds, process_results) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from collections.abc import Mapping | ||
| from typing import Any, cast | ||
|
Check failure on line 2 in tests/common/test_ingest_stats.py
|
||
|
||
|
|
||
| import pandas as pd | ||
| import xarray as xr | ||
|
|
||
| from reformatters.common.ingest_stats import update_ingested_forecast_length | ||
| from reformatters.common.region_job import CoordinateValueOrRange, SourceFileCoord | ||
| from reformatters.common.types import Dim, Timedelta, Timestamp | ||
|
|
||
|
|
||
| # --- Mock Class --- | ||
| class MockSourceFileCoord(SourceFileCoord): | ||
| init_time: Timestamp | ||
| lead_time: Timedelta | ||
|
|
||
| def out_loc(self) -> Mapping[Dim, CoordinateValueOrRange]: | ||
| return {} | ||
|
|
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are missing a test that checks that the existing values in the array not not modified. |
||
| # ------------------ | ||
|
|
||
|
|
||
| def test_update_ingested_forecast_length_simple() -> None: | ||
| # 1. Setup a dummy dataset | ||
| init_times = [ | ||
| pd.Timestamp("2025-01-01 12:00"), | ||
| pd.Timestamp("2025-01-01 18:00"), | ||
| ] | ||
|
|
||
| # We use 'cast' to silence the strict type checker here | ||
|
||
| empty_deltas = pd.to_timedelta([pd.NaT, pd.NaT]).values | ||
|
|
||
| ds = xr.Dataset( | ||
| coords={ | ||
| "init_time": init_times, | ||
| "ingested_forecast_length": (("init_time",), empty_deltas), | ||
| } | ||
| ) | ||
|
|
||
| # 2. Setup the Results | ||
| coord1 = MockSourceFileCoord( | ||
| init_time=pd.Timestamp("2025-01-01 12:00"), | ||
| lead_time=pd.Timedelta(hours=6), | ||
| ) | ||
| coord2 = MockSourceFileCoord( | ||
| init_time=pd.Timestamp("2025-01-01 18:00"), | ||
| lead_time=pd.Timedelta(hours=48), | ||
| ) | ||
|
|
||
| results = [coord1, coord2] | ||
|
|
||
| # 3. Run the function | ||
| update_ingested_forecast_length(ds, results) | ||
|
|
||
| # 4. Check the answers | ||
| assert ds["ingested_forecast_length"].sel( | ||
| init_time="2025-01-01 12:00" | ||
| ).values == pd.Timedelta(hours=6) | ||
| assert ds["ingested_forecast_length"].sel( | ||
| init_time="2025-01-01 18:00" | ||
| ).values == pd.Timedelta(hours=48) | ||
|
|
||
|
|
||
| def test_update_ingested_forecast_length_update_existing() -> None: | ||
| init_time = pd.Timestamp("2025-01-01 12:00") | ||
|
|
||
| # Start with 6 hours already recorded | ||
|
||
| ds = xr.Dataset( | ||
| coords={ | ||
| "init_time": [init_time], | ||
| "ingested_forecast_length": (("init_time",), [pd.Timedelta(hours=6)]), | ||
| } | ||
| ) | ||
|
|
||
| new_coord = MockSourceFileCoord( | ||
| init_time=init_time, | ||
| lead_time=pd.Timedelta(hours=12), | ||
| ) | ||
|
|
||
| update_ingested_forecast_length(ds, [new_coord]) | ||
|
|
||
| assert ds["ingested_forecast_length"].sel( | ||
| init_time=init_time | ||
| ).values == pd.Timedelta(hours=12) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove