Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def test_to_raster(gridpath):
assert isinstance(raster, np.ndarray)


def test_to_raster_with_extra_dims(gridpath):
fig, ax = plt.subplots(
subplot_kw={'projection': ccrs.Robinson()},
constrained_layout=True,
figsize=(10, 5),
)

mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
uxds = ux.open_dataset(mesh_path, mesh_path)

da = uxds['bottomDepth'].expand_dims(time=[0])
raster = da.isel(time=slice(0, 1)).to_raster(ax=ax)

assert isinstance(raster, np.ndarray)



def test_to_raster_reuse_mapping(gridpath, tmpdir):

Expand Down
4 changes: 2 additions & 2 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def to_raster(
_RasterAxAttrs,
)

_ensure_dimensions(self)
data = _ensure_dimensions(self)

if not isinstance(ax, GeoAxes):
raise TypeError("`ax` must be an instance of cartopy.mpl.geoaxes.GeoAxes")
Expand Down Expand Up @@ -405,7 +405,7 @@ def to_raster(
pixel_mapping = np.asarray(pixel_mapping, dtype=INT_DTYPE)

raster, pixel_mapping_np = _nearest_neighbor_resample(
self,
data,
ax,
pixel_ratio=pixel_ratio,
pixel_mapping=pixel_mapping,
Expand Down
18 changes: 11 additions & 7 deletions uxarray/plot/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@ def _ensure_dimensions(data: UxDataArray) -> UxDataArray:
ValueError
If the sole dimension is not named "n_face".
"""
# Check dimensionality
if data.ndim != 1:
# Allow extra singleton dimensions as long as there's exactly one non-singleton dim
non_trivial_dims = [dim for dim, size in zip(data.dims, data.shape) if size != 1]

if len(non_trivial_dims) != 1:
raise ValueError(
f"Expected a 1D DataArray over 'n_face', but got {data.ndim} dimensions: {data.dims}"
"Expected data with a single dimension (other axes may be length 1), "
f"but got dims {data.dims} with shape {data.shape}"
)

# Check dimension name
if data.dims[0] != "n_face":
raise ValueError(f"Expected dimension 'n_face', but got '{data.dims[0]}'")
sole_dim = non_trivial_dims[0]
if sole_dim != "n_face":
raise ValueError(f"Expected dimension 'n_face', but got '{sole_dim}'")

return data
# Squeeze any singleton axes to ensure we return a true 1D array over n_face
return data.squeeze()


class _RasterAxAttrs(NamedTuple):
Expand Down
Loading