diff --git a/workflows/argo/offline-diags.yaml b/workflows/argo/offline-diags.yaml index 325c1e4510..41c7546f90 100644 --- a/workflows/argo/offline-diags.yaml +++ b/workflows/argo/offline-diags.yaml @@ -23,6 +23,7 @@ spec: - name: test_data_config - name: offline-diags-output - name: report-output + - {name: offline-diags-flags, value: " "} - {name: no-wandb, value: "false"} - {name: wandb-project, value: "argo-default"} - {name: wandb-tags, value: ""} @@ -69,7 +70,8 @@ spec: python -m fv3net.diagnostics.offline.compute \ {{inputs.parameters.ml-model}} \ test_data.yaml \ - {{inputs.parameters.offline-diags-output}} + {{inputs.parameters.offline-diags-output}} \ + {{inputs.parameters.offline-diags-flags}} cat << EOF > training.yaml {{inputs.parameters.training_config}} diff --git a/workflows/argo/train-diags-prog.yaml b/workflows/argo/train-diags-prog.yaml index 592011ad5f..0d53898ac6 100644 --- a/workflows/argo/train-diags-prog.yaml +++ b/workflows/argo/train-diags-prog.yaml @@ -34,6 +34,7 @@ spec: - {name: memory-training, value: 6Gi} - {name: memory-offline-diags, value: 10Gi} - {name: training-flags, value: " "} + - {name: offline-diags-flags, value: " "} - {name: online-diags-flags, value: " "} - {name: do-prognostic-run, value: "true"} - {name: no-wandb, value: "false"} @@ -115,6 +116,8 @@ spec: value: "{{inputs.parameters.tag}},{{inputs.parameters.wandb-tags}}" - name: wandb-group value: "{{inputs.parameters.wandb-group}}" + - name: offline-diags-flags + value: "{{inputs.parameters.offline-diags-flags}}" - name: insert-model-urls when: "{{inputs.parameters.do-prognostic-run}} == true" dependencies: [resolve-output-url] diff --git a/workflows/diagnostics/fv3net/diagnostics/offline/compute.py b/workflows/diagnostics/fv3net/diagnostics/offline/compute.py index a1862d14d7..6565961787 100644 --- a/workflows/diagnostics/fv3net/diagnostics/offline/compute.py +++ b/workflows/diagnostics/fv3net/diagnostics/offline/compute.py @@ -1,6 +1,7 @@ import argparse import json import logging +import numpy as np import os import sys from tempfile import NamedTemporaryFile @@ -103,6 +104,14 @@ def _get_parser() -> argparse.ArgumentParser: default=-1, help=("Optional n_jobs parameter for joblib.parallel when computing metrics."), ) + parser.add_argument( + "--outputs-2d-only", + action="store_true", + help=( + "Use flag if all model outputs are 2D, in which case pressure thickness" + "does not have to be in test dataset." + ), + ) return parser @@ -144,6 +153,7 @@ def _compute_diagnostics( target = safe.get_variables( ds.sel({DERIVATION_DIM_NAME: TARGET_COORD}), full_predicted_vars ) + ds_summary = compute_diagnostics(prediction, target, grid, ds[DELP], n_jobs=n_jobs) timesteps.append(ds["time"]) @@ -214,6 +224,15 @@ def transform(ds): return transform +def _fill_delp(): + # filler for delp array that is require as an arg but not used in diagnostics + def transform(ds): + ds[DELP] = xr.DataArray([np.nan, np.nan]) + return ds + + return transform + + def _get_data_mapper_if_exists(config): if isinstance(config, loaders.BatchesFromMapperConfig): return config.load_mapper() @@ -221,10 +240,10 @@ def _get_data_mapper_if_exists(config): return None -def _variables_to_load(model): - vars = list( - set(list(model.input_variables) + list(model.output_variables) + [DELP]) - ) +def _variables_to_load(model, outputs_2d_only=False): + vars = list(set(list(model.input_variables) + list(model.output_variables))) + if outputs_2d_only is False: + vars.append(DELP) if "Q2" in model.output_variables: vars.append("water_vapor_path") return vars @@ -258,12 +277,14 @@ def get_prediction( config: loaders.BatchesFromMapperConfig, model: fv3fit.Predictor, evaluation_resolution: int, + outputs_2d_only: bool, ) -> xr.Dataset: - model_variables = _variables_to_load(model) + model_variables = _variables_to_load(model, outputs_2d_only) + if config.timesteps: config.timesteps = sorted(config.timesteps) - batches = config.load_batches(model_variables) + batches = config.load_batches(model_variables) transforms = [_get_predict_function(model, model_variables)] prediction_resolution = res_from_string(config.res) @@ -274,11 +295,25 @@ def get_prediction( prediction_resolution=prediction_resolution, ) ) + if outputs_2d_only: + transforms.append(_fill_delp()) mapping_function = compose_left(*transforms) batches = loaders.Map(mapping_function, batches) - - concatted_batches = _daskify_sequence(batches) + try: + concatted_batches = _daskify_sequence(batches) + except KeyError as e: + key_err = str(e) + if "pressure_thickness_of_atmospheric_layer" in key_err: + raise KeyError( + "Variable 'pressure_thickness_of_atmospheric_layer' " + "not in dataset. If outputs are 2D and this variable " + "is not needed for diagnostics, include the CLI flag " + "--outputs-2d-only. If outputs are 3D, make sure this " + "variable is present in the test dataset." + ) + else: + raise KeyError(key_err) del batches return concatted_batches @@ -286,7 +321,7 @@ def get_prediction( def _daskify_sequence(batches): temp_data_dir = temporary_directory() for i, batch in enumerate(batches): - logger.info(f"Locally caching batch {i+1}/{len(batches)+1}") + logger.info(f"Locally caching batch {i+1}/{len(batches)}") batch.to_netcdf(os.path.join(temp_data_dir.name, f"{i}.nc")) dask_ds = xr.open_mfdataset(os.path.join(temp_data_dir.name, "*.nc")) return dask_ds @@ -314,7 +349,10 @@ def main(args): model = fv3fit.DerivedModel(model, derived_output_variables=["Q2"]) ds_predicted = get_prediction( - config=config, model=model, evaluation_resolution=evaluation_grid.sizes["x"] + config=config, + model=model, + evaluation_resolution=evaluation_grid.sizes["x"], + outputs_2d_only=args.outputs_2d_only, ) output_data_yaml = os.path.join(args.output_path, "data_config.yaml")