Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/reformatters/ecmwf/ecmwf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@ def all_variables_available(
or data_var.internal_attrs.date_available <= init_time
for data_var in data_var_group
)


def has_hour_0_values(data_var: EcmwfDataVar) -> bool:
"""Returns True if this variable has a value at lead_time=0h.

ECMWF avg/accum variables (e.g. total precipitation, radiation) include a 0h
accumulation of 0 in the GRIB, so they do have hour 0 values. Only "max" and "min"
step_type variables are absent at lead_time=0h since they represent the extremum
since the previous post-processing step, which doesn't exist at initialization time.
"""
if data_var.internal_attrs.hour_0_values_override is not None:
return data_var.internal_attrs.hour_0_values_override
return data_var.attrs.step_type not in ("max", "min")
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from reformatters.common.download import (
http_download_to_disk,
)
from reformatters.common.iterating import digest
from reformatters.common.iterating import digest, item
from reformatters.common.logging import get_logger
from reformatters.common.region_job import (
CoordinateValueOrRange,
Expand All @@ -32,7 +32,7 @@
)
from reformatters.ecmwf.ecmwf_config_models import EcmwfDataVar
from reformatters.ecmwf.ecmwf_grib_index import get_message_byte_ranges_from_index
from reformatters.ecmwf.ecmwf_utils import all_variables_available
from reformatters.ecmwf.ecmwf_utils import all_variables_available, has_hour_0_values

log = get_logger(__name__)

Expand Down Expand Up @@ -105,12 +105,13 @@ def source_groups(
data_vars: Sequence[EcmwfDataVar],
) -> Sequence[Sequence[EcmwfDataVar]]:
"""Return groups of variables, where all variables in a group can be retrieved from the same source file."""
vars_by_date_available = defaultdict(list)
vars_by_key: defaultdict[tuple[object, bool], list[EcmwfDataVar]] = defaultdict(
list
)
for data_var in data_vars:
vars_by_date_available[data_var.internal_attrs.date_available].append(
data_var
)
return list(vars_by_date_available.values())
key = (data_var.internal_attrs.date_available, has_hour_0_values(data_var))
vars_by_key[key].append(data_var)
return list(vars_by_key.values())

def generate_source_file_coords(
self,
Expand All @@ -126,6 +127,7 @@ def generate_source_file_coords(
download/read performance by treating them separately and parallelizing.
"""
coords = []
group_has_hour_0_values = item({has_hour_0_values(v) for v in data_var_group})
for init_time, lead_time, ensemble_member in itertools.product(
processing_region_ds["init_time"].values,
processing_region_ds["lead_time"].values,
Expand All @@ -140,6 +142,9 @@ def generate_source_file_coords(
)
continue

if not group_has_hour_0_values and lead_time == np.timedelta64(0):
continue

coord = EcmwfIfsEnsForecast15Day025DegreeSourceFileCoord(
init_time=init_time,
lead_time=lead_time,
Expand Down
87 changes: 55 additions & 32 deletions tests/ecmwf/ifs_ens/forecast_15_day_0_25_degree/region_job_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from pathlib import Path
from unittest.mock import Mock
Expand All @@ -10,6 +11,7 @@
from reformatters.common.iterating import item
from reformatters.common.pydantic import replace
from reformatters.common.storage import DatasetFormat, StorageConfig, StoreFactory
from reformatters.ecmwf.ecmwf_config_models import EcmwfDataVar
from reformatters.ecmwf.ifs_ens.forecast_15_day_0_25_degree.region_job import (
EcmwfIfsEnsForecast15Day025DegreeRegionJob,
EcmwfIfsEnsForecast15Day025DegreeSourceFileCoord,
Expand Down Expand Up @@ -48,16 +50,15 @@ def test_region_job_source_groups() -> None:
groups = EcmwfIfsEnsForecast15Day025DegreeRegionJob.source_groups(
template_config.data_vars
)
assert len(groups) == 3
# Main group: vars with no date_available (available since dataset start)
assert len(groups) == 4
# Main group: vars with no date_available (available since dataset start), all with hour 0 values
assert len(groups[0]) == 16
# wind_gust_10m is available from 2024-11-13 (same as categorical_precipitation_type_surface)
assert {v.name for v in groups[1]} == {
"categorical_precipitation_type_surface",
"wind_gust_10m",
}
# categorical_precipitation_type_surface is instant (has hour 0) and available from 2024-11-13
assert item(groups[1]).name == "categorical_precipitation_type_surface"
# wind_gust_10m is max-window (no hour 0) and available from 2024-11-13
assert item(groups[2]).name == "wind_gust_10m"
# total_cloud_cover_atmosphere is available from 2025-11-21
assert item(groups[2]).name == "total_cloud_cover_atmosphere"
assert item(groups[3]).name == "total_cloud_cover_atmosphere"


def test_region_job_generate_source_file_coords() -> None:
Expand All @@ -81,7 +82,7 @@ def test_region_job_generate_source_file_coords() -> None:
groups = EcmwfIfsEnsForecast15Day025DegreeRegionJob.source_groups(
template_config.data_vars
)
# We are grouping by date_available, so we should get 3 groups
# We are grouping by date_available and has_hour_0_values, so we should get 4 groups
group_0_source_file_coords = region_job.generate_source_file_coords(
processing_region_ds, groups[0]
)
Expand All @@ -92,13 +93,20 @@ def test_region_job_generate_source_file_coords() -> None:
group_1_source_file_coords = region_job.generate_source_file_coords(
processing_region_ds, groups[1]
)
# group 1 has two vars (categorical_precipitation_type_surface and wind_gust_10m),
# both available from 2024-11-13. Nov 12 is skipped, so 2 init times x 2 members x 12 lead times = 48
# group 1 has categorical_precipitation_type_surface (instant, has hour 0) available from
# 2024-11-13. Nov 12 is skipped, so 2 init times x 2 members x 12 lead times = 48.
assert len(group_1_source_file_coords) == 2 * 2 * 12
assert {v.name for v in group_1_source_file_coords[0].data_var_group} == {
"categorical_precipitation_type_surface",
"wind_gust_10m",
}
assert item(group_1_source_file_coords[0].data_var_group).name == (
"categorical_precipitation_type_surface"
)

group_2_source_file_coords = region_job.generate_source_file_coords(
processing_region_ds, groups[2]
)
# group 2 has wind_gust_10m (max-window, no hour 0) available from 2024-11-13.
# Nov 12 is skipped, and lead_time=0h is excluded, so 2 * 2 * 11 = 44.
assert len(group_2_source_file_coords) == 2 * 2 * 11
assert item(group_2_source_file_coords[0].data_var_group).name == "wind_gust_10m"


def test_region_job_download_file(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down Expand Up @@ -259,31 +267,46 @@ def test_operational_update_jobs(

@pytest.mark.slow
def test_download_file_from_ecmwf_open_data() -> None:
"""Download a recent ECMWF IFS ENS init time and read all template variables."""
"""Download a recent ECMWF IFS ENS init time and read all template variables at lead_times where they are present."""
template_config = EcmwfIfsEnsForecast15Day025DegreeTemplateConfig()
init_time = pd.Timestamp("2026-01-01T00:00")
# Use a recent date so the test catches format changes in the current ECMWF data
init_time = (pd.Timestamp.now() - pd.Timedelta(days=5)).floor("D")

full_template = template_config.get_template(init_time + pd.Timedelta(days=1))
region_job = EcmwfIfsEnsForecast15Day025DegreeRegionJob.model_construct(
tmp_store=Mock(),
template_ds=template_config.get_template(init_time),
template_ds=full_template,
data_vars=template_config.data_vars,
append_dim=template_config.append_dim,
region=slice(0, 1),
reformat_job_name="test",
)

# lead_time=3h: all instant and max-window vars are present
lead_time = pd.Timedelta(hours=3)
for group in EcmwfIfsEnsForecast15Day025DegreeRegionJob.source_groups(
template_config.data_vars
):
for data_var in group:
coord = EcmwfIfsEnsForecast15Day025DegreeSourceFileCoord(
init_time=init_time,
lead_time=lead_time,
data_var_group=[data_var],
ensemble_member=0,
# Test over lead_times [0h, 3h] to catch bugs where variables are missing from the
# index at certain lead times (e.g. 10fg/wind_gust is absent at lead_time=0h since
# it is a max-window variable with no prior post-processing step at t=0).
test_ds = full_template.isel(
init_time=slice(-1, None),
lead_time=slice(0, 2), # 0h and 3h
ensemble_member=slice(0, 1),
)

def check_data_var(data_var: EcmwfDataVar) -> None:
for source_coord in region_job.generate_source_file_coords(test_ds, [data_var]):
downloaded_coord = replace(
source_coord, downloaded_path=region_job.download_file(source_coord)
)
data = region_job.read_data(downloaded_coord, data_var)
assert np.all(np.isfinite(data)), (
f"Non-finite values for {data_var.name} at lead_time={source_coord.lead_time}"
)
coord = replace(coord, downloaded_path=region_job.download_file(coord))
data = region_job.read_data(coord, data_var)
assert np.all(np.isfinite(data)), f"Non-finite values for {data_var.name}"

all_data_vars = [
data_var
for group in EcmwfIfsEnsForecast15Day025DegreeRegionJob.source_groups(
template_config.data_vars
)
for data_var in group
]
with ThreadPoolExecutor() as executor:
list(executor.map(check_data_var, all_data_vars))
63 changes: 63 additions & 0 deletions tests/ecmwf/test_ecmwf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Literal

import pytest

from reformatters.common.config_models import DataVarAttrs, Encoding
from reformatters.common.zarr import BLOSC_4BYTE_ZSTD_LEVEL3_SHUFFLE
from reformatters.ecmwf.ecmwf_config_models import EcmwfDataVar, EcmwfInternalAttrs
from reformatters.ecmwf.ecmwf_utils import has_hour_0_values

StepType = Literal["instant", "accum", "avg", "min", "max"]


def _make_data_var(
step_type: StepType,
hour_0_values_override: bool | None = None,
) -> EcmwfDataVar:
return EcmwfDataVar(
name="test_var",
encoding=Encoding(
dtype="float32",
fill_value=float("nan"),
chunks=(1, 85, 51, 32, 32),
shards=None,
compressors=[BLOSC_4BYTE_ZSTD_LEVEL3_SHUFFLE],
),
attrs=DataVarAttrs(
short_name="test",
long_name="Test variable",
units="1",
step_type=step_type,
),
internal_attrs=EcmwfInternalAttrs(
grib_comment="test [unit]",
grib_description='0[-] SFC="Ground or water surface"',
grib_element="TEST",
grib_index_param="test",
keep_mantissa_bits=7,
hour_0_values_override=hour_0_values_override,
),
)


@pytest.mark.parametrize("step_type", ["instant", "avg", "accum"])
def test_has_hour_0_values_true_for_non_extremum_step_types(
step_type: StepType,
) -> None:
assert has_hour_0_values(_make_data_var(step_type)) is True


@pytest.mark.parametrize("step_type", ["max", "min"])
def test_has_hour_0_values_false_for_extremum_step_types(step_type: StepType) -> None:
assert has_hour_0_values(_make_data_var(step_type)) is False


def test_has_hour_0_values_override_true_overrides_step_type() -> None:
assert has_hour_0_values(_make_data_var("max", hour_0_values_override=True)) is True


def test_has_hour_0_values_override_false_overrides_step_type() -> None:
assert (
has_hour_0_values(_make_data_var("instant", hour_0_values_override=False))
is False
)