Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions ufs2arco/datamover.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _next_data(self):
open_static_vars=self.target.always_open_static_vars,
cache_dir=cache_dir,
)

if len(fds) > 0:
fds = self.transformer(fds)
fds = self.target.apply_transforms_to_sample(fds)
Expand Down Expand Up @@ -225,7 +226,7 @@ def create_container(self) -> None:
return nds


def find_my_region(self, xds):
def find_my_region(self, xds, tds=None):
"""Given a dataset, that's assumed to be a subset of the initial dataset,
find the logical index values where this should be stored in the final zarr store

Expand All @@ -236,10 +237,16 @@ def find_my_region(self, xds):
region (dict): indicating the zarr region to store in, based on the initial condition indices
"""
region = {k: slice(None, None) for k in xds.dims}
for key in self.target.renamed_sample_dims:
full_array = getattr(self.target, key) # e.g. all of the initial conditions
batch_indices = [list(full_array).index(value) for value in xds[key].values]
region[key] = slice(batch_indices[0], batch_indices[-1]+1)
if tds is None:
for key in self.target.renamed_sample_dims:
full_array = getattr(self.target, key) # e.g. all of the initial conditions
batch_indices = [list(full_array).index(value) for value in xds[key].values]
region[key] = slice(batch_indices[0], batch_indices[-1]+1)
else:
for key in tds.dims:
full_array = tds[key].values
batch_indices = [list(full_array).index(value) for value in xds[key].values]
region[key] = slice(batch_indices[0], batch_indices[-1]+1)
return region


Expand Down
36 changes: 30 additions & 6 deletions ufs2arco/multidriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MultiDriver(Driver):
"directories",
"multisource",
"transforms",
"merged_transforms",
"target",
"attrs",
)
Expand Down Expand Up @@ -131,22 +132,32 @@ def _init_transformer(self):
self.transformers.append(transformer)


def _init_merged_transformer(self):
"""Define operations to be performed on the merged dataset, for example temporal aggregation"""
self.merged_transformer = None
if "merged_transforms" in self.config:
self.merged_transformer = Transformer(options=self.config.get("merged_transforms"))


def _init_target(self):
"""Make a copy of the same target format for each source

Note that we only compute any forcings with the first data source
"""

name = self.config["target"].get("name", "base").lower()
try:
assert name == "anemoi"
except:
raise NotImplementedError("Driver._init_target: multisource workflows currently only work with Anemoi Targets.")
if name in ("forecast", "analysis", "base"):
TargetDataset = ufs2arco.targets.Target
elif name == "anemoi":
TargetDataset = ufs2arco.targets.Anemoi
else:
raise NotImplementedError(f"MultiDriver._init_targets: only 'base' and 'anemoi' are implemented")

self.targets = list()
kwargs = self.target_kwargs.copy()
for source in self.sources:
self.targets.append(
ufs2arco.targets.Anemoi(
TargetDataset(
source=source,
**kwargs,
)
Expand All @@ -169,13 +180,20 @@ def _init_mover(self):
for source, target, transformer in zip(self.sources, self.targets, self.transformers)
]

def setup(self, runtype: str):
super().setup(runtype=runtype)
self._init_merged_transformer()


def write_container(self, overwrite):
"""Write empty zarr store, to be filled with data"""

if self.topo.is_root:
dslist = [mover.create_container() for mover in self.movers]
cds = self.target.merge_multisource(dslist)
if self.merged_transformer is not None:
with xr.set_options(keep_attrs=True):
cds = self.merged_transformer(cds)

kwargs = {"mode": "w"} if overwrite else {}
logger.info(f"Driver.write_container: storing container at {self.store_path}\n{cds}\n")
Expand All @@ -201,6 +219,9 @@ def run(self, overwrite: bool = False):
if self.mover.start == 0:
self.write_container(overwrite=overwrite)

# is this bad? open target dataset to get the dimensions
tds = xr.open_zarr(self.store_path, decode_timedelta=True).coords.to_dataset()

# loop through batches
n_batches = len(self.mover)
missing_dims = []
Expand Down Expand Up @@ -230,8 +251,11 @@ def run(self, overwrite: bool = False):
if all(foundit) and len(foundit) == len(self.movers):

mds = self.target.merge_multisource(dslist)
if self.merged_transformer is not None:
with xr.set_options(keep_attrs=True):
mds = self.merged_transformer(mds)

region = self.mover.find_my_region(mds)
region = self.mover.find_my_region(mds, tds)
mds.to_zarr(self.target.store_path, region=region)

self.mover.clear_cache(batch_idx)
Expand Down
6 changes: 6 additions & 0 deletions ufs2arco/targets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,9 @@ def handle_missing_data(self, missing_data: list[dict]) -> None:
zds = zarr.open(self.store_path, mode="a")
zds.attrs["missing_data"] = missing_data
zarr.consolidate_metadata(self.store_path)


def merge_multisource(self, dslist: list[xr.Dataset]) -> xr.Dataset:
"""Take a list of datasets, each from their own source, and merge them"""
result = xr.merge(dslist)
return result
23 changes: 22 additions & 1 deletion ufs2arco/transforms/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def implemented(self) -> tuple:
return (
"multiply",
"divide",
"rotate_vectors",
"xarray_coarsen",
"fv_vertical_regrid",
"horizontal_regrid",
"mappings",
"rotate_vectors",
)

def __init__(self, options):
Expand Down Expand Up @@ -85,6 +86,9 @@ def __call__(self, xds: xr.Dataset):
if "rotate_vectors" in self.names:
xds = rotate_vectors(xds, **self.options["rotate_vectors"])

if "xarray_coarsen" in self.names:
xds = xarray_coarsen(xds, **self.options["xarray_coarsen"])

if "fv_vertical_regrid" in self.names:
xds = fv_vertical_regrid(xds, **self.options["fv_vertical_regrid"])

Expand Down Expand Up @@ -135,3 +139,20 @@ def divide(xds, config):
if varname in xds:
xds[varname] = xds[varname] / scalar
return xds


def xarray_coarsen(xds: xr.Dataset, reduction: str, **kwargs) -> xr.Dataset:
"""
Apply ``xarray.coarsen`` to the dataset

Args:
xds (xr.Dataset):

"""
cds = xds.coarsen(**kwargs)
try:
reduce_method = getattr(cds, reduction)
except AttributeError:
raise AttributeError(f"ufs2arco.transforms.xarray_coarsen: '{reduction}' is not a valid method on the coarsened dataset.")
cds = reduce_method()
return cds
Loading