Skip to content

Commit

Permalink
Metadata object (#52)
Browse files Browse the repository at this point in the history
## Purpose

Use the standalone metadata object to store field metadata instead of
the byte string of a sample GRIB message. This change was required by a
breaking change to the `write` function which was used to update the
message attribute.

## Code changes:

- `DataArray` objects representing a field no longer have the `message`
attribute but now have a `metadata` attribute that is an instance of
`StandAloneGribMetadata` from earthkit-data.

## Requirements changes:

- earthkit-data range extended to `>=0.5.6,<1`
  • Loading branch information
cfkanesan authored Nov 29, 2024
1 parent 239df34 commit 51ad97b
Show file tree
Hide file tree
Showing 24 changed files with 1,691 additions and 1,098 deletions.
2,538 changes: 1,614 additions & 924 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
click = "^8.1.7"
earthkit-data = "^0.5.6"
earthkit-data = ">=0.5.6,<1"
eccodes = "^1.5.0"
numpy = "^1.26.4"
polytope-client = "^0.7.4"
Expand All @@ -37,6 +37,7 @@ pyproj = "^3.6.1"
pyyaml = "^6.0.1"
rasterio = "^1.3.10"
scipy = "^1.13"
setuptools = "*"
xarray = ">=2024"

[tool.poetry.group.dev.dependencies]
Expand Down
16 changes: 8 additions & 8 deletions src/meteodatalab/grib_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def load(self, field: GribField) -> None:
self.values[key] = field.to_numpy(dtype=np.float32)

if not self.metadata:
md = field.metadata().override()
self.metadata = {
"message": field.message(), # try field.metadata.override()
**metadata.extract(field.metadata()),
"metadata": md,
**metadata.extract(md),
}

if not self.hcoords:
Expand Down Expand Up @@ -359,7 +360,7 @@ def load(
if extract_pv not in requests:
msg = f"{extract_pv=} was not a key of the given requests."
raise RuntimeError(msg)
return result | metadata.extract_pv(result[extract_pv].message)
return result | metadata.extract_pv(result[extract_pv].metadata)
return result

def load_fieldnames(
Expand Down Expand Up @@ -390,15 +391,14 @@ def save(
Raises
------
ValueError
If the field does not have a message attribute.
If the field does not have a metadata attribute.
"""
if not hasattr(field, "message"):
msg = "The message attribute is required to write to the GRIB format."
if not hasattr(field, "metadata"):
msg = "The metadata attribute is required to write to the GRIB format."
raise ValueError(msg)

stream = io.BytesIO(field.message)
[md] = (f.metadata() for f in ekd.from_source("stream", stream))
md = field.metadata

idx = {
dim: field.coords[key]
Expand Down
115 changes: 30 additions & 85 deletions src/meteodatalab/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

# Standard library
import dataclasses as dc
import io
import typing

# Third-party
import earthkit.data as ekd # type: ignore
import numpy as np
import xarray as xr
from earthkit.data.writers import write # type: ignore
from earthkit.data.core.metadata import Metadata # type: ignore

# Local
from . import grib_decoder
Expand All @@ -21,7 +19,7 @@
}


def extract(metadata):
def extract(metadata: Metadata) -> dict[str, typing.Any]:
if metadata.get("gridType") == "unstructured_grid":
vref_flag = False
else:
Expand All @@ -41,41 +39,36 @@ def extract(metadata):
}


def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata contained in message.
def override(metadata: Metadata, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata.
Note that no special consideration is made for maintaining consistency when
overriding template definition keys such as productDefinitionTemplateNumber.
Note that the origin components in x and y are left untouched.
Parameters
----------
message : bytes
Byte string of the input GRIB message
metadata : Metadata
Metadata of the input GRIB metadata
kwargs : Any
Keyword arguments forwarded to earthkit-data GribMetadata override method
Returns
-------
dict[str, Any]
Updated message byte string along with the geography and parameter namespaces
Updated metadata along with the geography and parameter namespaces
"""
stream = io.BytesIO(message)
[grib_field] = ekd.from_source("stream", stream)

if grib_field.metadata("editionNumber") == 1:
if metadata["editionNumber"] == 1:
return {
"message": message,
**extract(grib_field.metadata()),
"metadata": metadata,
**extract(metadata),
}

out = io.BytesIO()
md = grib_field.metadata().override(**kwargs)
write(out, grib_field.values, md)
md = metadata.override(**kwargs)

return {
"message": out.getvalue(),
"metadata": md,
**extract(md),
}

Expand All @@ -97,28 +90,23 @@ class Grid:
lat_first_grid_point: float


def load_grid_reference(message: bytes) -> Grid:
def load_grid_reference(metadata: Metadata) -> Grid:
"""Construct a grid from a reference parameter.
Parameters
----------
message : bytes
GRIB message defining the reference grid.
metadata : Metadata
GRIB metadata defining the reference grid.
Returns
-------
Grid
reference grid
"""
stream = io.BytesIO(message)
[grib_field] = ekd.from_source("stream", stream)

return Grid(
*grib_field.metadata(
"longitudeOfFirstGridPointInDegrees",
"latitudeOfFirstGridPointInDegrees",
),
metadata["longitudeOfFirstGridPointInDegrees"],
metadata["latitudeOfFirstGridPointInDegrees"],
)


Expand Down Expand Up @@ -171,29 +159,26 @@ def set_origin_xy(ds: dict[str, xr.DataArray], ref_param: str) -> None:
if ref_param not in ds:
raise KeyError(f"ref_param {ref_param} not present in dataset.")

ref_grid = load_grid_reference(ds[ref_param].message)
ref_grid = load_grid_reference(ds[ref_param].metadata)
for field in ds.values():
field.attrs |= compute_origin(ref_grid, field)


def extract_pv(message: bytes) -> dict[str, xr.DataArray]:
def extract_pv(metadata: Metadata) -> dict[str, xr.DataArray]:
"""Extract hybrid level coefficients.
Parameters
----------
message : bytes
GRIB message containing the pv metadata.
metadata : Metadata
GRIB metadata containing the pv metadata.
Returns
-------
dict[str, xarray.DataArray]
Hybrid level coefficients.
"""
stream = io.BytesIO(message)
[grib_field] = ekd.from_source("stream", stream)

pv = grib_field.metadata("pv")
pv = metadata.get("pv")

if pv is None:
return {}
Expand All @@ -205,64 +190,24 @@ def extract_pv(message: bytes) -> dict[str, xr.DataArray]:
}


def extract_hcoords(message: bytes) -> dict[str, xr.DataArray]:
def extract_hcoords(metadata: Metadata) -> dict[str, xr.DataArray]:
"""Extract horizontal coordinates.
Parameters
----------
message : bytes
GRIB message containing the grid definition.
metadata : Metadata
GRIB metadata containing the grid definition.
Returns
-------
dict[str, xarray.DataArray]
Horizontal coordinates in geolatlon.
"""
stream = io.BytesIO(message)
[grib_field] = ekd.from_source("stream", stream)

geo = metadata.geography
return {
dim: xr.DataArray(dims=("y", "x"), data=values)
for dim, values in grib_field.to_latlon().items()
"lat": xr.DataArray(dims=("y", "x"), data=geo.latitudes().reshape(geo.shape())),
"lon": xr.DataArray(
dims=("y", "x"), data=geo.longitudes().reshape(geo.shape())
),
}


def extract_keys(message: bytes, keys: typing.Any, single: bool = True) -> typing.Any:
"""Extract keys from the GRIB message.
Parameters
----------
message : bytes
The GRIB message.
keys : Any
Keys for which to extract values from the message.
single : bool, optional
Whether a single GRIB message should be expected.
Raises
------
ValueError
if keys is None because the resulting metadata would point
to an eccodes handle that no longer exists resulting in a
possible segmentation fault
Returns
-------
Any
Single value if keys is a single value, tuple of values if
keys is a tuple, list of values if keys is a list. The type of
the value depends on the default type for the given key in eccodes.
If single is false, the above is returned within a list.
"""
if keys is None:
raise ValueError("keys must be specified.")
stream = io.BytesIO(message)
source = ekd.from_source("stream", stream)

if single:
[grib_field] = source
return grib_field.metadata(keys)

return [grib_field.metadata(keys) for grib_field in source]
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/brn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def fbrn(

return xr.DataArray(
data=brn,
attrs=metadata.override(p.message, shortName="BRN"),
attrs=metadata.override(p.metadata, shortName="BRN"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def crop(field: xr.DataArray, bounds: Bounds) -> xr.DataArray:
return xr.DataArray(
field.isel(x=slice(xmin, xmax + 1), y=slice(ymin, ymax + 1)),
attrs=metadata.override(
field.message,
field.metadata,
longitudeOfFirstGridPoint=lon_min,
longitudeOfLastGridPoint=lon_max,
Ni=ni,
Expand Down
8 changes: 4 additions & 4 deletions src/meteodatalab/operators/destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
lon_max = np.round(geo["longitudeOfLastGridPointInDegrees"] * 1e6)
dx = np.round(geo["iDirectionIncrementInDegrees"] * 1e6)
return metadata.override(
field.message,
field.metadata,
longitudeOfFirstGridPoint=lon_min - dx / 2,
longitudeOfLastGridPoint=lon_max - dx / 2,
)
Expand All @@ -90,7 +90,7 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
lat_max = np.round(geo["latitudeOfLastGridPointInDegrees"] * 1e6)
dy = np.round(geo["jDirectionIncrementInDegrees"] * 1e6)
return metadata.override(
field.message,
field.metadata,
latitudeOfFirstGridPoint=lat_min - dy / 2,
latitudeOfLastGridPoint=lat_max - dy / 2,
)
Expand All @@ -100,7 +100,7 @@ def _update_vertical(field) -> dict[str, Any]:
if field.vcoord_type != "model_level":
raise ValueError("typeOfLevel must equal generalVertical")
return metadata.override(
field.message,
field.metadata,
typeOfLevel="generalVerticalLayer",
)

Expand Down Expand Up @@ -151,7 +151,7 @@ def destagger(
)
.transpose(*dims)
.assign_attrs({f"origin_{dim}": 0.0}, **attrs)
.assign_coords(metadata.extract_hcoords(attrs["message"]))
.assign_coords(metadata.extract_hcoords(attrs["metadata"]))
)
elif dim == "z":
if field.origin_z != -0.5:
Expand Down
4 changes: 2 additions & 2 deletions src/meteodatalab/operators/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ def vref_rot2geolatlon(
xr.DataArray(
u_g,
attrs=metadata.override(
u.message, resolutionAndComponentFlags=resolution_components_flags
u.metadata, resolutionAndComponentFlags=resolution_components_flags
),
),
xr.DataArray(
v_g,
attrs=metadata.override(
v.message, resolutionAndComponentFlags=resolution_components_flags
v.metadata, resolutionAndComponentFlags=resolution_components_flags
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/hzerocl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ def fhzerocl(

return xr.DataArray(
data=hzerocl.where(hzerocl > 0),
attrs=metadata.override(t.message, shortName="HZEROCL"),
attrs=metadata.override(t.metadata, shortName="HZEROCL"),
)
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/omega_slope.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def omega_slope(
data=res,
attrs=metadata.override(
# Eta-coordinate vertical velocity
etadot.message,
etadot.metadata,
discipline=0,
parameterCategory=2,
parameterNumber=32,
Expand Down
2 changes: 1 addition & 1 deletion src/meteodatalab/operators/pot_vortic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def compute_pot_vortic(
dt_dlam * curl1 + dt_dphi * (curl2 + cor2) - dt_dzeta * (curl3 + cor3)
) / rho_tot

out.attrs = metadata.override(theta.message, shortName="POT_VORTIC")
out.attrs = metadata.override(theta.metadata, shortName="POT_VORTIC")

return out
4 changes: 2 additions & 2 deletions src/meteodatalab/operators/radiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def compute_athd_s(athb_s: xr.DataArray, tsurf: xr.DataArray) -> xr.DataArray:
"""
return xr.DataArray(
data=athb_s / pc.emissivity_surface + pc.boltzman_cst * tsurf**4,
attrs=metadata.override(athb_s.message, shortName="ATHD_S"),
attrs=metadata.override(athb_s.metadata, shortName="ATHD_S"),
)


Expand All @@ -48,5 +48,5 @@ def compute_swdown(diffuse: xr.DataArray, direct: xr.DataArray) -> xr.DataArray:
"""
return xr.DataArray(
data=(diffuse + direct).clip(min=0),
attrs=metadata.override(diffuse.message, shortName="ASOD_S"),
attrs=metadata.override(diffuse.metadata, shortName="ASOD_S"),
)
Loading

0 comments on commit 51ad97b

Please sign in to comment.