diff --git a/.gitignore b/.gitignore index 9e4682d3..68c44aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,4 @@ _version.py *.gif *.zarr/ /*-plots/ +/definitions*/ diff --git a/src/anemoi/inference/commands/retrieve.py b/src/anemoi/inference/commands/retrieve.py index 86b85426..9f12b47f 100644 --- a/src/anemoi/inference/commands/retrieve.py +++ b/src/anemoi/inference/commands/retrieve.py @@ -31,6 +31,7 @@ def add_arguments(self, command_parser): command_parser.add_argument("--output", type=str, help="Output file") command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates") command_parser.add_argument("--extra", action="append", help="Additional request values. Can be repeated") + command_parser.add_argument("--retrieve-fields-type", type=str, help="Type of fields to retrieve") command_parser.add_argument("overrides", nargs="*", help="Overrides.") def run(self, args): @@ -45,6 +46,18 @@ def run(self, args): area = runner.checkpoint.area grid = runner.checkpoint.grid + if args.retrieve_fields_type is not None: + selected = set() + + for name, kinds in runner.checkpoint.variable_categories().items(): + if "computed" in kinds: + continue + for kind in kinds: + if args.retrieve_fields_type.startswith(kind): # PrepML adds an 's' to the type + selected.add(name) + + variables = sorted(selected) + extra = postproc(grid, area) for r in args.extra or []: diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index f62a9c8a..b34e9938 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -88,6 +88,9 @@ class Config: development_hacks: dict = {} """A dictionary of development hacks to apply to the runner. This is used to test new features or to work around""" + debugging_info: dict = {} + """A dictionary to store debug information. This is ignored.""" + def load_config(path, overrides, defaults=None, Configuration=Configuration): diff --git a/src/anemoi/inference/grib/encoding.py b/src/anemoi/inference/grib/encoding.py index 21d25fed..a60d579a 100644 --- a/src/anemoi/inference/grib/encoding.py +++ b/src/anemoi/inference/grib/encoding.py @@ -86,11 +86,10 @@ def grib_keys( result.update(grib2_keys.get(param, {})) result.setdefault("type", "fc") - type = result.get("type") - if type is not None: + if result.get("type") in ("an", "fc"): # For organisations that do not use type - result.setdefault("dataType", type) + result.setdefault("dataType", result.pop("type")) # if stream is not None: # result.setdefault("stream", stream) diff --git a/src/anemoi/inference/outputs/grib.py b/src/anemoi/inference/outputs/grib.py index 7d5dcff8..81f02089 100644 --- a/src/anemoi/inference/outputs/grib.py +++ b/src/anemoi/inference/outputs/grib.py @@ -21,12 +21,58 @@ LOG = logging.getLogger(__name__) +class HindcastOutput: + + def __init__(self, reference_year): + self.reference_year = reference_year + + def __call__(self, values, template, keys): + + if "date" not in keys: + assert template.metadata("hdate", default=None) is None, template + date = template.metadata("date") + else: + date = keys.pop("date") + + for k in ("date", "hdate"): + keys.pop(k, None) + + keys["edition"] = 1 + keys["localDefinitionNumber"] = 30 + keys["dataDate"] = int(to_datetime(date).strftime("%Y%m%d")) + keys["referenceDate"] = int(to_datetime(date).replace(year=self.reference_year).strftime("%Y%m%d")) + + return values, template, keys + + +MODIFIERS = dict(hindcast=HindcastOutput) + + +def modifier_factory(modifiers): + + if modifiers is None: + return [] + + if not isinstance(modifiers, list): + modifiers = [modifiers] + + result = [] + for modifier in modifiers: + assert isinstance(modifier, dict), modifier + assert len(modifier) == 1, modifier + + klass = list(modifier.keys())[0] + result.append(MODIFIERS[klass](**modifier[klass])) + + return result + + class GribOutput(Output): """ Handles grib """ - def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None): + def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None): super().__init__(context) self._first = True self.typed_variables = self.checkpoint.typed_variables @@ -40,6 +86,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g self._template_date = None self._template_reuse = None self.use_closest_template = False # Off for now + self.modifiers = modifier_factory(modifiers) def write_initial_state(self, state): # We trust the GribInput class to provide the templates @@ -76,7 +123,8 @@ def write_initial_state(self, state): quiet=self.quiet, ) - # LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4)) + for modifier in self.modifiers: + values, template, keys = modifier(values, template, keys) self.write_message(values, template=template, **keys) @@ -95,7 +143,7 @@ def write_state(self, state): self.quiet.add("_grib_templates_for_output") LOG.warning("Input is not GRIB.") - for name, value in state["fields"].items(): + for name, values in state["fields"].items(): keys = {} variable = self.typed_variables[name] @@ -118,7 +166,7 @@ def write_state(self, state): keys.update(self.encoding) keys = grib_keys( - values=value, + values=values, template=template, date=reference_date.strftime("%Y-%m-%d"), time=reference_date.hour, @@ -131,11 +179,14 @@ def write_state(self, state): quiet=self.quiet, ) + for modifier in self.modifiers: + values, template, keys = modifier(values, template, keys) + if LOG.isEnabledFor(logging.DEBUG): LOG.info("Encoding GRIB %s\n%s", template, json.dumps(keys, indent=4)) try: - self.write_message(value, template=template, **keys) + self.write_message(values, template=template, **keys) except Exception: LOG.error("Error writing field %s", name) LOG.error("Template: %s", template) diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index 9fc99953..a31b8bfd 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -22,6 +22,34 @@ LOG = logging.getLogger(__name__) +# There is a bug with hindcasts, where these keys are not added to the 'mars' namespace +MARS_MAYBE_MISSING_KEYS = ( + "number", + "step", + "time", + "date", + "hdate", + "type", + "stream", + "expver", + "class", + "levtype", + "levelist", + "param", +) + + +def _is_valid(mars, keys): + if "number" in keys and "number" not in mars: + LOG.warning("`number` is missing from mars namespace") + return False + + if "referenceDate" in keys and "hdate" not in mars: + LOG.warning("`hdate` is missing from mars namespace") + return False + + return True + class ArchiveCollector: """Collects archive requests""" @@ -61,19 +89,42 @@ def __init__( templates=None, grib1_keys=None, grib2_keys=None, + modifiers=None, **kwargs, ): - super().__init__(context, encoding=encoding, templates=templates, grib1_keys=grib1_keys, grib2_keys=grib2_keys) + super().__init__( + context, + encoding=encoding, + templates=templates, + grib1_keys=grib1_keys, + grib2_keys=grib2_keys, + modifiers=modifiers, + ) self.path = path self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs) self.archiving = defaultdict(ArchiveCollector) self.archive_requests = archive_requests self.check_encoding = check_encoding + self._namespace_bug_fix = False def __repr__(self): return f"GribFileOutput({self.path})" def write_message(self, message, template, **keys): + # Make sure `name` is not in the keys, otherwise grib_encoding will fail + if template is not None and template.metadata("name", default=None) is not None: + # We cannot clear the metadata... + class Dummy: + def __init__(self, template): + self.template = template + self.handle = template.handle + + def __repr__(self): + return f"Dummy({self.template})" + + template = Dummy(template) + + # LOG.info("Writing message to %s %s", template, keys) try: self.collect_archive_requests( self.output.write( @@ -90,6 +141,7 @@ def write_message(self, message, template, **keys): LOG.error("Error writing message to %s", self.path) LOG.error("eccodes: %s", eccodes.__version__) + LOG.error("Template: %s, Keys: %s", template, keys) LOG.error("Exception: %s", e) if message is not None and np.isnan(message.data).any(): LOG.error("Message contains NaNs (%s, %s) (allow_nans=%s)", keys, template, self.context.allow_nans) @@ -102,7 +154,25 @@ def collect_archive_requests(self, written, template, **keys): handle, path = written - mars = handle.as_namespace("mars") + while True: + + if self._namespace_bug_fix: + import eccodes + from earthkit.data.readers.grib.codes import GribCodesHandle + + handle = GribCodesHandle(eccodes.codes_clone(handle._handle), None, None) + + mars = {k: v for k, v in handle.items("mars")} + + if _is_valid(mars, keys): + break + + if self._namespace_bug_fix: + raise ValueError("Namespace bug: %s" % mars) + + # Try again with the namespace bug + LOG.warning("Namespace bug detected, trying again") + self._namespace_bug_fix = True if self.check_encoding: check_encoding(handle, keys)