diff --git a/src/harpy/shape/__init__.pyi b/src/harpy/shape/__init__.pyi index 99ffd36..7546a93 100644 --- a/src/harpy/shape/__init__.pyi +++ b/src/harpy/shape/__init__.pyi @@ -1,5 +1,6 @@ from ._cell_expansion import create_voronoi_boundaries -from ._shape import add_shapes_layer, filter_shapes_layer, intersect_rectangles, vectorize +from ._filters import filter_by_morphology, filter_by_shapes +from ._shape import add_shapes_layer, filter_shapes_layer, intersect_rectangles, prep_region_annotations, vectorize __all__ = [ "add_shapes_layer", @@ -7,4 +8,7 @@ __all__ = [ "filter_shapes_layer", "intersect_rectangles", "vectorize", + "prep_region_annotations", + "filter_by_morphology", + "filter_by_shapes", ] diff --git a/src/harpy/shape/_filters.py b/src/harpy/shape/_filters.py new file mode 100755 index 0000000..f7f8a5c --- /dev/null +++ b/src/harpy/shape/_filters.py @@ -0,0 +1,367 @@ +import geopandas as gpd +import numpy as np +import pandas as pd +from shapely.geometry import MultiPolygon, Polygon +from spatialdata import SpatialData + +from harpy.shape._shape import add_shapes_layer + + +def filter_by_morphology( + sdata: SpatialData, + shapes_layer: str, + output_shapes_layer: str, + filters: dict[str, tuple] | None = None, + keep_unsupported: bool = False, + calculate_all_features: bool = False, + calculate_all_features_grouped: bool = False, + shape_names_column: str = "name", + pixel_size_um: float = 1.0, + overwrite: bool = True, +): + """ + Filter polygons in a SpatialData shapes layer based on morphological features. + It is recommended to run `hp.sh.prep_region_annotations` first when filtering region annotations that + may contain multipolygons, unnamed annotations, etc. + + Parameters + ---------- + sdata + SpatialData object containing the input shapes layer. + shapes_layer + Name of the shapes layer in `sdata.shapes` containing polygons. + output_shapes_layer + Name of the shapes layer to store the filtered polygons. + filters + Dictionary specifying filtering thresholds. + Each entry should be of the form: + `{ "feature_name": (min_value, max_value) }` + where `None` can be used to skip a bound, e.g.: + `{ "area": (50, 200), "convexity": (None, 0.8) }` + + Supported features: + - `"area"`: Area of the polygon (px²). + → Use to filter out very small or very large polygons. + - `"perimeter"`: Perimeter length (px). + → Use to detect irregular boundaries or fragmented shapes. + - `"circularity"`: 4π * area / perimeter². + → 1 for a perfect circle; lower = irregular shape. + - `"compactness"`: perimeter² / area. + → Shape compactness measure. + - `"convex_area"`: Area of convex hull (px²). + → Useful with solidity and convexity. + - `"solidity"`: area / convex_area. + → Low values mean concave or fragmented shapes. + - `"convexity"`: convex_perimeter / perimeter. + → 1 for perfectly convex shapes. Lower for rough or spiky boundaries. + - `"centroid_dif"`: Distance between polygon and convex hull centroids normalized by the polygon area. + → Captures off-centered concavity or asymmetry. + - `"major_axis_length"`: Length of the longest side of the minimum rotated bounding box (px). + → Use to filter very long or very short shapes. + - `"minor_axis_length"`: Length of the shortest side of the minimum rotated bounding box (px). + → Use to filter very wide or very skinny shapes. + - `"major_minor_axis_ratio"`: major_axis_length / minor_axis_length. + → High values indicate elongated polygons; ~1 for round shapes. + + You can use grouped filters by suffixing features with `-grouped`. + Grouped filters merge polygons sharing the same name in `shape_names_column` before computing morphological features and + can be used interchangibly with regular filters. + + keep_unsupported + Only Polygon and MultiPolygon are supported. Set keep_unsupported to True to skip any unssupported geometries types (e.g. Point), but + keep them in the output_shapes_layer. Set to False to remove them from output_shapes_layer. + calculate_all_features + If True, computes all supported morphological features regardless of which ones are used in `filters` and saves them to `output_shapes_layer`. + If False, only computes and saves morphological features needed for `filters`. + calculate_all_features_grouped + If True, computes all morphological features for merged group geometries (by {shape_names_column}), regardless + of which ones are used in grouped filters. + shape_names_column + Column name in shapes layer containing geometry names. Required when using grouped filters. + + pixel_size_um + Scale factor to convert geometric measurements from pixel units to microns. + - Applied to: + * "area" and "convex_area" → scaled by (pixel_size_um)² + * "perimeter", "major_axis_length", and "minor_axis_length" → scaled by (pixel_size_um) + - Dimensionless ratios (e.g., "circularity", "compactness", "solidity", "convexity", "major_minor_axis_ratio") + are unaffected by scaling but are computed using the scaled geometric quantities. + Defaults to 1.0 (no scaling, i.e. units remain in pixels). + Note that this affects the min/max values that need to be specified in `filters`. + overwrite + Whether to overwrite an existing shapes layer. + + Returns + ------- + SpatialData object with updated shapes layer containing only the filtered polygons. + """ + # Get filters and split into individual and grouped filters + filters = filters or {} + + grouped_filters = {k.replace("-grouped", ""): v for k, v in filters.items() if k.endswith("-grouped")} + individual_filters = {k: v for k, v in filters.items() if not k.endswith("-grouped")} + + required_individual = set(individual_filters.keys()) + required_grouped = set(grouped_filters.keys()) + + if calculate_all_features: + required_individual.update( + [ + "area", + "perimeter", + "convex_area", + "circularity", + "compactness", + "solidity", + "convexity", + "centroid_dif", + "major_axis_length", + "minor_axis_length", + "major_minor_axis_ratio", + ] + ) + print("Calculating all supported morphological features (per polygon).") + + if calculate_all_features_grouped: + required_grouped.update( + [ + "area", + "perimeter", + "convex_area", + "circularity", + "compactness", + "solidity", + "convexity", + "centroid_dif", + "major_axis_length", + "minor_axis_length", + "major_minor_axis_ratio", + ] + ) + print("Calculating all supported morphological features (grouped).") + + # Create copy of shapes layer + gdf = sdata.shapes[shapes_layer].copy() + + # Filter out geometries that are not Polygon or MultiPolygon + supported_mask = gdf.geometry.apply(lambda x: isinstance(x, (Polygon, MultiPolygon))) + + unssuported = len(gdf) - supported_mask.sum() + if unssuported > 0: + print(f"Found {unssuported} non-polygon geometries in {shapes_layer}.") + if unssuported == len(gdf): + print("No supported geometries (Polygon, MultiPolygon) found. Exiting...") + return sdata + + if keep_unsupported: + gdf_skipped = gdf[~supported_mask].copy() + else: + gdf_skipped = gpd.GeoDataFrame(columns=gdf.columns, geometry=[]) + + gdf = gdf[supported_mask].copy() + + # Dissolve polygons by name + grouped_gdf = None + if grouped_filters or calculate_all_features_grouped: + if shape_names_column not in gdf.columns: + raise ValueError("Grouped filters require a 'name' column in the shapes layer.") + + grouped_gdf = gdf.dissolve(by=shape_names_column, as_index=False, aggfunc="first").copy() + grouped_gdf["geometry"] = gdf.groupby(shape_names_column).geometry.apply(lambda x: x.unary_union).values + print(f"Polygons merged by {shape_names_column}, resulting in {len(grouped_gdf)} groups.") + + # Compute metrics + def _compute_morphological_features(gdf, required, pixel_size_um, suffix): + if any(filter in required for filter in ["area", "circularity", "compactness", "solidity", "centroid_dif"]): + gdf[f"area{suffix}"] = gdf.geometry.area * pixel_size_um**2 + + if any(filter in required for filter in ["perimeter", "circularity", "compactness", "convexity"]): + gdf[f"perimeter{suffix}"] = gdf.geometry.length * pixel_size_um + + if any(filter in required for filter in ["convex_area", "solidity"]): + gdf[f"convex_area{suffix}"] = gdf.geometry.convex_hull.area * pixel_size_um**2 + + if "circularity" in required: + gdf[f"circularity{suffix}"] = 4 * np.pi * gdf[f"area{suffix}"] / (gdf[f"perimeter{suffix}"] ** 2) + + if "compactness" in required: + gdf[f"compactness{suffix}"] = (gdf[f"perimeter{suffix}"] ** 2) / gdf[f"area{suffix}"] + + if "solidity" in required: + gdf[f"solidity{suffix}"] = gdf[f"area{suffix}"] / gdf[f"convex_area{suffix}"] + + if "centroid_dif" in required: + gdf[f"centroid_x{suffix}"] = gdf.geometry.centroid.x + gdf[f"centroid_y{suffix}"] = gdf.geometry.centroid.y + hull_centroids = gdf.geometry.convex_hull.centroid + gdf[f"centroid_dif{suffix}"] = np.sqrt( + (gdf[f"centroid_x{suffix}"] - hull_centroids.x) ** 2 + + (gdf[f"centroid_y{suffix}"] - hull_centroids.y) ** 2 + ) / np.sqrt(gdf[f"area{suffix}"]) + + if "convexity" in required: + gdf[f"convexity{suffix}"] = (gdf.geometry.convex_hull.length * pixel_size_um) / gdf[f"perimeter{suffix}"] + + # Get bounding box lengths from minimum rotated rectangle + if any(filter in required for filter in ["major_axis_length", "minor_axis_length", "major_minor_axis_ratio"]): + + def _rotated_rect_axes(geom): + try: + rect = geom.minimum_rotated_rectangle + x, y = rect.exterior.coords.xy + # Compute edge lengths + edges = np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2) + edges = np.sort(edges[:-1]) # drop duplicate closing edge + minor, major = edges[0], edges[1] + ratio = major / minor if minor > 0 else np.nan + return major, minor, ratio + except Exception: + return np.nan, np.nan, np.nan + + results = np.array([_rotated_rect_axes(g) for g in gdf.geometry]) + major_axis_length, minor_axis_length, major_minor_axis_ratio = results.T + if any(filter in required for filter in ["major_axis_length", "major_minor_axis_ratio"]): + gdf[f"major_axis_length{suffix}"] = major_axis_length * pixel_size_um + if any(filter in required for filter in ["minor_axis_length", "major_minor_axis_ratio"]): + gdf[f"minor_axis_length{suffix}"] = minor_axis_length * pixel_size_um + if "major_minor_axis_ratio" in required: + gdf[f"major_minor_axis_ratio{suffix}"] = major_minor_axis_ratio + + return gdf + + gdf = _compute_morphological_features(gdf, required_individual, pixel_size_um, "") + grouped_gdf = _compute_morphological_features(grouped_gdf, required_grouped, pixel_size_um, "-grouped") + + # Merge grouped metrics back into main gdf + if grouped_gdf is not None and not grouped_gdf.empty: + grouped_metrics = [col for col in grouped_gdf.columns if col.endswith("-grouped")] + gdf = gdf.merge(grouped_gdf[[shape_names_column] + grouped_metrics], on=shape_names_column, how="left") + + # Apply filters + print(f"Applying morphological filter(s) on {len(gdf)} polygons...") + mask = np.ones(len(gdf), dtype=bool) + if filters: + for feature, (min_val, max_val) in filters.items(): + print(f"\nFiltering by '{feature}': {min_val} ≤ value ≤ {max_val}") + if feature not in gdf.columns: + raise KeyError(f"Feature '{feature}' was not computed. Check your spelling or supported feature list.") + + # Apply lower bound + if min_val is not None: + to_remove = (gdf[feature] < min_val) & mask + removed_low = to_remove.sum() + mask &= gdf[feature] >= min_val + print(f" - Removed {removed_low} polygons with {feature} ≤ {min_val}") + + # Apply upper bound + if max_val is not None: + to_remove = (gdf[feature] > max_val) & mask + removed_high = to_remove.sum() + mask &= gdf[feature] <= max_val + print(f" - Removed {removed_high} polygons with {feature} ≥ {max_val}") + + remaining = mask.sum() + print(f" → Remaining after filtering '{feature}': {remaining} polygons") + + filtered_gdf = gdf[mask] + print(f"\nKept {len(filtered_gdf)} / {len(gdf)} polygons after morphological filters.") + + # Add filtered shapes layer + input_gdf = pd.concat([filtered_gdf, gdf_skipped], ignore_index=False) + sdata = add_shapes_layer( + sdata, + input=input_gdf, + output_layer=output_shapes_layer, + overwrite=overwrite, + ) + + return sdata + + +def filter_by_shapes( + sdata: SpatialData, + target_shapes_layer: str, + mask_shapes_layer: str, + output_shapes_layer: str, + shape_names_column: str | None = None, + shape_names: str | list[str] | None = None, + keep_intersecting: bool = True, + overwrite: bool = False, +): + """ + Filter polygons in a target shapes layer (typically containg segmention boundaries) based on intersection with polygons in a mask layer + (typically containing region annotations). + + Parameters + ---------- + sdata + SpatialData object containing the target and mask shapes layers. + target_shapes_layer + Name of shapes layer whose polygons will be filtered. + mask_shapes_layer + Name of shapes layer used as mask for filtering. + output_shapes_layer + Name of the output shapes layer to store filtered polygons. + shape_names_column + Optional column in mask_shapes_layer to select specific polygons. + shape_names + Name or list of names of polygons in mask_shapes_layer to use for filtering. Ignored if shape_names_column is None. + keep_intersecting + If True, keeps polygons that intersect the mask. + If False, removes polygons that intersect the mask. + overwrite + If True, overwrites the output shapes layer if it exists. + + Returns + ------- + SpatialData object with filtered shapes layer. + """ + # Copy target layer + target_gdf = sdata.shapes[target_shapes_layer].copy() + + # Copy mask layer + mask_gdf = sdata.shapes[mask_shapes_layer].copy() + + # Optionally select subset of mask polygons + if shape_names_column is not None and shape_names is not None: + if shape_names_column not in mask_gdf.columns: + raise ValueError(f"Column '{shape_names_column}' not found in mask layer '{mask_shapes_layer}'.") + if isinstance(shape_names, str): + shape_names = [shape_names] + mask_gdf = mask_gdf[mask_gdf[shape_names_column].isin(shape_names)].copy() + if mask_gdf.empty: + raise ValueError(f"No geometries found in '{mask_shapes_layer}' matching {shape_names}.") + + # Build mask union + mask_union = mask_gdf.unary_union + + # Compute intersection boolean + target_gdf["intersects_mask"] = target_gdf.geometry.intersects(mask_union) + + # Filter depending on mode + if keep_intersecting: + filtered_gdf = target_gdf[target_gdf["intersects_mask"]].copy() + else: + filtered_gdf = target_gdf[~target_gdf["intersects_mask"]].copy() + + filtered_gdf.drop(columns=["intersects_mask"], inplace=True) + removed = len(target_gdf) - len(filtered_gdf) + + if keep_intersecting: + print( + f"Kept {len(filtered_gdf)} / {len(target_gdf)} geometries intersecting '{mask_shapes_layer}' (removed {removed})." + ) + else: + print( + f"Removed {removed} / {len(target_gdf)} geometries intersecting '{mask_shapes_layer}' (kept {len(filtered_gdf)})." + ) + + # Add to SpatialData + sdata = add_shapes_layer( + sdata, + input=filtered_gdf, + output_layer=output_shapes_layer, + overwrite=overwrite, + ) + + return sdata diff --git a/src/harpy/shape/_shape.py b/src/harpy/shape/_shape.py index c5473db..f758110 100644 --- a/src/harpy/shape/_shape.py +++ b/src/harpy/shape/_shape.py @@ -1,8 +1,9 @@ from __future__ import annotations +import numpy as np from dask.array import Array from geopandas import GeoDataFrame -from shapely.geometry import GeometryCollection, MultiPolygon, Polygon +from shapely.geometry import GeometryCollection, MultiPolygon, Point, Polygon from spatialdata import SpatialData from spatialdata.models._utils import MappingToCoordinateSystem_t from spatialdata.transformations import get_transformation @@ -179,3 +180,114 @@ def intersect_rectangles(rect1: list[int | float], rect2: list[int | float]) -> return [x_min, x_max, y_min, y_max] else: return None + + +def prep_region_annotations( + sdata: SpatialData, + shapes_layer: str, + output_shapes_layer: str, + shape_names_column: str = "name", + unnamed: str = "unnamed", + unique_shape_names_column: str = "name-unique", + erosion: float = 0.5, + overwrite: bool = False, +): + """ + Prepares region annotations in a shapes layer for `hp.sh.filter_by_morphology`, `hp.tb.assign_cells_to_shapes` and `hp.tb.compute_distance_to_shapes`. + Operations performed: + - Ensures a shape name column exists and fills missing names. + - Converts Points with a `radius` column into circular polygons. Points without a `radius` column will be preserved as Points. + - Slightly erodes polygons to avoid shared edges. + - Explodes multipolygons into separate single polygons. + - Generates unique names for shapes with duplicate base names. + + Parameters + ---------- + sdata + The SpatialData object containing the input shapes layer. + shapes_layer + The shapes layer in `sdata.shapes` to use as input. + output_shapes_layer + The output shapes layer in `sdata.tables` to which the updated shapes layer will be written. + shape_names_column + Column name in shapes layer containing geometry names. If not present, new names will be generated. + unnamed + Name to be assigned to any unnamed geometries in `shape_names_column`. Defaults to 'unnamed'. + unique_shape_names_column + Column name in which unique names will be created for single polygons by appending a counter to the original name in `shape_names_column` for polygons with the same name. Note + that multipolygons will be split in individual polygons and each will get a unique name based on the original name of the multipolygon. Unique names will be stored in + `{shape_names_column}-unique` in the updated shapes layer. + erosion + Number of pixels to erode polygons by. This can avoid problems with overlapping edges of geometries when calculating distances. Default is 0.5 (i.e. erosion by 0.5 pixels). + overwrite + If True, overwrites the `output_shapes_layer` if it already exists in `sdata`. + + Returns + ------- + Modified `sdata` object with updated updated shapes layer. + """ + # Create copy of shapes layer + gdf = sdata.shapes[shapes_layer].copy() + print(f"Found {len(gdf)} geometries in {shapes_layer}.") + + # Ensure shape_names_column exist + if shape_names_column not in gdf.columns: + gdf[shape_names_column] = unnamed + + gdf[shape_names_column] = gdf[shape_names_column].fillna("").astype(str) + unnamed_mask = gdf[shape_names_column] == "" + for _i, idx in enumerate(gdf[unnamed_mask].index): + gdf.at[idx, shape_names_column] = unnamed + + # Convert Points with a radius column to circular polygons + def _point_to_circle(geom, radius=None): + if isinstance(geom, Point) and radius is not None: + return geom.buffer(radius, resolution=16) + return geom + + if "radius" in gdf.columns: + mask_points_with_radius = gdf.geometry.geom_type.eq("Point") & gdf["radius"].notna() + n_converted = mask_points_with_radius.sum() + + if n_converted > 0: + print(f"Converting {n_converted} Point geometries with 'radius' to circular polygons.") + gdf.loc[mask_points_with_radius, "geometry"] = gdf.loc[mask_points_with_radius].apply( + lambda row: _point_to_circle(row.geometry, getattr(row, "radius", None)), axis=1 + ) + + # Slightly erode all polygons (this avoids any shared borders between polygons) + gdf["geometry"] = gdf.geometry.apply( + lambda geom: geom.buffer(-erosion, join_style=2, resolution=16) + if geom.geom_type in ["Polygon", "MultiPolygon"] + else geom + ) + polygon_mask = ~gdf.geometry.is_empty + removed = len(gdf) - polygon_mask.sum() + gdf = gdf[polygon_mask] # drop any polygons that collapsed to empty + if removed > 0: + print(f"Removed {removed} polygons that collapsed to empty after erosion.") + + # Explode multipolygons into single polygons (this allows us to treat multipolygons as unique polygons) + n_multipolygons = gdf.geometry.apply(lambda g: g.geom_type == "MultiPolygon").sum() + gdf = gdf.explode(index_parts=False).reset_index(drop=True) + n_after = len(gdf) + print( + f"Split {n_multipolygons} multipolygons into individual polygons. Total number of geometries after splitting multipolgons is {n_after}." + ) + + # Create unique names + sizes = gdf.groupby(shape_names_column)[shape_names_column].transform("size") + counter = gdf.groupby(shape_names_column).cumcount() + 1 + gdf[unique_shape_names_column] = np.where( + sizes.eq(1), gdf[shape_names_column], gdf[shape_names_column] + counter.astype(str) + ) + + # Add filtered shapes layer + sdata = add_shapes_layer( + sdata, + input=gdf, + output_layer=output_shapes_layer, + overwrite=overwrite, + ) + + return sdata diff --git a/src/harpy/table/__init__.pyi b/src/harpy/table/__init__.pyi old mode 100644 new mode 100755 index bd9b583..a254ba3 --- a/src/harpy/table/__init__.pyi +++ b/src/harpy/table/__init__.pyi @@ -4,8 +4,9 @@ from ._annotation import cluster_cleanliness, score_genes, score_genes_iter from ._clustering import kmeans, leiden from ._enrichment import nhood_enrichment from ._preprocess import preprocess_proteomics, preprocess_transcriptomics +from ._region_annotations import assign_cells_to_shapes, compute_distance_to_shapes from ._regionprops import add_regionprop_features -from ._table import add_table_layer, correct_marker_genes, filter_on_size +from ._table import add_table_layer, correct_marker_genes, filter_categorical, filter_numerical, filter_on_size from .cell_clustering._clustering import flowsom from .cell_clustering._preprocess import cell_clustering_preprocess from .cell_clustering._weighted_channel_expression import weighted_channel_expression @@ -16,6 +17,8 @@ __all__ = [ "add_table_layer", "correct_marker_genes", "filter_on_size", + "filter_numerical", + "filter_categorical", "flowsom", "weighted_channel_expression", "cell_clustering_preprocess", @@ -33,4 +36,6 @@ __all__ = [ "allocate", "bin_counts", "allocate_intensity", + "assign_cells_to_shapes", + "compute_distance_to_shapes", ] diff --git a/src/harpy/table/_region_annotations.py b/src/harpy/table/_region_annotations.py new file mode 100755 index 0000000..23eb70c --- /dev/null +++ b/src/harpy/table/_region_annotations.py @@ -0,0 +1,595 @@ +from typing import Literal + +import geopandas as gpd +import pandas as pd +from shapely.geometry import LineString, MultiLineString, MultiPolygon, Point, Polygon +from spatialdata import SpatialData + +from harpy.table._table import add_table_layer +from harpy.utils._keys import _REGION_KEY + + +def assign_cells_to_shapes( + sdata: SpatialData, + shapes_layer: str, + table_layer: str, + output_table_layer: str, + shape_names_column: str = "name", + unique_shape_names_column: str = "name-unique", + output_column: str = None, + mode: Literal["original_names", "unique_names", "both"] = "original_names", + create_column_per_shape: bool = False, + overlap_tolerance: float = 0.1, + spatial_key: str = "spatial", + xy_columns: tuple = None, + overwrite: bool = False, +): + """ + Assign cells to polygons in a shapes layer and update the sdata table layer. It is recommended to run `hp.sh.prep_region_annotations` first. + + Parameters + ---------- + sdata + The SpatialData object containing the input table layer and shapes layer. + shapes_layer + The shapes layer in `sdata.shapes` to use as input. + table_layer + The table layer in `sdata.tables` to use as input. + output_table_layer + The output table layer in `sdata.tables` to which the updated table layer will be written. + shape_names_column + Column name in shapes layer containing geometry names. + unique_shape_names_column + Column name in shapes layer containing unique geometry names. + output_column + Name of the output column in `sdata.tables[table_layer].obs` to store the shape name of the geometries a cell was found in (if `create_column_per_shape` is False). + For the `unique_names` mode, the output will be stored in `{output_column}-unique` in `sdata.tables[table_layer].obs` (if `create_column_per_shape` is False). + If create_column_per_shape is True and output_column is not None, then column names will be in the format `{output_colum}-{shape_name}`. + If create_column_per_shape is True and output_column is None, then column names will be in the format `{shape_name}`. + mode + When set to `original_names`, original polygon names from `shape_names_column` will be used. + When set to `unique_names`, unique polygon names from `unique_shape_names_column` will be used. Use `both`, to run both modes at the same time. + create_column_per_shape + If True, create one column (named according to the shape names) per shape indicating whether a cell is located inside it. + overlap_tolerance + Tolerance for detecting overlapping polygons (area units of geometry CRS). + spatial_key + Key in `sdata.tables[table_layer].obsm` containing spatial coordinates. Ignored if `xy_columns` is provided. + xy_columns + Tuple of column names in `sdata.tables[table_layer].obs` containing the x and y coordinates the cells. + If None, defaults to using coordinates from `sdata.tables[table_layer].obsm[spatial_key]`. + overwrite + If True, overwrites the `output_table_layer` and/or `output_shapes_layer` if it already exists in `sdata`. + + Notes + ----- + - Only `Polygon` and `MultiPolygon` geometries are supported. Non-polygon geometries (e.g., `Point`, `LineString`) are skipped. + + Returns + ------- + Modified `sdata` object with updated table layer. + """ + if not output_column and not create_column_per_shape: + raise ValueError("Specify `output_column` or set `create_column_per_shape=True`.") + + # Create copy of shapes layer + gdf = sdata.shapes[shapes_layer].copy() + + # Create copy of table layer + adata = sdata.tables[table_layer].copy() + + # Filter out geometries that are not Polygon or MultiPolygon + supported_mask = gdf.geometry.apply(lambda x: isinstance(x, (Polygon, MultiPolygon))) + skipped = len(gdf) - supported_mask.sum() + if skipped > 0: + print(f"Skipped {skipped} non-polygon geometries in {shapes_layer}.") + if skipped == len(gdf): + print("No supported geometries (Polygon, MultiPolygon) found.") + return sdata + + gdf = gdf[supported_mask].copy().reset_index(drop=True) + + # Check for overlapping polygons + total_area = gdf.geometry.area.sum() + union_area = gdf.geometry.unary_union.area + if total_area - union_area > overlap_tolerance and not create_column_per_shape: + raise ValueError( + f"Overlapping polygons detected in {shapes_layer}. Correct polygons or use create_column_per_shape." + ) + elif total_area - union_area > 0 and not create_column_per_shape: + print(f"Overlaps detected (Δ={total_area - union_area:.3f}), below tolerance threshold {overlap_tolerance}.") + + # Get cell coordinates + if xy_columns is not None: + x_col, y_col = xy_columns + coords = adata.obs[[x_col, y_col]].to_numpy() + else: + if spatial_key not in adata.obsm: + raise KeyError(f"No spatial coordinates found in `obsm['{spatial_key}']` and `xy_columns` not provided.") + coords = adata.obsm[spatial_key] + if coords.shape[1] != 2: + raise ValueError(f"`obsm['{spatial_key}']` must have shape (n_cells, 2).") + + # Function to assign cell to a polygon + def _assign_region(x, y, gdf, column, sindex): + point = Point(x, y) + + # Quick bounding box search to prefilter the cells + candidate_idx = list(sindex.intersection(point.bounds)) + if not candidate_idx: + return None + + # Slower point‑in‑polygon search + candidate_gdf = gdf.iloc[candidate_idx] + match = candidate_gdf[candidate_gdf.contains(point)] + + if match.empty: + return None + + return match[column].iloc[0] + + # Assign cells to polygons + if create_column_per_shape: + # Collect all relevant name columns based on mode + name_columns = [] + if mode in ("original_names", "both"): + name_columns.append(shape_names_column) + if mode in ("unique_names", "both"): + name_columns.append(unique_shape_names_column) + + # Collect all unique names across both columns + name_to_column = {} + for name_column in name_columns: + for name in gdf[name_column].dropna().astype(str).unique(): + if name not in name_to_column: + name_to_column[name] = name_column + + # Create column per shape name + for shape_name, name_column in name_to_column.items(): + col_name = f"{output_column}-{shape_name}" if output_column else shape_name + + gdf_shape = gdf[gdf[name_column] == shape_name] + if gdf_shape.empty: + continue + + sidx_shape = gdf_shape.sindex # Build spatial index + assigned = [_assign_region(x, y, gdf_shape, name_column, sidx_shape) for x, y in coords] + adata.obs[col_name] = assigned + + else: + # One output column per mode + run_assign_dict = {} + if mode in ("original_names", "both"): + run_assign_dict[output_column] = shape_names_column + if mode in ("unique_names", "both"): + run_assign_dict[f"{output_column}-unique"] = unique_shape_names_column + + for output_column_name, name_column in run_assign_dict.items(): + # Assign one shape name per cell (to be avoided with overlapping regions) + sidx = gdf.sindex # Build spatial index + assigned = [_assign_region(x, y, gdf, name_column, sidx) for x, y in coords] + adata.obs[output_column_name] = assigned + + # Add table layer + sdata = add_table_layer( + sdata, + adata=adata, + output_layer=output_table_layer, + region=adata.obs[_REGION_KEY].cat.categories.to_list(), + overwrite=overwrite, + ) + + return sdata + + +def compute_distance_to_shapes( + sdata: SpatialData, + shapes_layer: str, + table_layer: str, + output_table_layer: str, + shape_names_column: str = "name", + unique_shape_names_column: str = "name-unique", + output_name: str = None, + pixel_size_um: float = 1.0, + modes: list[ + Literal[ + "nearest_edge", + "nearest_edge_grouped", + "all_edges", + "nearest_outer_edge", + "nearest_outer_edge_grouped", + "all_outer_edges", + "nearest_inner_edge", + "nearest_inner_edge_grouped", + "nearest_inner_edge_grouped_unique", + "all_inner_edges", + "nearest_centroid", + "nearest_centroid_grouped", + "all_centroids", + "nearest_point", + "nearest_point_grouped", + "all_points", + ] + ] = None, + spatial_key: str = "spatial", + xy_columns: tuple = None, + overwrite: bool = False, +): + """ + Compute distances from cells to polygons in a shapes layer and update the sdata table layer. + It is recommended to run `hp.sh.prep_region_annotations` first. + + Parameters + ---------- + sdata + The SpatialData object containing the input table layer and shapes layer. + shapes_layer + The shapes layer in `sdata.shapes` to use as input. + table_layer + The table layer in `sdata.tables` to use as input. + output_table_layer + The output table layer in `sdata.tables` to which the updated table layer will be written. + shape_names_column + Column name in shapes layer containing geometry names. + unique_shape_names_column + Column name in shapes layer containing unique geometry names. + output_name + Prefix for new distance columns in `.obs`. If None, no prefix is added. + pixel_size_um + Scale factor to convert distances to microns. Defaults to 1 (i.e. distances are in pixels). + modes + Which distance features to calculate. Options include: + + Edge distances (For Polygon and MultiPolygon geometries) + - `"nearest_edge"`: Distance to the nearest edge (outer or inner) of any polygon. + Creates `{output_name}-distance_to_nearest_edge` and `{output_name}-name_of_nearest_edge` columns. + - `"nearest_edge_grouped"`: For each group of shapes with the same name, compute the distance to the edge that is nearest. + Creates `{output_name}-distance_to_nearest_edge_of_` and `{output_name}-name_of_nearest_edge_of_` columns. + - `"all_edges"`: For each individual polygon, compute the distance to its edge. + Creates `{output_name}-distance_to_edge_of_` column. + + Outer edge distances (For Polygon and MultiPolygon geometries) + - `"nearest_outer_edge"`: Distance to the nearest outer edge of any polygon. + Creates `{output_name}-distance_to_nearest_outer_edge` and `{output_name}-name_of_nearest_outer_edge` columns. + - `"nearest_outer_edge_grouped"`: For each group of shapes with the same name, compute the distance to the outer edge that is nearest. + Creates `{output_name}-distance_to_nearest_outer_edge_of_` and `{output_name}-name_of_nearest_outer_edge_of_` columns. + - `"all_outer_edges"`: For each individual polygon, compute the distance to its outer edge. + Creates `{output_name}-distance_to_outer_edge_of_` column. + + Inner edge (hole) distances (For Polygon and MultiPolygon geometries) + - `"nearest_inner_edge"`: Distance to the nearest interior edge (“hole”) of any polygon. + Creates `{output_name}-distance_to_nearest_inner_edge` and `{output_name}-name_of_nearest_inner_edge` columns. + - `"nearest_inner_edge_grouped"`: For each group of shapes with the same name, compute the distance to the nearest inner edge of that group. + Creates `{output_name}-distance_to_nearest_inner_edge_of_` and `{output_name}-name_of_nearest_inner_edge_of_` columns. + - `"nearest_inner_edge_grouped_unique"`: For each individual polygon, compute the distance to the nearest inner edge of that polygon. + Creates `{output_name}-distance_to_nearest_inner_edge_of_` and `{output_name}-name_of_nearest_inner_edge_of_` columns. + - `"all_inner_edges"`: For all holes, compute the distance to each inner edge. + Creates `{output_name}-distance_to_inner_edge_of_-hole` column. + + Centroid distances (For Polygon and MultiPolygon geometries) + - `"nearest_centroid"`: Distance to the nearest centroid of any polygon. + Creates `{output_name}-distance_to_nearest_centroid` and `{output_name}-name_of_nearest_centroid` columns. + - `"nearest_centroid_grouped"`: For each group of shapes with the same name, compute the distance to the centroid that is nearest. + Creates `{output_name}-distance_to_nearest_centroid_of_` and `{output_name}-name_of_nearest_centroid_of_` columns. + - `"all_centroids"`: For each individual polygon, compute the distance to its centroid. + Creates `{output_name}-distance_to_centroid_of_` column. + + Point distances (For Point geometries) + - `"nearest_point"`: Distance to the nearest point. + Creates `{output_name}-distance_to_nearest_point` and `{output_name}-name_of_nearest_point` columns. + - `"nearest_point_grouped"`: For each group of points with the same name, compute the nearest distance. + Creates `{output_name}-distance_to_nearest_point_of_` and `{output_name}-name_of_nearest_point_of_` columns. + - `"all_points"`: For each individual Point, compute the distance to point coordinates. + Creates `{output_name}-distance_to_point_` column. + + spatial_key + Key in `sdata.tables[table_layer].obsm` containing spatial coordinates. Ignored if `xy_columns` is provided. + xy_columns + Tuple of column names in `sdata.tables[table_layer].obs` containing the x and y coordinates the cells. + If None, defaults to using coordinates from `sdata.tables[table_layer].obsm[spatial_key]`. + overwrite + If True, overwrites the `output_table_layer` if it already exists in `sdata`. + + Notes + ----- + - Only `Polygon`, `MultiPolygon` and `Point` geometries are supported. Other geometries (e.g., `LineString`, `MultiPoint`) are skipped. + + Returns + ------- + Modified `sdata` object with updated table layer. + """ + if modes is None: + modes = ["all_edges"] + if output_name is not None: + output_name = f"{output_name}-" + elif output_name is None: + output_name = "" + + # Create copy of shapes layer + gdf = sdata.shapes[shapes_layer].copy() + + # Create copy of table layer + adata = sdata.tables[table_layer].copy() + + # Filter out geometries that are not Polygon, MultiPolygon or Point + supported_mask = gdf.geometry.apply(lambda x: isinstance(x, (Polygon, MultiPolygon, Point))) + skipped = len(gdf) - supported_mask.sum() + if skipped > 0: + print(f"Skipped {skipped} geometries in {shapes_layer} that are not Polygon, MultiPolygon or Point.") + + gdf = gdf[supported_mask].copy().reset_index(drop=True) + + # Separate gdf by type + gdf_polygons = gdf[gdf.geometry.geom_type.isin(["Polygon", "MultiPolygon"])].copy() + gdf_points = gdf[gdf.geometry.geom_type == "Point"].copy() + + has_polygons = not gdf_polygons.empty + has_points = not gdf_points.empty + + if not has_polygons and not has_points: + print("No supported geometries (Polygon, MultiPolygon, Point) found.") + return sdata + + # Get cell coordinates + if xy_columns is not None: + x_col, y_col = xy_columns + coords = adata.obs[[x_col, y_col]].to_numpy() + else: + if spatial_key not in adata.obsm: + raise KeyError(f"No spatial coordinates found in `obsm['{spatial_key}']` and `xy_columns` not provided.") + coords = adata.obsm[spatial_key] + if coords.shape[1] != 2: + raise ValueError(f"`obsm['{spatial_key}']` must have shape (n_cells, 2).") + + # Build point GeoDataFrame + pts = gpd.GeoSeries([Point(x, y) for x, y in coords], crs=gdf.crs) + + pts_gdf = gpd.GeoDataFrame({"geometry": pts}, crs=gdf.crs).set_index(adata.obs.index) + + # Polygon and MultiPolygon + if has_polygons: + # Extract exterior and interior boundaries + def _extract_exterior_lines(geom): + if isinstance(geom, Polygon): + return [geom.exterior] + elif isinstance(geom, MultiPolygon): + return [poly.exterior for poly in geom.geoms] + return [] + + def _extract_interior_lines(geom): + if isinstance(geom, Polygon): + return list(geom.interiors) + elif isinstance(geom, MultiPolygon): + lines = [] + for poly in geom.geoms: + lines.extend(poly.interiors) + return lines + return [] + + names = gdf_polygons[unique_shape_names_column].tolist() + exteriors = [MultiLineString(_extract_exterior_lines(geom)) for geom in gdf_polygons.geometry] + interiors = [MultiLineString(_extract_interior_lines(geom)) for geom in gdf_polygons.geometry] + + ext_gdf = gpd.GeoDataFrame({unique_shape_names_column: names}, geometry=exteriors, crs=gdf.crs) + int_gdf = gpd.GeoDataFrame({unique_shape_names_column: names}, geometry=interiors, crs=gdf.crs) + + # Edges (outer and inner) + if "nearest_edge" in modes: + print("Calculating 'nearest_edge' distances'") + all_edges_gdf = gpd.GeoDataFrame(pd.concat([ext_gdf, int_gdf], ignore_index=True), crs=gdf.crs) + joined = gpd.sjoin_nearest(pts_gdf, all_edges_gdf, how="left", distance_col="dist") + adata.obs[f"{output_name}distance_to_nearest_edge"] = joined["dist"].to_numpy() * pixel_size_um + adata.obs[f"{output_name}name_of_nearest_edge"] = joined[unique_shape_names_column] + + if "nearest_edge_grouped" in modes: + print("Calculating 'nearest_edge_grouped' distances'") + for name, group in gdf_polygons.groupby(shape_names_column): + edges = [] + labels = [] + for _idx, row in group.iterrows(): + edge_lines = _extract_exterior_lines(row.geometry) + _extract_interior_lines(row.geometry) + edges.extend(edge_lines) + labels.extend([row[unique_shape_names_column]] * len(edge_lines)) + + edge_gdf = gpd.GeoDataFrame({"geometry": edges, unique_shape_names_column: labels}, crs=gdf.crs) + joined = gpd.sjoin_nearest(pts_gdf, edge_gdf, how="left", distance_col="dist") + + adata.obs[f"{output_name}distance_to_nearest_edge_of_{name}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_edge_of_{name}"] = joined[unique_shape_names_column] + + if "all_edges" in modes: + print("Calculating 'all_edges' distances'") + for _, feat in gdf_polygons.reset_index(drop=True).iterrows(): + adata.obs[f"{output_name}distance_to_edge_of_{feat[unique_shape_names_column]}"] = ( + pts.distance(feat.geometry.boundary).to_numpy() * pixel_size_um + ) + + # Outer edges + if "nearest_outer_edge" in modes: + print("Calculating 'nearest_outer_edge' distances'") + joined = gpd.sjoin_nearest(pts_gdf, ext_gdf, how="left", distance_col="dist") + adata.obs[f"{output_name}distance_to_nearest_outer_edge"] = joined["dist"].to_numpy() * pixel_size_um + adata.obs[f"{output_name}name_of_nearest_outer_edge"] = joined[unique_shape_names_column] + + if "nearest_outer_edge_grouped" in modes: + print("Calculating 'nearest_outer_edge_grouped' distances'") + for name, group in gdf_polygons.groupby(shape_names_column): + edges = [] + labels = [] + for _idx, row in group.iterrows(): + edge_lines = _extract_exterior_lines(row.geometry) + edges.extend(edge_lines) + labels.extend([row[unique_shape_names_column]] * len(edge_lines)) + + edge_gdf = gpd.GeoDataFrame({"geometry": edges, unique_shape_names_column: labels}, crs=gdf.crs) + joined = gpd.sjoin_nearest(pts_gdf, edge_gdf, how="left", distance_col="dist") + + adata.obs[f"{output_name}distance_to_nearest_outer_edge_of_{name}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_outer_edge_of_{name}"] = joined[unique_shape_names_column] + + if "all_outer_edges" in modes: + print("Calculating 'all_outer_edges' distances'") + for _, feat in gdf_polygons.reset_index(drop=True).iterrows(): + adata.obs[f"{output_name}distance_to_outer_edge_of_{feat[unique_shape_names_column]}"] = ( + pts.distance(feat.geometry.exterior).to_numpy() * pixel_size_um + ) + + # Inner edges + hole_geoms = [] + hole_names = [] + + for _, row in gdf_polygons.reset_index(drop=True).iterrows(): + base = row[unique_shape_names_column] + geom = row.geometry + + holes = _extract_interior_lines(geom) + + if len(holes) == 0: + continue + + if len(holes) == 1: + hole = holes[0] + hole_geoms.append(LineString(hole.coords)) + hole_names.append(base) + + else: + for i, hole in enumerate(holes, start=1): + hole_geoms.append(LineString(hole.coords)) + hole_names.append(f"{base}-hole{i}") + + holes_gdf = gpd.GeoDataFrame({unique_shape_names_column: hole_names}, geometry=hole_geoms, crs=gdf.crs) + + if "nearest_inner_edge" in modes and not holes_gdf.empty: + print("Calculating 'nearest_inner_edge' distances'") + joined = gpd.sjoin_nearest(pts_gdf, holes_gdf, how="left", distance_col="dist") + adata.obs[f"{output_name}distance_to_nearest_inner_edge"] = joined["dist"].to_numpy() * pixel_size_um + adata.obs[f"{output_name}name_of_nearest_inner_edge"] = joined[unique_shape_names_column] + + if "nearest_inner_edge_grouped" in modes: + print("Calculating 'nearest_inner_edge_grouped' distances'") + for name, group in gdf_polygons.groupby(shape_names_column): + hole_lines = [] + hole_labels = [] + + for _idx, row in group.iterrows(): + base = row[unique_shape_names_column] + for i, hole in enumerate(_extract_interior_lines(row.geometry), start=1): + hole_lines.append(LineString(hole.coords)) + hole_labels.append(f"{base}-hole{i}") + + if not hole_lines: + continue + + hole_gdf = gpd.GeoDataFrame({"geometry": hole_lines, "hole_name": hole_labels}, crs=gdf.crs) + + joined = gpd.sjoin_nearest(pts_gdf, hole_gdf, how="left", distance_col="dist") + + adata.obs[f"{output_name}distance_to_nearest_inner_edge_of_{name}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_inner_edge_of_{name}"] = joined["hole_name"] + + if "nearest_inner_edge_grouped_unique" in modes: + print("Calculating 'nearest_inner_edge_grouped_unique' distances'") + for _idx, row in gdf_polygons.iterrows(): + base = row[unique_shape_names_column] + holes = _extract_interior_lines(row.geometry) + + if not holes: + continue + + hole_geoms = [LineString(hole.coords) for hole in holes] + hole_names = [f"{base}-hole{i + 1}" for i in range(len(holes))] + + hole_gdf = gpd.GeoDataFrame({"geometry": hole_geoms, "hole_name": hole_names}, crs=gdf.crs) + + joined = gpd.sjoin_nearest(pts_gdf, hole_gdf, how="left", distance_col="dist") + + adata.obs[f"{output_name}distance_to_nearest_inner_edge_of_{base}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_inner_edge_of_{base}"] = joined["hole_name"] + + if "all_inner_edges" in modes and not holes_gdf.empty: + print("Calculating 'all_inner_edges' distances'") + for _, hole in holes_gdf.iterrows(): + hole_name = hole[unique_shape_names_column] + col = f"{output_name}distance_to_inner_edge_of_{hole_name}" + adata.obs[col] = pts.distance(hole.geometry).to_numpy() * pixel_size_um + + # Centroids + centroids_df = gdf_polygons.reset_index(drop=True)[[unique_shape_names_column]].copy() + centroids_df["geometry"] = gdf_polygons.geometry.centroid.reset_index(drop=True) + + centroids_gdf = gpd.GeoDataFrame(centroids_df, geometry="geometry", crs=gdf.crs) + + if "nearest_centroid" in modes: + print("Calculating 'nearest_centroid' distances'") + joined = gpd.sjoin_nearest(pts_gdf, centroids_gdf, how="left", distance_col="dist") + adata.obs.loc[joined.index, f"{output_name}distance_to_nearest_centroid"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs.loc[joined.index, f"{output_name}name_of_nearest_centroid"] = joined[ + unique_shape_names_column + ].to_numpy() + + if "nearest_centroid_grouped" in modes: + print("Calculating 'nearest_centroid_grouped' distances'") + for name, group in gdf_polygons.groupby(shape_names_column): + centroids = group.geometry.centroid + labels = group[unique_shape_names_column].tolist() + + centroid_gdf = gpd.GeoDataFrame({"geometry": centroids, unique_shape_names_column: labels}, crs=gdf.crs) + + joined = gpd.sjoin_nearest(pts_gdf, centroid_gdf, how="left", distance_col="dist") + + adata.obs[f"{output_name}distance_to_nearest_centroid_of_{name}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_centroid_of_{name}"] = joined[unique_shape_names_column] + + if "all_centroids" in modes: + print("Calculating 'all_centroids' distances'") + for _idx, feat in gdf_polygons.reset_index(drop=True).iterrows(): + adata.obs[f"{output_name}distance_to_centroid_of_{feat[unique_shape_names_column]}"] = ( + pts.distance(feat.geometry.centroid).to_numpy() * pixel_size_um + ) + + # Point + if has_points: + if "nearest_point" in modes: + print("Calculating 'nearest_point' distances'") + joined = gpd.sjoin_nearest(pts_gdf, gdf_points, how="left", distance_col="dist") + adata.obs[f"{output_name}distance_to_nearest_point"] = joined["dist"].to_numpy() * pixel_size_um + adata.obs[f"{output_name}name_of_nearest_point"] = joined[f"{shape_names_column}-unique"] + + if "nearest_point_grouped" in modes: + print("Calculating 'nearest_point_grouped' distances'") + for name, group in gdf_points.groupby(shape_names_column): + joined = gpd.sjoin_nearest(pts_gdf, group, how="left", distance_col="dist") + adata.obs[f"{output_name}distance_to_nearest_point_of_{name}"] = ( + joined["dist"].to_numpy() * pixel_size_um + ) + adata.obs[f"{output_name}name_of_nearest_point_of_{name}"] = joined[f"{shape_names_column}-unique"] + + if "all_points" in modes: + print("Calculating 'all_points' distances'") + for _, row in gdf_points.iterrows(): + adata.obs[f"{output_name}distance_to_point_{row[unique_shape_names_column]}"] = ( + pts.distance(row.geometry).to_numpy() * pixel_size_um + ) + + # Add table layer + sdata = add_table_layer( + sdata, + adata=adata, + output_layer=output_table_layer, + region=adata.obs[_REGION_KEY].cat.categories.to_list(), + overwrite=overwrite, + ) + + return sdata diff --git a/src/harpy/table/_table.py b/src/harpy/table/_table.py index bc467b0..e91318b 100644 --- a/src/harpy/table/_table.py +++ b/src/harpy/table/_table.py @@ -1,8 +1,9 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import numpy as np +import pandas as pd from anndata import AnnData from spatialdata import SpatialData from spatialdata.models import TableModel @@ -223,15 +224,17 @@ def filter_on_size( labels_layer: list[str], table_layer: str, output_layer: str, - min_size: int = 100, - max_size: int = 100000, - update_shapes_layers: bool = True, cellsize_key=_CELLSIZE_KEY, + min_size: float | int | None = None, + max_size: float | int | None = None, + update_shapes_layers: bool = True, + prefix_filtered_shapes_layer: str = "filtered_size", overwrite: bool = False, ) -> SpatialData: """Returns the updated SpatialData object. All cells with a size outside of the min and max size range are removed using the `cellsize_key` in `.obs`. Run e.g. `harpy.tb.preprocess_transcriptomics` or `harpy.tb.preprocess_proteomics` to obtain cell sizes. + Cells are kept if min_size ≤ cellsize_key ≤ max_size. Parameters ---------- @@ -246,16 +249,82 @@ def filter_on_size( The table layer in `sdata`. output_layer The output table layer in `sdata`. + cellsize_key + Column in `sdata.tables[table_layer].obs` containing cell sizes. min_size - minimum size in pixels. + minimum size in pixels. If None, this value is not used for filtering. max_size - maximum size in pixels. + maximum size in pixels. If None, this value is not used for filtering. update_shapes_layers Whether to filter the shapes layers associated with `labels_layer`. If set to `True`, cells that do not appear in resulting `output_layer` (with `_REGION_KEY` equal to `labels_layer`) will be removed from the shapes layers (via `_INSTANCE_KEY`) in the `sdata` object. - Filtered shapes will be added to `sdata` with prefix 'filtered_size'. - cellsize_key - Column in `sdata.tables[table_layer].obs` containing cell sizes. + Filtered shapes will be added to `sdata` with prefix `prefix_filtered_shapes_layer`. + prefix_filtered_shapes_layer + prefix to use for filtered shapes layer if update_shapes_layers is True. Defaults to 'filtered_size'. + overwrite + If True, overwrites the `output_layer` if it already exists in `sdata`. + + Returns + ------- + The updated SpatialData object. + """ + sdata = filter_numerical( + sdata=sdata, + labels_layer=labels_layer, + table_layer=table_layer, + output_layer=output_layer, + numerical_column=cellsize_key, + min_value=min_size, + max_value=max_size, + update_shapes_layers=update_shapes_layers, + prefix_filtered_shapes_layer=prefix_filtered_shapes_layer, + overwrite=overwrite, + ) + + return sdata + + +def filter_numerical( + sdata: SpatialData, + labels_layer: list[str], + table_layer: str, + output_layer: str, + numerical_column: str, + min_value: float | int | None = None, + max_value: float | int | None = None, + update_shapes_layers: bool = True, + prefix_filtered_shapes_layer: str = "filtered", + overwrite: bool = False, +) -> SpatialData: + """Returns the updated SpatialData object. + + All cells with a size outside of the min and max size range are removed using the `numerical_column` in `.obs` (cells are kept if min_value ≤ numerical_column ≤ max_value). + + Parameters + ---------- + sdata + The SpatialData object. + labels_layer + The labels layer(s) of `sdata` used to select the cells via the _REGION_KEY in `sdata.tables[table_layer].obs`. + Note that if `output_layer` is equal to `table_layer` and overwrite is True, + cells in `sdata.tables[table_layer]` linked to other `labels_layer` (via the _REGION_KEY), will be removed from `sdata.tables[table_layer]` + (also from the backing zarr store if it is backed). + table_layer + The table layer in `sdata`. + output_layer + The output table layer in `sdata`. + numerical_column + Name of numerical column in `sdata.tables[table_layer].obs` to use for filtering with `min_value` and `max_value`. + min_value + minimum value of `numerical_column`. If None, this value is not used for filtering. + max_value + maximum value of `numerical_column`. If None, this value is not used for filtering. + update_shapes_layers + Whether to filter the shapes layers associated with `labels_layer`. + If set to `True`, cells that do not appear in resulting `output_layer` (with `_REGION_KEY` equal to `labels_layer`) will be removed from the shapes layers (via `_INSTANCE_KEY`) in the `sdata` object. + Filtered shapes will be added to `sdata` with prefix `prefix_filtered_shapes_layer`. + prefix_filtered_shapes_layer + prefix to use for filtered shapes layer if update_shapes_layers is True. Defaults to 'filtered'. overwrite If True, overwrites the `output_layer` if it already exists in `sdata`. @@ -264,13 +333,30 @@ def filter_on_size( The updated SpatialData object. """ process_table_instance = ProcessTable(sdata, labels_layer=labels_layer, table_layer=table_layer) - adata = process_table_instance._get_adata() + adata = process_table_instance._get_adata().copy() start = adata.shape[0] - # Filter cells based on size and distance - # need to do the copy because we pop the spatialdata_attrs in add_table_layer, otherwise it would not be updated inplace - adata = adata[adata.obs[cellsize_key] < max_size, :].copy() - adata = adata[adata.obs[cellsize_key] > min_size, :].copy() + if numerical_column not in adata.obs.columns: + raise ValueError(f"Column '{numerical_column}' not found in '{table_layer}.obs'. ") + + if not np.issubdtype(adata.obs[numerical_column].dtype, np.number): + raise ValueError( + f"Column '{numerical_column}' must be numeric, but dtype is {adata.obs[numerical_column].dtype}." + ) + + # Filter cells based on min and max values + mask = pd.Series(True, index=adata.obs.index) + + if min_value is not None: + below = (adata.obs[numerical_column] < min_value).sum() + log.info(f"Removed {below} cells below {min_value}.") + mask &= adata.obs[numerical_column] >= min_value + if max_value is not None: + above = (adata.obs[numerical_column] > max_value).sum() + log.info(f"Removed {above} cells above {max_value}.") + mask &= adata.obs[numerical_column] <= max_value + + adata = adata[mask, :].copy() sdata = add_table_layer( sdata, @@ -286,11 +372,125 @@ def filter_on_size( sdata, table_layer=output_layer, labels_layer=_labels_layer, - prefix_filtered_shapes_layer="filtered_size", + prefix_filtered_shapes_layer=prefix_filtered_shapes_layer, ) filtered = start - adata.shape[0] - log.info(f"{filtered} cells were filtered out based on size.") + log.info( + f"Removed {filtered} / {start} cells based on '{numerical_column}' (min={min_value}, max={max_value}) and kept {adata.shape[0]}." + ) + + return sdata + + +def filter_categorical( + sdata: SpatialData, + labels_layer: list[str], + table_layer: str, + output_layer: str, + categorical_column: str, + include_values: str | Sequence[str] | None = None, + exclude_values: str | Sequence[str] | None = None, + update_shapes_layers: bool = True, + prefix_filtered_shapes_layer: str = "filtered", + overwrite: bool = False, +) -> SpatialData: + """Filter cells based on categorical values in `.obs`. + + Removes or keeps cells based on specific values in a categorical column of + `sdata.tables[table_layer].obs`. + + Parameters + ---------- + sdata + The SpatialData object. + labels_layer + The labels layer(s) of `sdata` used to select the cells via the _REGION_KEY in `sdata.tables[table_layer].obs`. + Note that if `output_layer` is equal to `table_layer` and overwrite is True, + cells in `sdata.tables[table_layer]` linked to other `labels_layer` (via the _REGION_KEY), will be removed from `sdata.tables[table_layer]` + (also from the backing zarr store if it is backed). + table_layer + The table layer in `sdata`. + output_layer + The output table layer in `sdata`. + categorical_column + Name of the categorical column in `.obs` to use for filtering. + include_values + Value(s) to keep. Only cells whose `categorical_column` matches one of these + values will be kept. Mutually exclusive with `exclude_values`. + exclude_values + Value(s) to remove. Cells whose `categorical_column` matches one of these + values will be removed. Mutually exclusive with `include_values`. + update_shapes_layers + Whether to filter the shapes layers associated with `labels_layer`. + If set to `True`, cells that do not appear in resulting `output_layer` (with `_REGION_KEY` equal to `labels_layer`) will be removed from the shapes layers (via `_INSTANCE_KEY`) in the `sdata` object. + Filtered shapes will be added to `sdata` with prefix `prefix_filtered_shapes_layer`. + prefix_filtered_shapes_layer + prefix to use for filtered shapes layer if update_shapes_layers is True. Defaults to 'filtered'. + overwrite + If True, overwrites the `output_layer` if it already exists in `sdata`. + + Returns + ------- + The updated SpatialData object. + """ + if include_values is not None and exclude_values is not None: + raise ValueError("Specify only one of 'include_values' or 'exclude_values'.") + + process_table_instance = ProcessTable(sdata, labels_layer=labels_layer, table_layer=table_layer) + adata = process_table_instance._get_adata().copy() + start = adata.shape[0] + + if categorical_column not in adata.obs.columns: + raise ValueError(f"Column '{categorical_column}' not found in '{table_layer}.obs'.") + + # Ensure include/exclude are lists + if isinstance(include_values, str): + include_values = [include_values] + if isinstance(exclude_values, str): + exclude_values = [exclude_values] + + # Filter + mask = pd.Series(True, index=adata.obs.index) + + if include_values is not None: + kept = adata.obs[categorical_column].isin(include_values).sum() + removed = (~adata.obs[categorical_column].isin(include_values)).sum() + log.info(f"Found {kept} cells in {include_values} to keep.") + mask &= adata.obs[categorical_column].isin(include_values) + + elif exclude_values is not None: + removed = adata.obs[categorical_column].isin(exclude_values).sum() + kept = (~adata.obs[categorical_column].isin(exclude_values)).sum() + log.info(f"Found {removed} cells in {exclude_values} to remove.") + mask &= ~adata.obs[categorical_column].isin(exclude_values) + + adata = adata[mask, :].copy() + + sdata = add_table_layer( + sdata, + adata=adata, + output_layer=output_layer, + region=process_table_instance.labels_layer, + overwrite=overwrite, + ) + + if update_shapes_layers: + for _labels_layer in process_table_instance.labels_layer: + sdata = filter_shapes_layer( + sdata, + table_layer=output_layer, + labels_layer=_labels_layer, + prefix_filtered_shapes_layer=prefix_filtered_shapes_layer, + ) + + filtered = start - adata.shape[0] + log.info( + f"Removed {filtered} / {start} cells based on '{categorical_column}' " + f"({'included' if include_values is not None else 'excluded'}: " + f"{include_values if include_values is not None else exclude_values}) " + f"and kept {adata.shape[0]}." + ) return sdata