diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 6ddc5d3..94ff71b 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -1,6 +1,7 @@ import pathlib import numpy as np +import pytest import xarray as xr DATA_FOLDER = pathlib.Path(__file__).parent / "data" @@ -49,6 +50,21 @@ def test_open_dataset_root() -> None: assert isinstance(res, xr.Dataset) +def test_open_dataset_root_accessor() -> None: + product_path = ( + DATA_FOLDER + / "S1B_IW_SLC__1SDV_20210401T052622_20210401T052650_026269_032297_EFA4.SAFE" + ) + res = xr.open_dataset(product_path, engine="sentinel-1") # type: ignore + + assert res.sentinel1.group is None + + res1 = xr.zeros_like(res, 0) + + with pytest.raises(TypeError): + res1.sentinel1 + + def test_open_dataset_orbit() -> None: manifest_path = ( DATA_FOLDER diff --git a/xarray_sentinel/sentinel1.py b/xarray_sentinel/sentinel1.py index 72c8745..14c307c 100644 --- a/xarray_sentinel/sentinel1.py +++ b/xarray_sentinel/sentinel1.py @@ -293,13 +293,13 @@ def compute_burst_centres(gcp: xr.Dataset) -> T.Tuple[np.ndarray, np.ndarray]: def open_dataset( - product_urlpath: esa_safe.PathType, + urlpath: esa_safe.PathType, drop_variables: T.Optional[T.Tuple[str]] = None, group: T.Optional[str] = None, chunks: T.Optional[T.Union[int, T.Dict[str, int]]] = None, fs: T.Optional[fsspec.AbstractFileSystem] = None, ) -> xr.Dataset: - fs, manifest_path = get_fs_path(product_urlpath, fs) + fs, manifest_path = get_fs_path(urlpath, fs) if fs.isdir(manifest_path): manifest_path = os.path.join(manifest_path, "manifest.safe") @@ -341,9 +341,24 @@ def open_dataset( annotation_path=annotation_path, chunks=chunks, ) + # add backend specific metadata in the Dataset enconding + ds.encoding = { + "engine": "sentinel-1", + "group": group, + "urlpath": urlpath, + } return ds +@xr.register_dataset_accessor("sentinel1") # type: ignore +class Sentinel1Accessor: + def __init__(self, xarray_obj: xr.Dataset) -> None: + if xarray_obj.encoding.get("engine") != "sentinel-1": + raise TypeError("not a 'sentinel-1' 'Dataset'") + self.urlpath = xarray_obj.encoding["urlpath"] + self.group = xarray_obj.encoding["group"] + + class Sentinel1Backend(xr.backends.common.BackendEntrypoint): def open_dataset( # type: ignore self,