-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First Pass Zonal Average #82
base: main
Are you sure you want to change the base?
Changes from 1 commit
e2a6d50
c09eadf
ff4bc68
6fcd961
3e6b24c
8a0bd6d
4732c90
15488ca
f950017
acf689c
db5c0d2
886790c
d4bb79c
31698c1
4fceca6
2a9a0a9
6e14014
b6623f1
284997e
8a831a3
9726975
eba88b3
424c05f
69e4193
2d017cc
68a2a10
bbeeecb
4ff8a80
d415190
8d54ba7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import os | ||
import warnings | ||
|
||
import numpy as np | ||
import xarray as xr | ||
import xesmf as xe | ||
|
||
from .grid import get_grid | ||
|
||
|
||
def _generate_dest_grid(dy=None, dx=None, method_gen_grid='regular_lat_lon'): | ||
""" | ||
Generates the destination grid | ||
|
||
Parameters | ||
---------- | ||
dy: float | ||
Horizontal grid spacing in y-direction (latitudinal) | ||
|
||
dy: float | ||
Horizontal grid spcaing in x-direction (longitudinal) | ||
""" | ||
|
||
# Use regular lat/lon with regular spacing | ||
if method_gen_grid == 'regular_lat_lon': | ||
if dy is None: | ||
dy = 0.25 | ||
|
||
if dx is None: | ||
dx = dy | ||
|
||
# Able to add other options at a later point | ||
|
||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Check to see if there is a weights file already existing | ||
# Use xESMF to generate the destination grid | ||
return xe.util.grid_global(dx, dy) | ||
|
||
|
||
def _get_default_filename(src_grid, dst_grid, method): | ||
|
||
# Get the source grid shape | ||
src_shape = src_grid.lat.shape | ||
|
||
# Get the destination grid shape | ||
dst_shape = dst_grid.lat.shape | ||
|
||
filename = '{0}_{1}x{2}_{3}x{4}.nc'.format( | ||
method, src_shape[0], src_shape[1], dst_shape[0], dst_shape[1] | ||
) | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return filename | ||
|
||
|
||
def _convert_to_xesmf(data_ds, grid_ds): | ||
""" | ||
Format xarray datasets to be read in easily to xESMF | ||
|
||
Parameters | ||
---------- | ||
data_ds : `xarray.Dataset` | ||
Dataset which includes fields to regrid | ||
|
||
grid_ds : `xarray.Dataset` | ||
Dataset including the POP grid | ||
|
||
Returns | ||
------- | ||
|
||
out_ds : `xarray.Dataset` | ||
Clipped dataset including fields to regrid with grid | ||
|
||
""" | ||
|
||
# Merge datasets into single dataset | ||
data_ds = xr.merge( | ||
[grid_ds.reset_coords(), data_ds.reset_coords()], compat='override', join='right' | ||
).rename({'TLAT': 'lat', 'TLONG': 'lon'}) | ||
|
||
# Inlcude only points that will have surrounding corners | ||
data_ds = data_ds.isel({'nlon': data_ds.nlon[1:], 'nlat': data_ds.nlat[1:]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't work for datasets returned by I think the solution here is to use cf_xarray
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a suggestion for adding a new X/Y coordinate? I noticed that the only cf axis within the POP grids is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ugh yeah I forgot. We need to make ds["nlon"] = ("nlon", np.arange(ds.sizes["nlon"]), {"axis": "X"})
ds["nlat"] = ("nlon", np.arange(ds.sizes["nlon"]), {"axis": "Y"}) and similarly for |
||
|
||
# Use ulat and ulong values as grid corners, rename variables to match xESMF syntax | ||
grid_corners = grid_ds[['ULAT', 'ULONG']].rename( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is wrong for EDIT: I see you've fixed this :) but a good solution would be to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternately this kind of inference requires knowledge of the underlying grid. so We can hardcode this for POP now and move on, but solving |
||
{'nlat': 'nlat_b', 'nlon': 'nlon_b', 'ULAT': 'lat_b', 'ULONG': 'lon_b'} | ||
) | ||
|
||
# Merge datasets with data and grid corner information | ||
out_ds = xr.merge([data_ds, grid_corners]) | ||
|
||
return out_ds | ||
|
||
|
||
def _generate_weights(src_grid, dst_grid, method, weight_file=None, clobber=False): | ||
""" | ||
Generate regridding weights by calling xESMF | ||
""" | ||
|
||
# Allow user to input weights file, if there is not one, use default check | ||
if weight_file is None: | ||
weight_file = _get_default_filename(src_grid, dst_grid, method) | ||
|
||
# Check to see if the weights file already exists - if not, generate weights | ||
if not os.path.exists(weight_file) or clobber: | ||
xe.Regridder(src_grid, dst_grid, method).to_netcdf(weight_file) | ||
|
||
regridder = xe.Regridder(src_grid, dst_grid, method, weights=weight_file) | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return regridder | ||
|
||
|
||
class regridder(object): | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
grid_name, | ||
grid=None, | ||
dx=None, | ||
dy=None, | ||
mask=None, | ||
regrid_method='conservative', | ||
method_gen_grid='regular_lat_lon', | ||
): | ||
""" | ||
A regridding class which uses xESMF and Xarray tools to both regrid and | ||
calculate a zonal averge. | ||
|
||
Parameters | ||
---------- | ||
grid_name | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if grid_name is not None: | ||
self.grid_name = grid_name | ||
|
||
# Use pop-tools to retrieve the grid | ||
self.grid = get_grid(grid_name) | ||
|
||
elif grid is not None: | ||
self.grid = grid | ||
|
||
else: | ||
raise ValueError('Failed to input grid name or grid dataset') | ||
|
||
# Set the dx/dy parameters for generating the grid | ||
self.dx = dx | ||
self.dy = dy | ||
|
||
# Set the regridding method | ||
self.regrid_method = regrid_method | ||
|
||
# Set the grid generation method | ||
self.method_gen_grid = method_gen_grid | ||
|
||
# If the user does not input a mask, use default mask | ||
if not mask: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. applying a mask when Should the default be |
||
self.mask = self.grid['REGION_MASK'] | ||
self.mask_labels = self.grid['region_name'] | ||
|
||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we never need no mask? I guess a global average across all basins is somewhat useless but would be required for the atmosphere. |
||
self.mask = mask | ||
|
||
# Setup method for regridding a dataarray | ||
def _regrid_dataarray(self, da_in, regrid_mask=False, regrid_method=None): | ||
|
||
src_grid = _convert_to_xesmf(da_in, self.grid) | ||
dst_grid = _generate_dest_grid(self.dy, self.dx, self.method_gen_grid) | ||
|
||
# If the user does not specify a regridding method, use default conservative | ||
if regrid_method is None: | ||
regridder = _generate_weights(src_grid, dst_grid, self.regrid_method) | ||
|
||
else: | ||
regridder = _generate_weights(src_grid, dst_grid, regrid_method) | ||
|
||
# Regrid the input data array, assigning the original attributes | ||
da_out = regridder(src_grid[da_in.name]) | ||
da_out.attrs = da_in.attrs | ||
|
||
return da_out | ||
|
||
def regrid(self, obj, **kwargs): | ||
"""generic interface for regridding DataArray or Dataset""" | ||
if isinstance(obj, xr.Dataset): | ||
return obj.map(self._regrid_dataarray, keep_attrs=True, **kwargs) | ||
elif isinstance(obj, xr.DataArray): | ||
return self._regrid_dataarray(obj, **kwargs) | ||
else: | ||
raise ValueError('unknown type') | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def za(self, obj, vertical_average=False, **kwargs): | ||
andersy005 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
data = self.regrid(obj, **kwargs) | ||
mask = self.regrid(self.mask, regrid_method='nearest_s2d', **kwargs) | ||
|
||
# Store the various datasets seperated by basin in this list | ||
ds_list = [] | ||
for region in np.unique(mask): | ||
|
||
if region != 0: | ||
ds_list.append(data.where(mask == region).groupby('lat').mean()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why groupby? Just a simple instead of this loop, I would recommend taking a 2D mask variable and expanding out to 3D with a new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Matt and were discussing, making use of the
where the dimensions of data_masked would be |
||
|
||
# Merge the datasets | ||
out = xr.concat(ds_list, dim='nreg') | ||
|
||
# Check to see if a weighted vertical average is needed | ||
if vertical_average: | ||
|
||
# Run the vertical, weighted average | ||
out = out.weighted(out['z_t'].fillna(0)).mean(dim=['z_t']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why weight by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be calculated by using the following? DZT = out.z_t.diff(dim='z_t') Or am I misunderstanding what DZT is? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
>>> import pop_tools
>>> gx1v7_grid = pop_tools.get_grid("POP_gx1v7")
>>> gx1v7_grid.z_t.diff(dim='z_t')
<xarray.DataArray 'z_t' (z_t: 59)>
array([ 1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1009.8404 , 1038.0646 , 1081.22175, 1136.90105, 1205.11015,
1286.69055, 1383.0544 , 1496.13345, 1628.40275, 1782.946 ,
1963.55735, 2174.8772 , 2422.5496 , 2713.41105, 3055.7061 ,
3459.30485, 3935.90165, 4499.1272 , 5164.489 , 5948.97315,
6870.06835, 7943.9187 , 9182.2754 , 10588.062 , 12149.8325 ,
13836.2715 , 15593.0264 , 17344.9219 , 19005.6846 , 20494.0957 ,
21751.2334 , 22751.50195, 23502.97165, 24038.333 , 24401.99805,
24638.8965 , 24787.6709 , 24878.15135, 24931.6328 , 24962.4424 ,
24979.77735, 24989.31735, 24994.45895, 24997.17675])
>>> You want the cell thickness - the difference between layer interfaces, which is >>> gx1v7_grid.z_w_bot.swap_dims(z_w_bot="z_t") - gx1v7_grid.z_w.swap_dims(z_w="z_t")
<xarray.DataArray (z_t: 60)>
array([ 1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1019.6808, 1056.4484, 1105.9951, 1167.807 ,
1242.4133, 1330.9678, 1435.141 , 1557.1259, 1699.6796,
1866.2124, 2060.9023, 2288.8521, 2556.2471, 2870.575 ,
3240.8372, 3677.7725, 4194.0308, 4804.2236, 5524.7544,
6373.1919, 7366.9448, 8520.8926, 9843.6582, 11332.4658,
12967.1992, 14705.3438, 16480.709 , 18209.1348, 19802.2344,
21185.957 , 22316.5098, 23186.4941, 23819.4492, 24257.2168,
24546.7793, 24731.0137, 24844.3281, 24911.9746, 24951.291 ,
24973.5938, 24985.9609, 24992.6738, 24996.2441, 24998.1094])
Coordinates:
z_w_bot (z_t) float64 1e+03 2e+03 3e+03 4e+03 ... 5e+05 5.25e+05 5.5e+05
z_w (z_t) float64 0.0 1e+03 2e+03 3e+03 ... 4.75e+05 5e+05 5.25e+05
Dimensions without coordinates: z_t
>>> I think >>> gx1v7_grid["dz"]
<xarray.DataArray 'dz' (z_t: 60)>
array([ 1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1000. , 1000. , 1000. , 1000. ,
1000. , 1019.6808, 1056.4484, 1105.9951, 1167.807 ,
1242.4133, 1330.9678, 1435.141 , 1557.1259, 1699.6796,
1866.2124, 2060.9023, 2288.8521, 2556.2471, 2870.575 ,
3240.8372, 3677.7725, 4194.0308, 4804.2236, 5524.7544,
6373.1919, 7366.9448, 8520.8926, 9843.6582, 11332.4658,
12967.1992, 14705.3438, 16480.709 , 18209.1348, 19802.2344,
21185.957 , 22316.5098, 23186.4941, 23819.4492, 24257.2168,
24546.7793, 24731.0137, 24844.3281, 24911.9746, 24951.291 ,
24973.5938, 24985.9609, 24992.6738, 24996.2441, 24998.1094])
Coordinates:
* z_t (z_t) float64 500.0 1.5e+03 2.5e+03 ... 5.125e+05 5.375e+05
Attributes:
units: cm
long_name: thickness of layer k
>>> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I adjusted this to utilize the out = out.weighted(self.grid.dz).mean(dim='z_t') Which seems to work well. |
||
|
||
# Add in the region name | ||
out['region_name'] = data.region_name | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make
xesmf
a soft dependency i.e.??