diff --git a/README.md b/README.md index d79401ba..121e99cd 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,16 @@ Quick links: Setup ----- -rslearn requires Python 3.10+ (Python 3.12 is recommended). +rslearn requires Python 3.11+ (Python 3.12 is recommended). ``` git clone https://github.com/allenai/rslearn.git cd rslearn -pip install .[extra] +uv venv --python 3.11 +source .venv/bin/activate +uv sync +uv pip install -e ".[extra]" +uv pip install -e ".[dev]" # If running tests ``` diff --git a/docs/DatasetConfig.md b/docs/DatasetConfig.md index 32de6995..06128601 100644 --- a/docs/DatasetConfig.md +++ b/docs/DatasetConfig.md @@ -1163,6 +1163,51 @@ Available bands: - B10 - B11 +### rslearn.data_sources.zarr.ZarrDataSource + +This data source reads spatio-temporal cubes that are stored in a Zarr hierarchy. It can +either ingest items into the dataset tile store or act as the tile store itself when +`ingest` is set to false. Access to the underlying cube requires the optional +dependencies installed via `pip install rslearn[extra]`. + +```jsonc +{ + // Required URI pointing to the root of the Zarr store. Any fsspec-compatible URI is + // supported. + "store_uri": "s3://bucket/path/to/datacube.zarr", + // Optional variable name inside the store. If omitted, the store must contain a + // single data variable. + "data_variable": "reflectance", + // Required CRS of the cube, expressed as an EPSG code or WKT string. + "crs": "EPSG:32633", + // Required pixel size. Provide either a scalar (identical resolutions) or an object + // with explicit x and y values. + "pixel_size": 10, + // Required origin of pixel (0, 0) expressed as [min_x, max_y] in CRS units. + "origin": [500000.0, 4200000.0], + // Required mapping from conceptual axes to dimension names in the Zarr array. + "axis_names": {"x": "x", "y": "y", "time": "time", "band": "band"}, + // Required list of bands. The length must match the band dimension when present. + "bands": ["B02", "B03", "B04"], + // Required numpy dtype string that matches the underlying Zarr array. + "dtype": "float32", + // Optional nodata value applied when writing tiles and returned during direct reads. + "nodata": 0.0, + // Optional override for how the cube is broken into items. Each value is the number + // of pixels per chunk along that axis. + "chunk_shape": {"y": 1024, "x": 1024}, + // Optional fsspec storage options passed to xarray.open_zarr. + "storage_options": {"anon": true}, + // Optional flag toggling consolidated metadata support. Defaults to true. + "consolidated": true +} +``` + +The Zarr data source currently creates one item per time step. When you skip ingestion +(`"ingest": false` on the layer), the source acts as a read-only tile store so windows +can be materialized directly from the Zarr cube. + +======= ### rslearn.data_sources.worldcover.WorldCover This data source is for the ESA WorldCover 2021 land cover map. diff --git a/docs/examples/ZarrDataSource.md b/docs/examples/ZarrDataSource.md new file mode 100644 index 00000000..c9307992 --- /dev/null +++ b/docs/examples/ZarrDataSource.md @@ -0,0 +1,54 @@ +# Zarr Data Source Example + +The snippet below demonstrates how to reference a spatio-temporal Zarr cube from a +raster layer. Install the optional dependencies before running the dataset workflow: + +```bash +uv pip install -e ".[extra]" +``` + +Add a layer similar to the following in your dataset's `config.json`: + +```jsonc +"sentinel2": { + "type": "raster", + "bands": [ + { + "name": "B02", + "dtype": "float32", + "nodata": 0.0 + }, + { + "name": "B03", + "dtype": "float32", + "nodata": 0.0 + }, + { + "name": "B04", + "dtype": "float32", + "nodata": 0.0 + } + ], + "data_source": { + "name": "rslearn.data_sources.zarr.ZarrDataSource", + "store_uri": "s3://bucket/path/to/datacube.zarr", + "data_variable": "reflectance", + "crs": "EPSG:32633", + "pixel_size": 10, + "origin": [500000.0, 4200000.0], + "axis_names": {"x": "x", "y": "y", "time": "time", "band": "band"}, + "bands": ["B02", "B03", "B04"], + "dtype": "float32", + "nodata": 0.0, + "chunk_shape": {"y": 1024, "x": 1024}, + "storage_options": {"anon": true} + }, + // Set to false to stream directly from the cube instead of ingesting. + "ingest": true +} +``` + +When `ingest` is left at the default `true`, run `rslearn dataset ingest` to cache each +chunk into your tile store. If you flip `ingest` to `false`, `rslearn dataset +materialize` will read the necessary portions directly from the Zarr store instead. + diff --git a/pyproject.toml b/pyproject.toml index 6c64091b..42fbfc65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,12 +46,15 @@ extra = [ "planetary_computer>=1.0", "pycocotools>=2.0", "pystac_client>=0.9", + "rioxarray>=0.15", "rtree>=1.4", "s3fs==2025.3.0", "satlaspretrain_models>=0.3", "scipy>=1.16", "terratorch>=1.0.2", "transformers>=4.55", + "xarray>=2024.1", + "zarr>=2.17", "wandb>=0.21", ] diff --git a/rslearn/data_sources/__init__.py b/rslearn/data_sources/__init__.py index cf255046..64ff7f90 100644 --- a/rslearn/data_sources/__init__.py +++ b/rslearn/data_sources/__init__.py @@ -19,6 +19,7 @@ from rslearn.log_utils import get_logger from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource +from .zarr import ZarrDataSource, ZarrItem logger = get_logger(__name__) @@ -47,5 +48,7 @@ def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource: "Item", "ItemLookupDataSource", "RetrieveItemDataSource", + "ZarrDataSource", + "ZarrItem", "data_source_from_config", ) diff --git a/rslearn/data_sources/zarr.py b/rslearn/data_sources/zarr.py new file mode 100644 index 00000000..3f5decb6 --- /dev/null +++ b/rslearn/data_sources/zarr.py @@ -0,0 +1,668 @@ +"""Data source for reading spatio-temporal cubes stored in Zarr.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Iterable + +import numpy as np +import numpy.typing as npt +import shapely +from rasterio.crs import CRS +from rasterio.enums import Resampling + +from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig +from rslearn.dataset import Window +from rslearn.dataset.materialize import RasterMaterializer +from rslearn.log_utils import get_logger +from rslearn.tile_stores import TileStore, TileStoreWithLayer +from rslearn.utils.grid_index import GridIndex +from rslearn.utils.geometry import Projection, STGeometry, shp_intersects +from rslearn.utils.raster_format import get_transform_from_projection_and_bounds + +from .data_source import DataSource, Item +from .utils import match_candidate_items_to_window + +logger = get_logger(__name__) + + +def _import_zarr_deps() -> tuple[Any, Any]: + """Import dependencies required for interacting with Zarr stores.""" + + try: + import xarray as xr # type: ignore + except ImportError as exc: # pragma: no cover - import guard + raise ImportError( + "ZarrDataSource requires xarray; install rslearn with the 'extra' extra" + ) from exc + + try: + import zarr # noqa: F401 # type: ignore + except ImportError as exc: # pragma: no cover - import guard + raise ImportError( + "ZarrDataSource requires zarr; install rslearn with the 'extra' extra" + ) from exc + + return xr, None + + +def _ensure_utc(dt: datetime) -> datetime: + """Attach UTC timezone if missing and normalize to UTC.""" + + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def _to_datetime(value: Any) -> datetime: + """Convert various time coordinate values emitted by xarray to datetime.""" + + if isinstance(value, datetime): + return _ensure_utc(value) + + try: + import cftime # type: ignore + + if isinstance(value, cftime.datetime): + return _ensure_utc(value.to_datetime()) + except ImportError: # pragma: no cover - optional dependency + pass + + if isinstance(value, np.datetime64): + iso = np.datetime_as_string(value, timezone="UTC") + if iso.endswith("Z"): + iso = iso[:-1] + "+00:00" + dt = datetime.fromisoformat(iso) + return _ensure_utc(dt) + + if isinstance(value, str): + dt = datetime.fromisoformat(value) + return _ensure_utc(dt) + + raise TypeError(f"Unsupported time coordinate type: {type(value)!r}") + + +def _compute_time_boundaries(times: list[datetime], fallback: timedelta) -> list[datetime]: + """Compute interval boundaries for a monotonic list of timestamps.""" + + if len(times) == 0: + return [] + + # Ensure timestamps are sorted to avoid negative durations. + sorted_times = sorted(times) + boundaries: list[datetime] = [] + + if len(sorted_times) == 1: + half = fallback / 2 + boundaries.append(sorted_times[0] - half) + boundaries.append(sorted_times[0] + half) + return boundaries + + for idx in range(len(sorted_times) + 1): + if idx == 0: + delta = sorted_times[1] - sorted_times[0] + boundaries.append(sorted_times[0] - delta / 2) + elif idx == len(sorted_times): + delta = sorted_times[-1] - sorted_times[-2] + boundaries.append(sorted_times[-1] + delta / 2) + else: + prev_time = sorted_times[idx - 1] + next_time = sorted_times[idx] + boundaries.append(prev_time + (next_time - prev_time) / 2) + + return boundaries + + +class ZarrItem(Item): + """Represents a spatio-temporal chunk inside the Zarr data cube.""" + + x_range: tuple[int, int] + y_range: tuple[int, int] + time_range_indexes: tuple[int, int] | None + dim_slices: dict[str, tuple[int, int]] + x_offset: int + y_offset: int + + def __init__( + self, + name: str, + geometry: STGeometry, + x_range: tuple[int, int], + y_range: tuple[int, int], + time_range_indexes: tuple[int, int] | None, + dim_slices: dict[str, tuple[int, int]], + x_offset: int, + y_offset: int, + ) -> None: + super().__init__(name, geometry) + self.x_range = x_range + self.y_range = y_range + self.time_range_indexes = time_range_indexes + self.dim_slices = dim_slices + self.x_offset = x_offset + self.y_offset = y_offset + + @property + def pixel_bounds(self) -> tuple[int, int, int, int]: + """Return bounds in pixel coordinates (x0, y0, x1, y1).""" + + return ( + self.x_offset + self.x_range[0], + self.y_offset + self.y_range[0], + self.x_offset + self.x_range[1], + self.y_offset + self.y_range[1], + ) + + def serialize(self) -> dict: + """Serialize this item for storing inside a window.""" + + data = super().serialize() + data.update( + { + "x_range": list(self.x_range), + "y_range": list(self.y_range), + "x_offset": self.x_offset, + "y_offset": self.y_offset, + "time_range_indexes": list(self.time_range_indexes) + if self.time_range_indexes + else None, + "dim_slices": { + dim: list(range_pair) for dim, range_pair in self.dim_slices.items() + }, + } + ) + return data + + @staticmethod + def deserialize(data: dict) -> "ZarrItem": + """Deserialize a serialized ZarrItem.""" + + base_item = Item.deserialize(data) + dim_slices = { + dim: (range_pair[0], range_pair[1]) for dim, range_pair in data["dim_slices"].items() + } + time_range_indexes = None + if data["time_range_indexes"] is not None: + time_range_indexes = ( + data["time_range_indexes"][0], + data["time_range_indexes"][1], + ) + return ZarrItem( + name=base_item.name, + geometry=base_item.geometry, + x_range=(data["x_range"][0], data["x_range"][1]), + y_range=(data["y_range"][0], data["y_range"][1]), + time_range_indexes=time_range_indexes, + dim_slices=dim_slices, + x_offset=data.get("x_offset", 0), + y_offset=data.get("y_offset", 0), + ) + + +class ZarrDataSource(DataSource[ZarrItem], TileStore): + """DataSource for reading raster cubes stored in a Zarr hierarchy.""" + + DEFAULT_SINGLE_TIME_INTERVAL = timedelta(hours=1) + + def __init__( + self, + *, + store_uri: str, + data_variable: str, + projection: Projection, + band_names: list[str], + dtype: np.dtype, + nodata: float | int | None, + axis_names: dict[str, str], + items: list[ZarrItem], + time_boundaries: list[datetime] | None, + storage_options: dict[str, Any], + consolidated: bool, + ) -> None: + self.store_uri = store_uri + self.data_variable = data_variable + self.projection = projection + self.band_names = band_names + self.dtype = dtype + self.nodata = nodata + self.axis_names = axis_names + self._items = items + self.storage_options = storage_options + self.consolidated = consolidated + self._time_boundaries = time_boundaries + + self._item_by_name = {item.name: item for item in self._items} + grid_size = 1 + for item in self._items: + width = item.x_range[1] - item.x_range[0] + height = item.y_range[1] - item.y_range[0] + grid_size = max(grid_size, min(width, height)) + + self._spatial_index = GridIndex(size=grid_size) + for item in self._items: + self._spatial_index.insert(item.geometry.shp.bounds, item) + + self._data_array = None + + @staticmethod + def from_config(config: LayerConfig, ds_path: Any) -> "ZarrDataSource": + """Create a ZarrDataSource from a RasterLayer configuration.""" + + if not isinstance(config, RasterLayerConfig): + raise ValueError("ZarrDataSource requires a raster layer") + if config.data_source is None: + raise ValueError("data_source configuration required for ZarrDataSource") + + cfg = config.data_source.config_dict + + required_keys = [ + "store_uri", + "axis_names", + "pixel_size", + "origin", + "bands", + "crs", + "dtype", + ] + for key in required_keys: + if key not in cfg: + raise ValueError(f"Missing required Zarr data source config key: {key}") + + axis_names = cfg["axis_names"] + x_dim = axis_names.get("x") + y_dim = axis_names.get("y") + if not x_dim or not y_dim: + raise ValueError("axis_names must map 'x' and 'y' dimensions") + + time_dim = axis_names.get("time") + band_dim = axis_names.get("band") + + pixel_size = cfg["pixel_size"] + if isinstance(pixel_size, dict): + x_resolution = float(pixel_size.get("x")) + y_resolution = -float(pixel_size.get("y")) + else: + x_resolution = float(pixel_size) + y_resolution = -float(pixel_size) + + origin = cfg["origin"] + if not isinstance(origin, (list, tuple)) or len(origin) != 2: + raise ValueError("origin must be a two element array [x_min, y_max]") + + projection = Projection(CRS.from_string(cfg["crs"]), x_resolution, y_resolution) + + band_names = cfg["bands"] + if not isinstance(band_names, list) or len(band_names) == 0: + raise ValueError("bands must be a non-empty list") + + dtype = np.dtype(cfg["dtype"]) + nodata = cfg.get("nodata") + + data_variable = cfg.get("data_variable") + storage_options = cfg.get("storage_options") + consolidated = cfg.get("consolidated", True) + + xr, _ = _import_zarr_deps() + open_kwargs = dict(consolidated=consolidated) + if storage_options is not None: + open_kwargs["storage_options"] = storage_options + dataset = xr.open_zarr(cfg["store_uri"], **open_kwargs) + if data_variable is None: + if len(dataset.data_vars) != 1: + raise ValueError( + "data_variable must be specified when Zarr store has multiple variables" + ) + data_variable = next(iter(dataset.data_vars)) + data_array = dataset[data_variable] + + time_chunk_size = cfg.get("time_chunk_size", 1) + chunk_shape_cfg = cfg.get("chunk_shape", {}) + + if band_dim and band_dim in data_array.dims: + band_size = int(data_array.sizes[band_dim]) + if band_size != len(band_names): + raise ValueError( + "Configured bands do not match Zarr band dimension size" + ) + elif len(band_names) != 1: + raise ValueError( + "Zarr data without an explicit band dimension must configure exactly one band" + ) + + def _dimension_chunk(dim: str, default: int) -> int: + if dim in chunk_shape_cfg: + return int(chunk_shape_cfg[dim]) + if getattr(data_array, "chunks", None) and data_array.chunksizes.get(dim): + return int(data_array.chunksizes[dim][0]) + return default + + y_size = int(data_array.sizes[y_dim]) + x_size = int(data_array.sizes[x_dim]) + y_chunk = max(1, min(_dimension_chunk(y_dim, y_size), y_size)) + x_chunk = max(1, min(_dimension_chunk(x_dim, x_size), x_size)) + + time_boundaries: list[datetime] | None = None + time_indexes: Iterable[int] = [0] + if time_dim in data_array.dims: + time_size = int(data_array.sizes[time_dim]) + if time_chunk_size != 1: + logger.warning( + "ZarrDataSource currently treats one time-step per item; " + "ignoring time_chunk_size=%s", time_chunk_size + ) + times = [_to_datetime(value) for value in data_array[time_dim].values] + fallback = ZarrDataSource.DEFAULT_SINGLE_TIME_INTERVAL + if len(times) >= 2: + first_delta = times[1] - times[0] + if first_delta.total_seconds() > 0: + fallback = first_delta + time_boundaries = _compute_time_boundaries(times, fallback) + time_indexes = range(time_size) + else: + time_dim = None + + items: list[ZarrItem] = [] + origin_x, origin_y = float(origin[0]), float(origin[1]) + x_offset = int(round(origin_x / x_resolution)) + y_offset = int(round(origin_y / y_resolution)) + + x_ranges = [ + (x_start, min(x_start + x_chunk, x_size)) + for x_start in range(0, x_size, x_chunk) + ] + y_ranges = [ + (y_start, min(y_start + y_chunk, y_size)) + for y_start in range(0, y_size, y_chunk) + ] + + for t_index in time_indexes: + for y_range in y_ranges: + for x_range in x_ranges: + if time_dim is None: + time_range = None + time_index_range = None + else: + assert time_boundaries is not None + start = time_boundaries[t_index] + end = time_boundaries[t_index + 1] + time_range = (start, end) + time_index_range = (t_index, t_index + 1) + + geometry = STGeometry( + projection, + shapely.box( + x_offset + x_range[0], + y_offset + y_range[0], + x_offset + x_range[1], + y_offset + y_range[1], + ), + time_range, + ) + name = f"t{t_index}_y{y_range[0]}_{y_range[1]}_x{x_range[0]}_{x_range[1]}" + dim_slices = { + y_dim: y_range, + x_dim: x_range, + } + if time_dim: + dim_slices[time_dim] = time_index_range + if band_dim and band_dim in data_array.dims: + dim_slices[band_dim] = (0, len(band_names)) + + items.append( + ZarrItem( + name=name, + geometry=geometry, + x_range=x_range, + y_range=y_range, + time_range_indexes=time_index_range, + dim_slices=dim_slices, + x_offset=x_offset, + y_offset=y_offset, + ) + ) + + return ZarrDataSource( + store_uri=cfg["store_uri"], + data_variable=data_variable, + projection=projection, + band_names=band_names, + dtype=dtype, + nodata=nodata, + axis_names={ + "x": x_dim, + "y": y_dim, + "time": time_dim, + "band": band_dim, + }, + items=items, + time_boundaries=time_boundaries, + storage_options=storage_options, + consolidated=consolidated, + ) + + # ------------------------------------------------------------------ + # DataSource interface + # ------------------------------------------------------------------ + + def _load_data_array(self): + if self._data_array is not None: + return self._data_array + + xr, _ = _import_zarr_deps() + dataset = xr.open_zarr( + self.store_uri, + consolidated=self.consolidated, + storage_options=self.storage_options, + ) + data_array = dataset[self.data_variable] + self._data_array = data_array + return data_array + + def _iter_candidate_items(self, geometry: STGeometry) -> list[ZarrItem]: + geometry_in_projection = geometry.to_projection(self.projection) + candidates = self._spatial_index.query(geometry_in_projection.shp.bounds) + filtered: list[ZarrItem] = [] + for item in candidates: + item_geom = item.geometry + if not shp_intersects(item_geom.shp, geometry_in_projection.shp): + continue + if geometry.time_range and not geometry.intersects_time_range( + item_geom.time_range + ): + continue + filtered.append(item) + return filtered + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[ZarrItem]]]: + groups: list[list[list[ZarrItem]]] = [] + for geometry in geometries: + if geometry.time_range is not None: + geometry = STGeometry( + geometry.projection, + geometry.shp, + ( + _ensure_utc(geometry.time_range[0]), + _ensure_utc(geometry.time_range[1]), + ), + ) + + candidates = self._iter_candidate_items(geometry) + geometry_in_projection = geometry.to_projection(self.projection) + cur_groups = match_candidate_items_to_window( + geometry_in_projection, + candidates, + query_config, + ) + groups.append(cur_groups) + return groups + + def deserialize_item(self, serialized_item: Any) -> ZarrItem: + return ZarrItem.deserialize(serialized_item) + + def _read_item_array(self, item: ZarrItem) -> npt.NDArray[Any]: + data_array = self._load_data_array() + indexers: dict[str, slice] = {} + for dim, range_pair in item.dim_slices.items(): + indexers[dim] = slice(range_pair[0], range_pair[1]) + + selected = data_array.isel(**indexers) + + time_dim = self.axis_names.get("time") + if time_dim and time_dim in selected.dims: + selected = selected.squeeze(dim=time_dim, drop=True) + + band_dim = self.axis_names.get("band") + if band_dim and band_dim in selected.dims: + selected = selected.transpose(band_dim, self.axis_names["y"], self.axis_names["x"]) + array = selected.values + else: + selected = selected.transpose(self.axis_names["y"], self.axis_names["x"]) + array = selected.values[None, :, :] + + return np.asarray(array, dtype=self.dtype, order="C") + + def ingest( + self, + tile_store: TileStoreWithLayer, + items: list[ZarrItem], + geometries: list[list[STGeometry]], + ) -> None: + for item in items: + if tile_store.is_raster_ready(item.name, self.band_names): + continue + + array = self._read_item_array(item) + tile_store.write_raster( + item.name, + self.band_names, + self.projection, + item.pixel_bounds, + array, + ) + + def materialize( + self, + window: Window, + item_groups: list[list[ZarrItem]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + if not isinstance(layer_cfg, RasterLayerConfig): + raise ValueError("ZarrDataSource only supports raster materialization") + RasterMaterializer().materialize( + TileStoreWithLayer(self, layer_name), + window, + layer_name, + layer_cfg, + item_groups, + ) + + # ------------------------------------------------------------------ + # TileStore interface (read-only) + # ------------------------------------------------------------------ + + def is_raster_ready( + self, layer_name: str, item_name: str, bands: list[str] + ) -> bool: + if bands != self.band_names: + return False + return item_name in self._item_by_name + + def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]: + if item_name not in self._item_by_name: + return [] + return [self.band_names] + + def get_raster_bounds( + self, layer_name: str, item_name: str, bands: list[str], projection: Projection + ) -> tuple[int, int, int, int]: + if item_name not in self._item_by_name: + raise ValueError(f"Unknown item {item_name}") + item = self._item_by_name[item_name] + geom = item.geometry.to_projection(projection) + bounds = geom.shp.bounds + return ( + int(math.floor(bounds[0])), + int(math.floor(bounds[1])), + int(math.ceil(bounds[2])), + int(math.ceil(bounds[3])), + ) + + def read_raster( + self, + layer_name: str, + item_name: str, + bands: list[str], + projection: Projection, + bounds: tuple[int, int, int, int], + resampling: Resampling = Resampling.bilinear, + ) -> npt.NDArray[Any]: + if bands != self.band_names: + raise ValueError( + f"ZarrDataSource stores bands {self.band_names}, requested {bands}" + ) + + if item_name not in self._item_by_name: + raise ValueError(f"Unknown item {item_name}") + + item = self._item_by_name[item_name] + + request_geometry = STGeometry(projection, shapely.box(*bounds), None) + request_in_native = request_geometry.to_projection(self.projection) + + intersection = request_in_native.shp.intersection(item.geometry.shp) + if intersection.is_empty: + height = bounds[3] - bounds[1] + width = bounds[2] - bounds[0] + fill_value = self.nodata if self.nodata is not None else 0 + return np.full((len(self.band_names), height, width), fill_value, self.dtype) + + read_bounds = ( + math.floor(intersection.bounds[0]), + math.floor(intersection.bounds[1]), + math.ceil(intersection.bounds[2]), + math.ceil(intersection.bounds[3]), + ) + + crop = self._read_item_array(item) + x0 = read_bounds[0] - item.pixel_bounds[0] + x1 = x0 + (read_bounds[2] - read_bounds[0]) + y0 = read_bounds[1] - item.pixel_bounds[1] + y1 = y0 + (read_bounds[3] - read_bounds[1]) + crop = crop[:, y0:y1, x0:x1] + + if self.projection == projection and read_bounds == bounds: + return crop + + src_transform = get_transform_from_projection_and_bounds( + self.projection, read_bounds + ) + dst_transform = get_transform_from_projection_and_bounds(projection, bounds) + dst_array = np.full( + (len(self.band_names), bounds[3] - bounds[1], bounds[2] - bounds[0]), + self.nodata if self.nodata is not None else 0, + dtype=self.dtype, + ) + + import rasterio.warp + + rasterio.warp.reproject( + source=crop, + src_crs=self.projection.crs, + src_transform=src_transform, + destination=dst_array, + dst_crs=projection.crs, + dst_transform=dst_transform, + resampling=resampling, + src_nodata=self.nodata, + dst_nodata=self.nodata, + ) + + return dst_array + + +__all__ = ["ZarrDataSource", "ZarrItem"] diff --git a/tests/unit/data_sources/test_zarr.py b/tests/unit/data_sources/test_zarr.py new file mode 100644 index 00000000..bb35fa28 --- /dev/null +++ b/tests/unit/data_sources/test_zarr.py @@ -0,0 +1,164 @@ +"""Tests for the Zarr data source.""" + +from __future__ import annotations + +from datetime import datetime, timedelta + +import numpy as np +import pytest +import shapely.geometry +from rasterio.crs import CRS +from upath import UPath + +from rslearn.config.dataset import RasterLayerConfig +from rslearn.tile_stores import DefaultTileStore, TileStoreWithLayer +from rslearn.utils.geometry import Projection, STGeometry + + +@pytest.fixture() +def sample_zarr_store(tmp_path): + """Create a small Zarr cube suitable for testing.""" + + xr = pytest.importorskip("xarray") + pytest.importorskip("zarr") + + times = np.array(["2024-01-01", "2024-01-02"], dtype="datetime64[ns]") + bands = ["B02", "B03", "B04"] + y = np.arange(4) + x = np.arange(4) + + data = np.zeros((len(times), len(bands), len(y), len(x)), dtype=np.float32) + for ti in range(len(times)): + for bi in range(len(bands)): + data[ti, bi, :, :] = ti * 10.0 + float(bi) + + array = xr.DataArray( + data, + coords={"time": times, "band": bands, "y": y, "x": x}, + dims=("time", "band", "y", "x"), + name="reflectance", + ) + dataset = array.to_dataset() + store_path = tmp_path / "cube.zarr" + dataset.to_zarr(store_path, mode="w") + return store_path + + +def build_layer_config(store_path) -> RasterLayerConfig: + """Construct a raster layer configuration for the test cube.""" + + config_dict = { + "type": "raster", + "band_sets": [ + { + "dtype": "float32", + "bands": ["B02", "B03", "B04"], + "nodata_vals": [0.0, 0.0, 0.0], + } + ], + "data_source": { + "name": "rslearn.data_sources.zarr.ZarrDataSource", + "store_uri": str(store_path), + "data_variable": "reflectance", + "crs": "EPSG:32633", + "pixel_size": 1, + "origin": [0.0, 0.0], + "axis_names": { + "x": "x", + "y": "y", + "time": "time", + "band": "band", + }, + "bands": ["B02", "B03", "B04"], + "dtype": "float32", + "nodata": 0.0, + "chunk_shape": {"y": 4, "x": 4}, + }, + } + return RasterLayerConfig.from_config(config_dict) + + +def test_zarr_ingest_and_read(tmp_path, sample_zarr_store): + """Zarr data source should ingest chunks and support direct reads.""" + + layer_cfg = build_layer_config(sample_zarr_store) + + from rslearn.data_sources.zarr import ZarrDataSource + + data_source = ZarrDataSource.from_config( + layer_cfg, UPath(tmp_path / "dataset") + ) + + projection = Projection(CRS.from_epsg(32633), 1, -1) + window_time = ( + datetime(2024, 1, 1), + datetime(2024, 1, 1) + timedelta(hours=1), + ) + window_geom = STGeometry( + projection, + shapely.geometry.box(0, 0, 4, 4), + window_time, + ) + + groups = data_source.get_items( + [window_geom], layer_cfg.data_source.query_config + ) + assert len(groups) == 1 + assert len(groups[0]) == 1 + assert len(groups[0][0]) == 1 + item = groups[0][0][0] + + # Ensure serialization round-trips. + reconstructed = data_source.deserialize_item(item.serialize()) + assert reconstructed.pixel_bounds == item.pixel_bounds + + layer_name = "zarr_layer" + tile_store = DefaultTileStore() + tile_store.set_dataset_path(UPath(tmp_path / "tiles")) + + data_source.ingest( + TileStoreWithLayer(tile_store, layer_name), + [item], + [[window_geom]], + ) + + bands = ["B02", "B03", "B04"] + assert tile_store.is_raster_ready(layer_name, item.name, bands) + ingested = tile_store.read_raster( + layer_name, + item.name, + bands, + data_source.projection, + item.pixel_bounds, + ) + assert ingested.shape == (3, 4, 4) + assert np.allclose(ingested[0], 0.0) + assert np.allclose(ingested[1], 1.0) + assert np.allclose(ingested[2], 2.0) + + # Direct reads without ingestion should also succeed on partial bounds. + partial_bounds = ( + item.pixel_bounds[0], + item.pixel_bounds[1], + item.pixel_bounds[0] + 2, + item.pixel_bounds[1] + 2, + ) + direct = data_source.read_raster( + layer_name, + item.name, + bands, + data_source.projection, + partial_bounds, + ) + assert direct.shape == (3, 2, 2) + assert np.allclose(direct[0], 0.0) + assert np.allclose(direct[1], 1.0) + assert np.allclose(direct[2], 2.0) + + # Tile store metadata helpers should reflect the item structure. + assert data_source.is_raster_ready(layer_name, item.name, bands) + assert data_source.get_raster_bands(layer_name, item.name) == [bands] + assert ( + data_source.get_raster_bounds(layer_name, item.name, bands, data_source.projection) + == item.pixel_bounds + ) diff --git a/uv.lock b/uv.lock index 31aa7ff1..f4d3063c 100644 --- a/uv.lock +++ b/uv.lock @@ -5438,4 +5438,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -] +] \ No newline at end of file