diff --git a/ufs2arco/datamover.py b/ufs2arco/datamover.py index b63af96..45f8280 100644 --- a/ufs2arco/datamover.py +++ b/ufs2arco/datamover.py @@ -115,6 +115,7 @@ def _next_data(self): open_static_vars=self.target.always_open_static_vars, cache_dir=cache_dir, ) + if len(fds) > 0: fds = self.transformer(fds) fds = self.target.apply_transforms_to_sample(fds) @@ -225,7 +226,7 @@ def create_container(self) -> None: return nds - def find_my_region(self, xds): + def find_my_region(self, xds, tds=None): """Given a dataset, that's assumed to be a subset of the initial dataset, find the logical index values where this should be stored in the final zarr store @@ -236,10 +237,16 @@ def find_my_region(self, xds): region (dict): indicating the zarr region to store in, based on the initial condition indices """ region = {k: slice(None, None) for k in xds.dims} - for key in self.target.renamed_sample_dims: - full_array = getattr(self.target, key) # e.g. all of the initial conditions - batch_indices = [list(full_array).index(value) for value in xds[key].values] - region[key] = slice(batch_indices[0], batch_indices[-1]+1) + if tds is None: + for key in self.target.renamed_sample_dims: + full_array = getattr(self.target, key) # e.g. all of the initial conditions + batch_indices = [list(full_array).index(value) for value in xds[key].values] + region[key] = slice(batch_indices[0], batch_indices[-1]+1) + else: + for key in tds.dims: + full_array = tds[key].values + batch_indices = [list(full_array).index(value) for value in xds[key].values] + region[key] = slice(batch_indices[0], batch_indices[-1]+1) return region diff --git a/ufs2arco/multidriver.py b/ufs2arco/multidriver.py index 74a86ec..ee8dc69 100644 --- a/ufs2arco/multidriver.py +++ b/ufs2arco/multidriver.py @@ -60,6 +60,7 @@ class MultiDriver(Driver): "directories", "multisource", "transforms", + "merged_transforms", "target", "attrs", ) @@ -131,6 +132,13 @@ def _init_transformer(self): self.transformers.append(transformer) + def _init_merged_transformer(self): + """Define operations to be performed on the merged dataset, for example temporal aggregation""" + self.merged_transformer = None + if "merged_transforms" in self.config: + self.merged_transformer = Transformer(options=self.config.get("merged_transforms")) + + def _init_target(self): """Make a copy of the same target format for each source @@ -138,15 +146,18 @@ def _init_target(self): """ name = self.config["target"].get("name", "base").lower() - try: - assert name == "anemoi" - except: - raise NotImplementedError("Driver._init_target: multisource workflows currently only work with Anemoi Targets.") + if name in ("forecast", "analysis", "base"): + TargetDataset = ufs2arco.targets.Target + elif name == "anemoi": + TargetDataset = ufs2arco.targets.Anemoi + else: + raise NotImplementedError(f"MultiDriver._init_targets: only 'base' and 'anemoi' are implemented") + self.targets = list() kwargs = self.target_kwargs.copy() for source in self.sources: self.targets.append( - ufs2arco.targets.Anemoi( + TargetDataset( source=source, **kwargs, ) @@ -169,6 +180,10 @@ def _init_mover(self): for source, target, transformer in zip(self.sources, self.targets, self.transformers) ] + def setup(self, runtype: str): + super().setup(runtype=runtype) + self._init_merged_transformer() + def write_container(self, overwrite): """Write empty zarr store, to be filled with data""" @@ -176,6 +191,9 @@ def write_container(self, overwrite): if self.topo.is_root: dslist = [mover.create_container() for mover in self.movers] cds = self.target.merge_multisource(dslist) + if self.merged_transformer is not None: + with xr.set_options(keep_attrs=True): + cds = self.merged_transformer(cds) kwargs = {"mode": "w"} if overwrite else {} logger.info(f"Driver.write_container: storing container at {self.store_path}\n{cds}\n") @@ -201,6 +219,9 @@ def run(self, overwrite: bool = False): if self.mover.start == 0: self.write_container(overwrite=overwrite) + # is this bad? open target dataset to get the dimensions + tds = xr.open_zarr(self.store_path, decode_timedelta=True).coords.to_dataset() + # loop through batches n_batches = len(self.mover) missing_dims = [] @@ -230,8 +251,11 @@ def run(self, overwrite: bool = False): if all(foundit) and len(foundit) == len(self.movers): mds = self.target.merge_multisource(dslist) + if self.merged_transformer is not None: + with xr.set_options(keep_attrs=True): + mds = self.merged_transformer(mds) - region = self.mover.find_my_region(mds) + region = self.mover.find_my_region(mds, tds) mds.to_zarr(self.target.store_path, region=region) self.mover.clear_cache(batch_idx) diff --git a/ufs2arco/targets/base.py b/ufs2arco/targets/base.py index e3da03e..3d45706 100644 --- a/ufs2arco/targets/base.py +++ b/ufs2arco/targets/base.py @@ -268,3 +268,9 @@ def handle_missing_data(self, missing_data: list[dict]) -> None: zds = zarr.open(self.store_path, mode="a") zds.attrs["missing_data"] = missing_data zarr.consolidate_metadata(self.store_path) + + + def merge_multisource(self, dslist: list[xr.Dataset]) -> xr.Dataset: + """Take a list of datasets, each from their own source, and merge them""" + result = xr.merge(dslist) + return result diff --git a/ufs2arco/transforms/transformer.py b/ufs2arco/transforms/transformer.py index 7c61c85..b92085e 100644 --- a/ufs2arco/transforms/transformer.py +++ b/ufs2arco/transforms/transformer.py @@ -16,10 +16,11 @@ def implemented(self) -> tuple: return ( "multiply", "divide", + "rotate_vectors", + "xarray_coarsen", "fv_vertical_regrid", "horizontal_regrid", "mappings", - "rotate_vectors", ) def __init__(self, options): @@ -85,6 +86,9 @@ def __call__(self, xds: xr.Dataset): if "rotate_vectors" in self.names: xds = rotate_vectors(xds, **self.options["rotate_vectors"]) + if "xarray_coarsen" in self.names: + xds = xarray_coarsen(xds, **self.options["xarray_coarsen"]) + if "fv_vertical_regrid" in self.names: xds = fv_vertical_regrid(xds, **self.options["fv_vertical_regrid"]) @@ -135,3 +139,20 @@ def divide(xds, config): if varname in xds: xds[varname] = xds[varname] / scalar return xds + + +def xarray_coarsen(xds: xr.Dataset, reduction: str, **kwargs) -> xr.Dataset: + """ + Apply ``xarray.coarsen`` to the dataset + + Args: + xds (xr.Dataset): + + """ + cds = xds.coarsen(**kwargs) + try: + reduce_method = getattr(cds, reduction) + except AttributeError: + raise AttributeError(f"ufs2arco.transforms.xarray_coarsen: '{reduction}' is not a valid method on the coarsened dataset.") + cds = reduce_method() + return cds