From d1945f2b6bb345be02564902c461aed063982fa3 Mon Sep 17 00:00:00 2001 From: khersameesh24 Date: Fri, 8 May 2026 13:47:04 +0200 Subject: [PATCH] removed redundant files under modules/name/resource/usr/bin --- ...ata.py => utility_extract_preview_data.py} | 0 .../resources/usr/bin/create_dataset.py | 96 --- .../usr/bin/preprocess_transcripts.py | 126 --- .../resources/usr/bin/ficture_preprocess.py | 101 --- .../resources/usr/bin/run_create_dataset.py | 253 ------ .../predict/resources/usr/bin/run_predict.py | 137 --- .../resources/usr/bin/spatialdata_merge.py | 82 -- .../resources/usr/bin/spatialdata_meta.py | 126 --- .../resources/usr/bin/spatialdata_write.py | 156 ---- .../utility/extract_preview_data/main.nf | 2 +- .../resources/usr/bin/extract_data.py | 208 ----- .../resources/usr/bin/get_coordinates.py | 60 -- .../resources/usr/bin/parquet_to_csv.py | 70 -- .../resources/usr/bin/resize_tif.py | 134 --- .../segger2xr/resources/usr/bin/segger2xr.py | 247 ------ .../resources/usr/bin/split_transcripts.py | 109 --- .../resources/usr/bin/stitch_transcripts.py | 808 ------------------ 17 files changed, 1 insertion(+), 2714 deletions(-) rename bin/{utility_extract_data.py => utility_extract_preview_data.py} (100%) delete mode 100755 modules/local/baysor/create_dataset/resources/usr/bin/create_dataset.py delete mode 100755 modules/local/baysor/preprocess/resources/usr/bin/preprocess_transcripts.py delete mode 100755 modules/local/ficture/preprocess/resources/usr/bin/ficture_preprocess.py delete mode 100755 modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py delete mode 100755 modules/local/segger/predict/resources/usr/bin/run_predict.py delete mode 100755 modules/local/spatialdata/merge/resources/usr/bin/spatialdata_merge.py delete mode 100755 modules/local/spatialdata/meta/resources/usr/bin/spatialdata_meta.py delete mode 100755 modules/local/spatialdata/write/resources/usr/bin/spatialdata_write.py delete mode 100755 modules/local/utility/extract_preview_data/resources/usr/bin/extract_data.py delete mode 100755 modules/local/utility/get_coordinates/resources/usr/bin/get_coordinates.py delete mode 100755 modules/local/utility/parquet_to_csv/resources/usr/bin/parquet_to_csv.py delete mode 100755 modules/local/utility/resize_tif/resources/usr/bin/resize_tif.py delete mode 100755 modules/local/utility/segger2xr/resources/usr/bin/segger2xr.py delete mode 100755 modules/local/utility/split_transcripts/resources/usr/bin/split_transcripts.py delete mode 100755 modules/local/xenium_patch/stitch/resources/usr/bin/stitch_transcripts.py diff --git a/bin/utility_extract_data.py b/bin/utility_extract_preview_data.py similarity index 100% rename from bin/utility_extract_data.py rename to bin/utility_extract_preview_data.py diff --git a/modules/local/baysor/create_dataset/resources/usr/bin/create_dataset.py b/modules/local/baysor/create_dataset/resources/usr/bin/create_dataset.py deleted file mode 100755 index 4e5a263a..00000000 --- a/modules/local/baysor/create_dataset/resources/usr/bin/create_dataset.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 -""" -Create a sampled dataset for Baysor preview mode. - -Reads a CSV transcript file and randomly samples a fraction of rows, -writing the result to a new CSV file. -""" - -import argparse -import csv -import os -import random -from pathlib import Path - - -class BaysorPreview(): - """ - Utility class to generate baysor preview dataset - """ - @staticmethod - def generate_dataset( - transcripts: Path, - sampled_transcripts: Path, - sample_fraction: float = 0.3, - random_state: int = 42, - prefix: str = "" - ) -> None: - """ - Reads a csv file & randomly samples a fraction of rows, - and writes the result to a .csv file. - - Args: - transcripts: unziped transcripts.csv from xenium bundle - sampled_transcripts: randomly subsampled transcripts.csv file - sample_fraction: Fraction of rows to sample - random_state: Seed for reproducibility - prefix: Output directory prefix - """ - - random.seed(random_state) - output_path = f"{prefix}/{sampled_transcripts}" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(transcripts, mode='rt', newline='') as infile, \ - open(output_path, mode='wt', newline='') as outfile: - - reader = csv.reader(infile) - writer = csv.writer(outfile) - - # get the header line - header = next(reader) - writer.writerow(header) - - # randomize csv rows to write - for row in reader: - if random.random() < float(sample_fraction): - writer.writerow(row) - - return None - - -def main() -> None: - """ - Run create dataset as nf module - """ - parser = argparse.ArgumentParser( - description="Create sampled dataset for Baysor preview" - ) - parser.add_argument( - "--transcripts", required=True, - help="Path to transcripts CSV file" - ) - parser.add_argument( - "--sample-fraction", required=True, type=float, - help="Fraction of rows to sample" - ) - parser.add_argument( - "--prefix", required=True, - help="Output directory prefix" - ) - args = parser.parse_args() - - sampled_transcripts = "sampled_transcripts.csv" - - # generate dataset - BaysorPreview.generate_dataset( - transcripts=args.transcripts, - sampled_transcripts=sampled_transcripts, - sample_fraction=args.sample_fraction, - prefix=args.prefix - ) - - return None - - -if __name__ == "__main__": - main() diff --git a/modules/local/baysor/preprocess/resources/usr/bin/preprocess_transcripts.py b/modules/local/baysor/preprocess/resources/usr/bin/preprocess_transcripts.py deleted file mode 100755 index 2662f83c..00000000 --- a/modules/local/baysor/preprocess/resources/usr/bin/preprocess_transcripts.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -""" -Preprocess Xenium transcripts for Baysor segmentation. - -Filters transcripts based on quality score and spatial coordinate thresholds, -removes negative control probes, and outputs filtered CSV for Baysor compatibility. -""" - -import argparse -import os - -import pandas as pd - - -def filter_transcripts( - transcripts: str, - min_qv: float = 20.0, - min_x: float = 0.0, - max_x: float = 24000.0, - min_y: float = 0.0, - max_y: float = 24000.0, - prefix: str = "", -) -> None: - """ - Filter transcripts based on the specified thresholds. - - Args: - transcripts: Path to transcripts parquet file - min_qv: Minimum Q-Score to pass filtering - min_x: Minimum x-coordinate threshold - max_x: Maximum x-coordinate threshold - min_y: Minimum y-coordinate threshold - max_y: Maximum y-coordinate threshold - prefix: Output directory prefix - """ - df = pd.read_parquet(transcripts, engine="pyarrow") - - # filter transcripts df with thresholds, ignore negative controls - filtered_df = df[ - (df["qv"] >= min_qv) - & (df["x_location"] >= min_x) - & (df["x_location"] <= max_x) - & (df["y_location"] >= min_y) - & (df["y_location"] <= max_y) - & (~df["feature_name"].str.startswith("NegControlProbe_")) - & (~df["feature_name"].str.startswith("antisense_")) - & (~df["feature_name"].str.startswith("NegControlCodeword_")) - & (~df["feature_name"].str.startswith("BLANK_")) - ] - - # change cell_id of cell-free transcripts to "0" (Baysor's no-cell sentinel). - # Modern Xenium stores cell_id as a string ("UNASSIGNED" for cell-free transcripts); - # legacy Xenium used integer -1. Normalize to string and handle both cases — pandas 3 - # rejects mixing int values into a string-dtype column. - filtered_df["cell_id"] = filtered_df["cell_id"].astype(str) - neg_cell_row = filtered_df["cell_id"].isin(["-1", "UNASSIGNED"]) - filtered_df.loc[neg_cell_row, "cell_id"] = "0" - - # Output filtered transcripts as CSV for Baysor 0.7.1 compatibility. - # Baysor's Julia Parquet.jl cannot read modern pyarrow Parquet files - # (pyarrow 15+ writes size_statistics Thrift field 16 unconditionally, - # which Baysor's old Thrift deserializer doesn't recognize). - os.makedirs(prefix, exist_ok=True) - filtered_df.to_csv(f"{prefix}/filtered_transcripts.csv", index=False) - - return None - - -def main() -> None: - """ - Run preprocess transcripts as nf module. - """ - parser = argparse.ArgumentParser( - description="Preprocess Xenium transcripts for Baysor" - ) - parser.add_argument( - "--transcripts", required=True, help="Path to transcripts parquet file" - ) - parser.add_argument("--prefix", required=True, help="Output directory prefix") - parser.add_argument( - "--min-qv", - type=float, - default=20.0, - help="Minimum Q-Score threshold (default: 20.0)", - ) - parser.add_argument( - "--min-x", - type=float, - default=0.0, - help="Minimum x-coordinate threshold (default: 0.0)", - ) - parser.add_argument( - "--max-x", - type=float, - default=24000.0, - help="Maximum x-coordinate threshold (default: 24000.0)", - ) - parser.add_argument( - "--min-y", - type=float, - default=0.0, - help="Minimum y-coordinate threshold (default: 0.0)", - ) - parser.add_argument( - "--max-y", - type=float, - default=24000.0, - help="Maximum y-coordinate threshold (default: 24000.0)", - ) - args = parser.parse_args() - - filter_transcripts( - transcripts=args.transcripts, - min_qv=args.min_qv, - min_x=args.min_x, - max_x=args.max_x, - min_y=args.min_y, - max_y=args.max_y, - prefix=args.prefix, - ) - - return None - - -if __name__ == "__main__": - main() diff --git a/modules/local/ficture/preprocess/resources/usr/bin/ficture_preprocess.py b/modules/local/ficture/preprocess/resources/usr/bin/ficture_preprocess.py deleted file mode 100755 index 2e0c687c..00000000 --- a/modules/local/ficture/preprocess/resources/usr/bin/ficture_preprocess.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -"""Preprocess Xenium transcripts for FICTURE analysis.""" - -import argparse -import gzip -import logging -import os -import re -import sys - -import pandas as pd - - -def parse_args(): - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Preprocess Xenium transcripts for FICTURE" - ) - parser.add_argument( - "--transcripts", required=True, help="Path to transcripts file (CSV)" - ) - parser.add_argument( - "--features", default="", help="Path to features file (optional)" - ) - parser.add_argument( - "--negative-control-regex", default="", help="Regex for negative control probes" - ) - return parser.parse_args() - - -def main(): - """Run FICTURE preprocessing.""" - args = parse_args() - print("[START]") - - negctrl_regex = "BLANK|NegCon" - if args.negative_control_regex: - negctrl_regex = args.negative_control_regex - - unit_info = ["X", "Y", "gene", "cell_id", "overlaps_nucleus"] - oheader = unit_info + ["Count"] - - feature = pd.DataFrame() - xmin = sys.maxsize - xmax = 0 - ymin = sys.maxsize - ymax = 0 - - output = "processed_transcripts.tsv.gz" - feature_file = "feature.clean.tsv.gz" - min_phred_score = 15 - - with gzip.open(output, "wt") as wf: - wf.write("\t".join(oheader) + "\n") - - for chunk in pd.read_csv(args.transcripts, header=0, chunksize=500000): - chunk = chunk.loc[(chunk.qv > min_phred_score)] - chunk.rename(columns={"feature_name": "gene"}, inplace=True) - if negctrl_regex != "": - chunk = chunk[ - ~chunk.gene.str.contains(negctrl_regex, flags=re.IGNORECASE, regex=True) - ] - chunk.rename(columns={"x_location": "X", "y_location": "Y"}, inplace=True) - chunk["Count"] = 1 - chunk[oheader].to_csv( - output, sep="\t", mode="a", index=False, header=False, float_format="%.2f" - ) - logging.info(f"{chunk.shape[0]}") - feature = pd.concat( - [feature, chunk.groupby(by="gene").agg({"Count": "sum"}).reset_index()] - ) - x0 = chunk.X.min() - x1 = chunk.X.max() - y0 = chunk.Y.min() - y1 = chunk.Y.max() - xmin = min(int(xmin), int(x0)) - xmax = max(int(xmax), int(x1)) - ymin = min(int(ymin), int(y0)) - ymax = max(int(ymax), int(y1)) - - if os.path.exists(args.features): - feature_list = [] - with open(args.features, "r") as ff: - for line in ff: - feature_list.append(line.strip("\n")) - feature = feature.groupby(by="gene").agg({"Count": "sum"}).reset_index() - feature = feature[[x in feature_list for x in feature["gene"]]] - feature.to_csv(feature_file, sep="\t", index=False) - - f = os.path.join(os.path.dirname(output), "coordinate_minmax.tsv") - with open(f, "w") as wf: - wf.write(f"xmin\t{xmin}\n") - wf.write(f"xmax\t{xmax}\n") - wf.write(f"ymin\t{ymin}\n") - wf.write(f"ymax\t{ymax}\n") - - print("[FINISH]") - - -if __name__ == "__main__": - main() diff --git a/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py b/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py deleted file mode 100755 index c73ab006..00000000 --- a/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py +++ /dev/null @@ -1,253 +0,0 @@ -#!/usr/bin/env python3 -""" -Run segger create_dataset with spatialxe-specific preprocessing and workarounds. - -Wraps segger's create_dataset_fast.py with: - - bundle_local symlink prep (handles read-only S3/Fusion mounts) - - parquet column statistics (segger needs these) - - WORKAROUND: filter trainable tiles from test_tiles when segger commit 0787167 mis-splits - - WORKAROUND: replace NaN bd.x with zeros after get_polygon_props produces NaN - -Each WORKAROUND should be removable when the upstream segger bug is fixed. -""" - -import argparse -import os -import shutil -import subprocess -import sys -from pathlib import Path - -# imports for actual work (used in functions below) -import pyarrow.parquet as pq -import pyarrow.compute as pc -import torch - - -SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/create_dataset_fast.py" - - -def parse_args(): - p = argparse.ArgumentParser() - p.add_argument("--bundle-dir", required=True) - p.add_argument("--output-dir", required=True) - p.add_argument("--sample-type", required=True, choices=["xenium"]) - p.add_argument("--tile-width", type=int, required=True) - p.add_argument("--tile-height", type=int, required=True) - p.add_argument("--n-workers", type=int, required=True) - # remaining args forwarded to segger CLI - args, extra = p.parse_known_args() - return args, extra - - -def prepare_bundle(bundle_dir): - """Create local bundle dir with absolute symlinks (S3/Fusion read-only-safe).""" - Path("bundle_local").mkdir(exist_ok=True) - for item in Path(bundle_dir).iterdir(): - try: - abs_path = item.resolve() - except Exception: - abs_path = item - target = Path("bundle_local") / item.name - if target.exists() or target.is_symlink(): - target.unlink() - target.symlink_to(abs_path) - - # Segger expects nucleus_boundaries.parquet but Xenium bundles have cell_boundaries.parquet - nb = Path("bundle_local/nucleus_boundaries.parquet") - cb = Path("bundle_local/cell_boundaries.parquet") - if not nb.exists() and cb.exists(): - print( - "Creating nucleus_boundaries.parquet symlink from cell_boundaries.parquet" - ) - nb.symlink_to(cb.resolve()) - - print("Bundle contents:") - for item in sorted(Path("bundle_local").iterdir()): - print(f" {item.name}") - - -def add_parquet_stats(): - """Rewrite key parquet files with column statistics (segger requires them).""" - Path("bundle_stats").mkdir(exist_ok=True) - for fname in ["transcripts.parquet", "nucleus_boundaries.parquet"]: - src = Path("bundle_local") / fname - dst = Path("bundle_stats") / fname - if not src.exists(): - print(f" Skip {src}") - continue - t = pq.read_table(str(src)) - pq.write_table(t, str(dst), write_statistics=True, compression="snappy") - print(f" Done {fname} ({len(t)} rows)") - - # Symlink everything else from bundle_local into bundle_stats - for item in Path("bundle_local").iterdir(): - dst = Path("bundle_stats") / item.name - if not dst.exists(): - dst.symlink_to(item.resolve()) - - # Debug: check overlaps_nucleus column in transcripts - print("\n=== Debugging overlaps_nucleus data ===") - tx = pq.read_table("bundle_stats/transcripts.parquet") - bd = pq.read_table("bundle_stats/nucleus_boundaries.parquet") - if "overlaps_nucleus" in tx.column_names: - col = tx.column("overlaps_nucleus") - print(f"overlaps_nucleus dtype: {col.type}") - unique_vals = pc.unique(col) - print(f"overlaps_nucleus unique values: {unique_vals.to_pylist()[:10]}") - val_counts = pc.value_counts(col) - print(f"overlaps_nucleus value_counts: {val_counts.to_pylist()}") - else: - print("WARNING: overlaps_nucleus column NOT FOUND in transcripts.parquet") - - if "cell_id" in tx.column_names and "cell_id" in bd.column_names: - tx_cells = set(pc.unique(tx.column("cell_id")).to_pylist()) - bd_cells = set(pc.unique(bd.column("cell_id")).to_pylist()) - overlap = tx_cells & bd_cells - print(f"Transcripts unique cell_ids: {len(tx_cells)}") - print(f"Boundaries unique cell_ids: {len(bd_cells)}") - print(f"Overlapping cell_ids: {len(overlap)}") - print("=== End Debug ===\n") - - -def run_segger_cli(args, extra): - cmd = [ - "python3", - SEGGER_CLI, - "--base_dir", - "bundle_stats", - "--data_dir", - args.output_dir, - "--sample_type", - args.sample_type, - "--tile_width", - str(args.tile_width), - "--tile_height", - str(args.tile_height), - "--n_workers", - str(args.n_workers), - *extra, - ] - print(f"Running: {' '.join(cmd)}") - result = subprocess.run(cmd) - if result.returncode != 0: - sys.exit(result.returncode) - - -def filter_trainable_tiles_if_needed(prefix): - """ - WORKAROUND: segger commit 0787167 has a bug where all tiles end up in test_tiles - regardless of test_prob/val_prob settings. Move ONLY trainable tiles (those with - edge_label_index) from test_tiles to train_tiles. - - Remove this function once segger >= 0.1.x is bumped with the upstream fix. - """ - train_dir = Path(prefix) / "train_tiles" / "processed" - test_dir = Path(prefix) / "test_tiles" / "processed" - val_dir = Path(prefix) / "val_tiles" / "processed" - - train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 - test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 - val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 - print( - f"Dataset split (before fix): train={train_count} val={val_count} test={test_count}" - ) - - if train_count == 0 and test_count > 0: - print( - "Applying workaround: filtering trainable tiles from test_tiles (segger split bug)" - ) - moved = 0 - skipped = 0 - for tile_path in list(test_dir.iterdir()): - if not tile_path.name.endswith(".pt"): - continue - try: - tile = torch.load(str(tile_path), weights_only=False) - edge_store = tile["tx", "belongs", "bd"] - if ( - hasattr(edge_store, "edge_label_index") - and edge_store.edge_label_index.numel() > 0 - ): - shutil.move(str(tile_path), str(train_dir / tile_path.name)) - moved += 1 - else: - skipped += 1 - except Exception as e: - print(f"Warning: Could not process {tile_path.name}: {e}") - skipped += 1 - print(f"Moved {moved} trainable tiles to train_tiles") - print(f"Skipped {skipped} test-only tiles (no edge_label_index)") - - train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 - test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 - val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 - print( - f"Dataset split (after fix): train={train_count} val={val_count} test={test_count}" - ) - - if train_count == 0: - print(f"ERROR: No trainable tiles were created in {train_dir}", file=sys.stderr) - print( - "This usually means no transcripts overlap with nucleus boundaries in the dataset.", - file=sys.stderr, - ) - print( - "Check if the Xenium bundle contains valid overlaps_nucleus data in transcripts.parquet.", - file=sys.stderr, - ) - sys.exit(1) - print(f"Successfully created {train_count} trainable tiles") - - -def fix_bd_x_nan(prefix): - """ - WORKAROUND: segger's get_polygon_props() produces NaN boundary features (bd.x) - when polygon geometries have zero area or index misalignment during GeoDataFrame - construction. Replace NaN bd.x with zeros so BCEWithLogitsLoss doesn't propagate NaN. - - Remove this function once segger >= 0.1.x is bumped with the upstream fix. - """ - fixed = 0 - total = 0 - for split in ["train_tiles", "test_tiles", "val_tiles"]: - tile_dir = Path(prefix) / split / "processed" - if not tile_dir.is_dir(): - continue - for tile_path in tile_dir.iterdir(): - if not tile_path.name.endswith(".pt"): - continue - total += 1 - tile = torch.load(str(tile_path), weights_only=False) - bd_x = tile["bd"].x - if bd_x.isnan().any(): - tile["bd"].x = torch.nan_to_num(bd_x, nan=0.0) - torch.save(tile, str(tile_path)) - fixed += 1 - print(f"Fixed NaN bd.x in {fixed}/{total} tiles") - - -def main(): - args, extra = parse_args() - - # Ensure numba cache dir is writable (env var should be set by caller, but belt-and-suspenders) - os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) - os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) - - prepare_bundle(args.bundle_dir) - print("Adding statistics to parquet files...") - add_parquet_stats() - - # Sanity-check bundle_stats - print("bundle_stats contents:") - for item in sorted(Path("bundle_stats").iterdir()): - print(f" {item.name}") - - run_segger_cli(args, extra) - - filter_trainable_tiles_if_needed(args.output_dir) - fix_bd_x_nan(args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/modules/local/segger/predict/resources/usr/bin/run_predict.py b/modules/local/segger/predict/resources/usr/bin/run_predict.py deleted file mode 100755 index 56a77ffc..00000000 --- a/modules/local/segger/predict/resources/usr/bin/run_predict.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -""" -Run segger predict with spatialxe-specific preprocessing. - -Wraps segger's predict_fast.py with: - - GPU enumeration (replaces inline python3 -c torch check) - - WORKAROUND: patch predict_parquet.py at runtime to add torch.no_grad() for ~30-50% VRAM savings - - WORKAROUND: seed random.choice for deterministic GPU assignment (avoids stochastic OOM) - -Both WORKAROUNDs should be removable once the patches are upstreamed to segger. -""" - -import argparse -import os -import subprocess -import sys - - -SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/predict_fast.py" - - -def parse_args(): - p = argparse.ArgumentParser() - p.add_argument("--models-dir", required=True) - p.add_argument("--segger-data-dir", required=True) - p.add_argument("--transcripts-file", required=True) - p.add_argument("--benchmarks-dir", required=True) - p.add_argument("--batch-size", type=int, required=True) - p.add_argument("--use-cc", required=True) - p.add_argument("--knn-method", required=True) - p.add_argument("--num-workers", type=int, required=True) - args, extra = p.parse_known_args() - return args, extra - - -def detect_gpus(): - """Return comma-separated list of available CUDA device ids (or "0" if none).""" - import torch - - print("=== GPU Detection (SEGGER_PREDICT) ===") - print(f"PyTorch CUDA available: {torch.cuda.is_available()}") - n = torch.cuda.device_count() - print(f"CUDA device count: {n}") - print("======================================") - if n > 0: - return ",".join(str(i) for i in range(n)) - return "0" - - -def patch_predict_parquet(): - """ - WORKAROUND: patch segger.prediction.predict_parquet at runtime. - - Avoids rebuilding the segger Docker image. Two patches: - 1. Add torch.no_grad() to disable gradient graphs during inference (~30-50% VRAM savings). - 2. Seed random for deterministic GPU assignment (avoids stochastic OOM). - - Remove this function once the patches are upstreamed to segger. - """ - import segger.prediction.predict_parquet as m - - pred_py = m.__file__ - print(f"Patching {pred_py}: torch.no_grad() + round-robin GPU assignment") - # Use sed via subprocess for in-place edit (matches the original behavior exactly) - subprocess.run( - [ - "sed", - "-i", - "s/with cp.cuda.Device(gpu_id):/with cp.cuda.Device(gpu_id), torch.no_grad():/", - pred_py, - ], - check=True, - ) - subprocess.run( - [ - "sed", - "-i", - "s/gpu_id = random.choice(gpu_ids)/random.seed(0); gpu_id = random.choice(gpu_ids)/", - pred_py, - ], - check=True, - ) - - -def run_segger_cli(args, extra, gpu_ids): - cmd = [ - "python3", - SEGGER_CLI, - "--models_dir", - args.models_dir, - "--segger_data_dir", - args.segger_data_dir, - "--transcripts_file", - args.transcripts_file, - "--benchmarks_dir", - args.benchmarks_dir, - "--batch_size", - str(args.batch_size), - "--use_cc", - str(args.use_cc), - "--knn_method", - args.knn_method, - "--num_workers", - str(args.num_workers), - "--gpu_ids", - gpu_ids, - *extra, - ] - print(f"Running: {' '.join(cmd)}") - result = subprocess.run(cmd) - if result.returncode != 0: - sys.exit(result.returncode) - - -def main(): - args, extra = parse_args() - - # Limit cupy GPU memory to 80% so PyTorch has headroom for graph attention ops - os.environ.setdefault("CUPY_GPU_MEMORY_LIMIT", "80%") - # Belt-and-suspenders: ensure PyTorch uses expandable segments - os.environ.setdefault( - "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512" - ) - # Numba cache directory - os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) - os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) - - gpu_ids = detect_gpus() - print(f"Using GPUs: {gpu_ids}") - - patch_predict_parquet() - - run_segger_cli(args, extra, gpu_ids) - - -if __name__ == "__main__": - main() diff --git a/modules/local/spatialdata/merge/resources/usr/bin/spatialdata_merge.py b/modules/local/spatialdata/merge/resources/usr/bin/spatialdata_merge.py deleted file mode 100755 index 409d8c00..00000000 --- a/modules/local/spatialdata/merge/resources/usr/bin/spatialdata_merge.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python3 -"""Merge two spatialdata bundles to create a layered spatialdata object.""" - -import argparse -import json -import os -import shutil - -import spatialdata - - -def parse_args(): - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Merge two spatialdata bundles") - parser.add_argument("--raw-bundle", required=True, help="Path to raw spatialdata bundle") - parser.add_argument("--redefined-bundle", required=True, help="Path to redefined spatialdata bundle") - parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") - parser.add_argument("--output-folder", required=True, help="Output folder name") - return parser.parse_args() - - -def main(): - """Run spatialdata merge.""" - args = parse_args() - print("[START]") - - output_dir = f"spatialdata/{args.prefix}/{args.output_folder}" - - # Ensure the output folder exists - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir) - - # Copy the entire reference bundle as is - for root, _, files in os.walk(args.raw_bundle): - rel_path = os.path.relpath(root, args.raw_bundle) - target_path = os.path.join(output_dir, rel_path) - os.makedirs(target_path, exist_ok=True) - for file in files: - shutil.copy(os.path.join(root, file), os.path.join(target_path, file)) - - # Rename folders in Points, Shapes, and Tables to raw_* - for category in ["points", "shapes", "tables"]: - category_path = os.path.join(output_dir, category) - if os.path.exists(category_path): - for folder in next(os.walk(category_path))[1]: - old_path = os.path.join(category_path, folder) - print(folder) - new_path = os.path.join(category_path, f"raw_{folder}") - os.rename(old_path, new_path) - - # Copy folders from redefined_bundle and rename them as redefined_* - for category in ["points", "shapes", "tables"]: - add_category_path = os.path.join(args.redefined_bundle, category) - output_category_path = os.path.join(output_dir, category) - os.makedirs(output_category_path, exist_ok=True) - - if os.path.exists(add_category_path): - for folder in next(os.walk(add_category_path))[1]: - src_folder = os.path.join(add_category_path, folder) - dest_folder = os.path.join(output_category_path, f"redefined_{folder}") - shutil.copytree(src_folder, dest_folder) - - # Invalidate consolidated metadata in zarr.json -- the directory renames above - # made the element paths in the metadata stale (e.g., 'points/transcripts' -> - # 'points/raw_transcripts'). Without consolidated metadata, sd.read_zarr() - # discovers elements by scanning the filesystem directly. - zarr_json = os.path.join(output_dir, "zarr.json") - if os.path.exists(zarr_json): - with open(zarr_json) as f: - meta = json.load(f) - if "consolidated_metadata" in meta: - del meta["consolidated_metadata"] - with open(zarr_json, "w") as f: - json.dump(meta, f) - print("[NOTE] Removed stale consolidated metadata from zarr.json") - - print("[FINISH]") - - -if __name__ == "__main__": - main() diff --git a/modules/local/spatialdata/meta/resources/usr/bin/spatialdata_meta.py b/modules/local/spatialdata/meta/resources/usr/bin/spatialdata_meta.py deleted file mode 100755 index 935f39b2..00000000 --- a/modules/local/spatialdata/meta/resources/usr/bin/spatialdata_meta.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -"""Add metadata to SpatialData bundle.""" - -import argparse -import json -import sys - -import pandas as pd -import spatialdata as sd -import zarr - -# Fix zarr v3 + anndata + numcodecs incompatibility: -# anndata's string writer passes numcodecs.VLenUTF8 to zarr.Group.create_array, -# but zarr v3 only accepts ArrayArrayCodec types. OME-Zarr 0.5 requires zarr v3 -# for images, so we can't downgrade the store format. Instead, we intercept -# create_array to strip numcodecs codecs and let zarr v3 handle strings natively. -import numcodecs -import zarr.core.group as _zarr_group - -_orig_create_array = _zarr_group.Group.create_array - - -def _v3_compat_create_array(self, *args, **kwargs): - """Strip numcodecs VLenUTF8 from codec params for zarr v3 compatibility.""" - for param in ("filters", "compressor", "object_codec"): - val = kwargs.get(param) - if val is None: - continue - if isinstance(val, numcodecs.vlen.VLenUTF8): - del kwargs[param] - elif isinstance(val, (list, tuple)): - cleaned = [v for v in val if not isinstance(v, numcodecs.vlen.VLenUTF8)] - if len(cleaned) != len(val): - if cleaned: - kwargs[param] = cleaned - else: - del kwargs[param] - return _orig_create_array(self, *args, **kwargs) - - -_zarr_group.Group.create_array = _v3_compat_create_array - - -def _is_arrow_backed(dtype): - """Check if a pandas dtype is backed by PyArrow.""" - return isinstance(dtype, pd.ArrowDtype) or ( - hasattr(dtype, "storage") and getattr(dtype, "storage", None) == "pyarrow" - ) or "pyarrow" in str(dtype) - - -def _convert_df_arrow_to_numpy(df): - """Convert Arrow-backed dtypes in a DataFrame to numpy object dtype.""" - for col in df.columns: - dtype = df[col].dtype - if _is_arrow_backed(dtype): - df[col] = df[col].astype("object") - elif isinstance(dtype, pd.CategoricalDtype): - cats = dtype.categories - if cats is not None and _is_arrow_backed(cats.dtype): - df[col] = df[col].cat.rename_categories(cats.astype("object")) - if _is_arrow_backed(df.index.dtype): - df.index = pd.Index(df.index.astype("object")) - - -def convert_arrow_to_numpy(sdata): - """Convert Arrow-backed dtypes to numpy for anndata zarr write compatibility.""" - for table_key in list(sdata.tables.keys()): - adata = sdata.tables[table_key] - _convert_df_arrow_to_numpy(adata.obs) - _convert_df_arrow_to_numpy(adata.var) - - -def parse_args(): - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Add metadata to SpatialData bundle") - parser.add_argument("--spatialdata-bundle", required=True, help="Path to spatialdata bundle") - parser.add_argument("--xenium-bundle", required=True, help="Path to xenium bundle") - parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") - parser.add_argument("--metadata", required=True, help="Metadata string from Nextflow meta map") - parser.add_argument("--output-folder", required=True, help="Output folder name") - return parser.parse_args() - - -def main(): - """Run spatialdata metadata addition.""" - args = parse_args() - print("[START]") - - sdata = sd.read_zarr(args.spatialdata_bundle) - - # Convert metadata into dict - print("[NOTE] Read in provenance ...") - metadata = args.metadata.strip("[]") # Remove square brackets - pairs = metadata.split(", ") # Split by comma and space - metadata = {k: v for k, v in (pair.split(":") for pair in pairs)} # Create dictionary - - for key in metadata: - if key not in sdata['raw_table'].uns['spatialdata_attrs']: - sdata['raw_table'].uns['spatialdata_attrs'][key] = metadata[key] - else: - print(f'[ERROR] {key} already exist in sdata[raw_table].uns[spatialdata_attrs].', file=sys.stderr) - - # Add experimental metadata - print("[NOTE] Read in experiment metadata ...") - sdata['raw_table'].uns['experiment_xenium'] = '' - metadata_experiment = f'{args.xenium_bundle}/experiment.xenium' - with open(metadata_experiment, "r") as f: - metadata_experiment = json.load(f) - sdata['raw_table'].uns['experiment_xenium'] = json.dumps(metadata_experiment) - - # Add gene panel metadata - print("[NOTE] Read in gene panel metadata ...") - sdata['raw_table'].uns['gene_panel'] = '' - metadata_gene_panel = f'{args.xenium_bundle}/gene_panel.json' - with open(metadata_gene_panel, "r") as f: - metadata_gene_panel = json.load(f) - sdata['raw_table'].uns['gene_panel'] = json.dumps(metadata_gene_panel) - - convert_arrow_to_numpy(sdata) - sdata.write(f"spatialdata/{args.prefix}/{args.output_folder}", overwrite=True, consolidate_metadata=True, sdata_formats=None) - - print("[FINISH]") - - -if __name__ == "__main__": - main() diff --git a/modules/local/spatialdata/write/resources/usr/bin/spatialdata_write.py b/modules/local/spatialdata/write/resources/usr/bin/spatialdata_write.py deleted file mode 100755 index 421e830f..00000000 --- a/modules/local/spatialdata/write/resources/usr/bin/spatialdata_write.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 -"""Write spatialdata object from segmentation format.""" - -import argparse -import sys - -import pandas as pd -import spatialdata -from spatialdata_io import xenium - -# Fix zarr v3 + anndata + numcodecs incompatibility: -# anndata's string writer passes numcodecs.VLenUTF8 to zarr.Group.create_array, -# but zarr v3 only accepts ArrayArrayCodec types. OME-Zarr 0.5 requires zarr v3 -# for images, so we can't downgrade the store format. Instead, we intercept -# create_array to strip numcodecs codecs and let zarr v3 handle strings natively. -import numcodecs -import zarr.core.group as _zarr_group - -_orig_create_array = _zarr_group.Group.create_array - - -def _v3_compat_create_array(self, *args, **kwargs): - """Strip numcodecs VLenUTF8 from codec params for zarr v3 compatibility.""" - for param in ("filters", "compressor", "object_codec"): - val = kwargs.get(param) - if val is None: - continue - if isinstance(val, numcodecs.vlen.VLenUTF8): - del kwargs[param] - elif isinstance(val, (list, tuple)): - cleaned = [v for v in val if not isinstance(v, numcodecs.vlen.VLenUTF8)] - if len(cleaned) != len(val): - if cleaned: - kwargs[param] = cleaned - else: - del kwargs[param] - return _orig_create_array(self, *args, **kwargs) - - -_zarr_group.Group.create_array = _v3_compat_create_array - - -def _is_arrow_backed(dtype): - """Check if a pandas dtype is backed by PyArrow.""" - return ( - isinstance(dtype, pd.ArrowDtype) - or (hasattr(dtype, "storage") and getattr(dtype, "storage", None) == "pyarrow") - or "pyarrow" in str(dtype) - ) - - -def _convert_df_arrow_to_numpy(df): - """Convert Arrow-backed dtypes in a DataFrame to numpy object dtype. - - Handles three cases: - 1. Regular columns with Arrow-backed dtypes - 2. Categorical columns whose categories are Arrow-backed - 3. Index with Arrow-backed dtype - """ - for col in df.columns: - dtype = df[col].dtype - if _is_arrow_backed(dtype): - df[col] = df[col].astype("object") - elif isinstance(dtype, pd.CategoricalDtype): - cats = dtype.categories - if cats is not None and _is_arrow_backed(cats.dtype): - df[col] = df[col].cat.rename_categories(cats.astype("object")) - if _is_arrow_backed(df.index.dtype): - df.index = pd.Index(df.index.astype("object")) - - -def convert_arrow_to_numpy(sdata): - """Convert Arrow-backed dtypes to numpy for anndata zarr write compatibility.""" - for table_key in list(sdata.tables.keys()): - adata = sdata.tables[table_key] - _convert_df_arrow_to_numpy(adata.obs) - _convert_df_arrow_to_numpy(adata.var) - - -def parse_args(): - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Write spatialdata object from segmentation format") - parser.add_argument("--bundle", required=True, help="Path to input bundle") - parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") - parser.add_argument("--output-folder", required=True, help="Output folder name") - parser.add_argument("--segmented-object", required=True, help="Segmented object type (cells, nuclei, cells_and_nuclei)") - parser.add_argument("--coordinate-space", required=True, help="Coordinate space (pixels, microns)") - parser.add_argument("--format", required=True, help="Input format (xenium)") - return parser.parse_args() - - -def main(): - """Run spatialdata write.""" - args = parse_args() - print("[START]") - - cells_as_circles = False - cells_boundaries = False - nucleus_boundaries = False - cells_labels = False - nucleus_labels = False - - if args.segmented_object == "cells": - cells_boundaries = True - cells_labels = True - elif args.segmented_object == "nuclei": - nucleus_boundaries = True - nucleus_labels = True - elif args.segmented_object == "cells_and_nuclei": - cells_boundaries = True - nucleus_boundaries = True - cells_labels = True - nucleus_labels = True - else: - cells_as_circles = False - - # set sd variables based on the coordinate space - if args.coordinate_space == "pixels": - cells_labels = True - nucleus_labels = True - # Labels are sufficient in pixel space; boundaries can contain - # degenerate polygons (< 4 vertices) from XeniumRanger that - # crash spatialdata_io's shapely LinearRing parser. - cells_boundaries = False - nucleus_boundaries = False - - if args.coordinate_space == "microns": - cells_labels = False - cells_boundaries = True - nucleus_boundaries = False - nucleus_labels = False - cells_as_circles = False - - if args.format == "xenium": - sd_xenium_obj = xenium( - args.bundle, - cells_as_circles=cells_as_circles, - cells_boundaries=cells_boundaries, - nucleus_boundaries=nucleus_boundaries, - cells_labels=cells_labels, - nucleus_labels=nucleus_labels, - transcripts=True, - morphology_mip=True, - morphology_focus=True, - ) - print(sd_xenium_obj) - convert_arrow_to_numpy(sd_xenium_obj) - sd_xenium_obj.write(f"spatialdata/{args.prefix}/{args.output_folder}") - else: - sys.exit("[ERROR] Format not found") - - print("[FINISH]") - - -if __name__ == "__main__": - main() diff --git a/modules/local/utility/extract_preview_data/main.nf b/modules/local/utility/extract_preview_data/main.nf index 1240ddbf..821effc5 100644 --- a/modules/local/utility/extract_preview_data/main.nf +++ b/modules/local/utility/extract_preview_data/main.nf @@ -26,7 +26,7 @@ process EXTRACT_PREVIEW_DATA { prefix = task.ext.prefix ?: "${meta.id}" """ - utility_extract_data.py \\ + utility_extract_preview_data.py \\ --preview-html ${preview_html} \\ --prefix ${prefix} """ diff --git a/modules/local/utility/extract_preview_data/resources/usr/bin/extract_data.py b/modules/local/utility/extract_preview_data/resources/usr/bin/extract_data.py deleted file mode 100755 index 0ea737c2..00000000 --- a/modules/local/utility/extract_preview_data/resources/usr/bin/extract_data.py +++ /dev/null @@ -1,208 +0,0 @@ -#!/usr/bin/env python3 -""" -Extract preview data from Baysor preview HTML reports. - -Parses embedded Vega-Lite spec variables and base64 PNG images from the -Baysor preview.html file, writing MultiQC-compatible TSV and PNG files. -""" - -import argparse -import base64 -import html -import json -import re -import sys -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import pandas as pd -from bs4 import BeautifulSoup - - -def get_png_files(soup: BeautifulSoup, outdir: Path) -> None: - """Get png base64 images following specific h1 tags in preview.html""" - target_ids = ["Transcript_Plots", "Noise_Level"] - outdir.mkdir(parents=True, exist_ok=True) - - for h1_id in target_ids: - h1_tag = soup.find("h1", id=h1_id) - if not h1_tag: - print(f"[WARN] No

with id {h1_id} found") - continue - - # Look for the first after the h1 in the DOM - img_tag = h1_tag.find_next("img") - if not img_tag or not img_tag.get("src"): - print(f"[WARN] No found after h1#{h1_id}") - continue - - img_src = img_tag["src"] - if img_src.startswith("data:image/png;base64,"): - base64_data = img_src.split(",", 1)[1] - data = base64.b64decode(base64_data) - else: - print(f"[WARN] img src is not base64 PNG for h1#{h1_id}") - continue - - # save png files with _mqc suffix for MultiQC integration - img_name = f"{h1_id}_mqc.png".lower() - out_path = outdir / img_name - with open(out_path, "wb") as f: - f.write(data) - - print(f"[INFO] Saved {img_name}") - - return None - - -def extract_js_object(text: str, start_idx: int) -> Tuple[Optional[str], int]: - """Extract json-like object starting at start_idx.""" - if start_idx >= len(text) or text[start_idx] != "{": - return None, start_idx - - stack, in_str, escape, quote = [], False, False, None - for i in range(start_idx, len(text)): - ch = text[i] - if in_str: - if escape: - escape = False - elif ch == "\\": - escape = True - elif ch == quote: - in_str = False - else: - if ch in ('"', "'"): - in_str, quote = True, ch - elif ch == "{": - stack.append("{") - elif ch == "}": - stack.pop() - if not stack: - return text[start_idx : i + 1], i + 1 - elif ch == "/" and i + 1 < len(text): - # skip js comments - nxt = text[i + 1] - if nxt == "/": - end = text.find("\n", i + 2) - i = len(text) - 1 if end == -1 else end - elif nxt == "*": - end = text.find("*/", i + 2) - if end == -1: - break - i = end + 1 - - return None, start_idx - - -def js_to_json(js: str) -> str: - """Convert a JS object string to valid JSON.""" - # Remove comments - js = re.sub(r"/\*.*?\*/", "", js, flags=re.S) - js = re.sub(r"//[^\n]*", "", js) - - # Convert single-quoted strings to double-quoted strings - js = re.sub( - r"'((?:\\.|[^'\\])*)'", - lambda m: '"' + m.group(1).replace('"', '\\"') + '"', - js, - ) - - # Remove trailing commas - js = re.sub(r",\s*(?=[}\]])", "", js) - js = re.sub(r",\s*,+", ",", js) - - return js.strip() - - -def find_variables(script_text: str) -> Dict[str, str]: - """Find all 'var|let|const specN =' declarations and extract their objects.""" - specs: Dict[str, str] = {} - script_text = html.unescape(script_text) - pattern = re.compile(r"(?:var|let|const)\s+(spec\d+)\s*=\s*{", re.I) - - for match in pattern.finditer(script_text): - var = match.group(1) - obj, _ = extract_js_object(script_text, match.end() - 1) - if obj: - specs[var] = obj - else: - print(f"[WARN] Could not extract object for {var}") - return specs - - -def write_tsvs(specs: Dict[str, str], outdir: Path) -> List[Path]: - """Convert extracted json to tsv.""" - outdir.mkdir(parents=True, exist_ok=True) - written: List[Path] = [] - - for var, js_obj in specs.items(): - try: - data = json.loads(js_to_json(js_obj)) - values = data.get("data", {}).get("values", []) - if not values: - print(f"[WARN] No data.values found in {var}") - continue - - df = pd.DataFrame(values) - outpath = outdir / f"{var}_mqc.tsv" - - with open(outpath, "w") as f: - f.write("# plot_type: linegraph\n") - f.write(f"# section_name: {var}\n") - f.write("# description: Extracted preview data\n") - df.to_csv(f, sep="\t", index=False) - - written.append(outpath) - print(f"[INFO] Wrote {outpath} ({len(df)} rows × {len(df.columns)} cols)") - except Exception as e: - print(f"[ERROR] Failed to process {var}: {e}") - - return written - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Extract preview data from Baysor preview HTML reports." - ) - parser.add_argument( - "--preview-html", - required=True, - help="Path to Baysor preview HTML file", - ) - parser.add_argument( - "--prefix", - required=True, - help="Output directory prefix (sample ID)", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - input_path: Path = Path(args.preview_html) - outdir: Path = Path(args.prefix) - - text = input_path.read_text(encoding="utf-8", errors="ignore") - soup = BeautifulSoup(text, "html.parser") - - # get the script section - if " argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Get transcript coordinate bounds from a Parquet file." - ) - parser.add_argument( - "--transcripts", - required=True, - help="Path to transcripts parquet file", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - result = get_coordinates(args.transcripts) - print(",".join(str(v) for v in result)) diff --git a/modules/local/utility/parquet_to_csv/resources/usr/bin/parquet_to_csv.py b/modules/local/utility/parquet_to_csv/resources/usr/bin/parquet_to_csv.py deleted file mode 100755 index bfa19c40..00000000 --- a/modules/local/utility/parquet_to_csv/resources/usr/bin/parquet_to_csv.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -""" -Convert a Parquet file to CSV format. - -Reads a Parquet file and writes it as CSV, optionally gzip-compressed. -""" - -import argparse -from pathlib import Path - -import pandas as pd - - -def convert_parquet( - transcripts: str, - extension: str = ".csv", - prefix: str = "", -) -> None: - """ - Convert a Parquet file to CSV or CSV.GZ format. - - Args: - transcripts: Filename of the input parquet file - extension: Output extension ('.csv' or '.gz' for gzip) - prefix: Output directory prefix - """ - df = pd.read_parquet(transcripts, engine="pyarrow") - - Path(prefix).mkdir(parents=True, exist_ok=True) - - if extension == ".gz": - output = transcripts.replace(".parquet", ".csv.gz") - df.to_csv(f"{prefix}/{output}", compression="gzip", index=False) - else: - output = transcripts.replace(".parquet", ".csv") - df.to_csv(f"{prefix}/{output}", index=False) - - return None - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Convert a Parquet file to CSV format." - ) - parser.add_argument( - "--transcripts", - required=True, - help="Input parquet filename", - ) - parser.add_argument( - "--extension", - default=".csv", - help="Output extension: '.csv' or '.gz' (default: .csv)", - ) - parser.add_argument( - "--prefix", - required=True, - help="Output directory prefix (sample ID)", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - convert_parquet( - transcripts=args.transcripts, - extension=args.extension, - prefix=args.prefix, - ) diff --git a/modules/local/utility/resize_tif/resources/usr/bin/resize_tif.py b/modules/local/utility/resize_tif/resources/usr/bin/resize_tif.py deleted file mode 100755 index 6cca640d..00000000 --- a/modules/local/utility/resize_tif/resources/usr/bin/resize_tif.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3 -""" -Resize a segmentation TIFF mask to match transcript coordinates. - -This script rescales a segmentation mask image to match the coordinate -space of Xenium transcript data using microns-per-pixel metadata. -""" - -import argparse -import json -import os -from typing import Tuple - -import numpy as np -import pandas as pd -import tifffile -from skimage.transform import resize - - -def read_mask(mask_path: str) -> np.ndarray: - """Read the segmentation mask from a TIFF file.""" - print(f"Reading mask: {mask_path}") - mask = tifffile.imread(mask_path) - print(f"Mask shape: {mask.shape}, dtype: {mask.dtype}") - return mask - - -def read_transcript_bounds(transcript_path: str) -> Tuple[float, float, float, float]: - """Read transcript coordinates and return their bounding box.""" - print(f"Reading transcripts: {transcript_path}") - if transcript_path.endswith(".parquet"): - transcripts = pd.read_parquet(transcript_path, columns=["x_location", "y_location"]) - else: - transcripts = pd.read_csv(transcript_path) - - if "x_location" not in transcripts.columns or "y_location" not in transcripts.columns: - raise ValueError("Transcript file must contain 'x_location' and 'y_location' columns.") - - x_min, x_max = transcripts["x_location"].min(), transcripts["x_location"].max() - y_min, y_max = transcripts["y_location"].min(), transcripts["y_location"].max() - - print(f"Transcript bounds: X=({x_min:.2f}, {x_max:.2f}), Y=({y_min:.2f}, {y_max:.2f})") - return x_min, x_max, y_min, y_max - - -def read_microns_per_pixel(metadata_path: str) -> float: - """Extract microns_per_pixel or pixel_size from metadata JSON.""" - print(f"Reading metadata: {metadata_path}") - with open(metadata_path, "r") as f: - metadata = json.load(f) - - mpp = metadata.get("microns_per_pixel") or metadata.get("pixel_size") - if mpp is None: - raise KeyError("Metadata JSON must contain 'microns_per_pixel' or 'pixel_size'.") - - print(f"Microns per pixel: {mpp}") - return float(mpp) - - -def compute_target_size( - x_min: float, x_max: float, y_min: float, y_max: float, microns_per_pixel: float -) -> Tuple[int, int]: - """Compute new image size (in pixels) to cover given coordinates.""" - new_width = int(round((x_max - x_min) / microns_per_pixel)) - new_height = int(round((y_max - y_min) / microns_per_pixel)) - print(f"Target image size: {new_width} x {new_height} pixels") - return new_height, new_width - - -def resize_mask(mask: np.ndarray, new_shape: Tuple[int, int]) -> np.ndarray: - """Resize mask using nearest-neighbor interpolation (preserve labels).""" - print("Resizing mask...") - resized = resize( - mask, - new_shape, - order=0, # nearest neighbor to preserve segmentation labels - preserve_range=True, - anti_aliasing=False, - ).astype(mask.dtype) - print(f"Resized shape: {resized.shape}") - return resized - - -def main(mask_path: str, transcripts_path: str, metadata_path: str, output_path: str) -> None: - """Resize segmentation mask to match Xenium coordinate space.""" - # Validate input files - for path in [mask_path, transcripts_path, metadata_path]: - if not os.path.exists(path): - raise FileNotFoundError(f"File not found: {path}") - - # Load data - mask = read_mask(mask_path) - x_min, x_max, y_min, y_max = read_transcript_bounds(transcripts_path) - microns_per_pixel = read_microns_per_pixel(metadata_path) - - # Compute physical mask size - height, width = mask.shape - print(f"Original mask size: {width * microns_per_pixel:.2f} x {height * microns_per_pixel:.2f} um") - - # Compute target size - new_height, new_width = compute_target_size(x_min, x_max, y_min, y_max, microns_per_pixel) - - # Resize and save - resized_mask = resize_mask(mask, (new_height, new_width)) - tifffile.imwrite(output_path, resized_mask) - - print(f"Saved resized mask -> {output_path}") - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Resize a segmentation TIFF mask to match transcript coordinates." - ) - parser.add_argument("--mask", required=True, help="Path to segmentation mask TIFF") - parser.add_argument("--transcripts", required=True, help="Path to transcripts file") - parser.add_argument("--metadata", required=True, help="Path to metadata JSON") - parser.add_argument("--prefix", required=True, help="Output directory prefix") - parser.add_argument("--mask-filename", required=True, help="Original mask filename for output naming") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - os.makedirs(args.prefix, exist_ok=True) - output_mask: str = os.path.join(args.prefix, f"resized_{args.mask_filename}.tif") - - main( - mask_path=args.mask, - transcripts_path=args.transcripts, - metadata_path=args.metadata, - output_path=output_mask, - ) diff --git a/modules/local/utility/segger2xr/resources/usr/bin/segger2xr.py b/modules/local/utility/segger2xr/resources/usr/bin/segger2xr.py deleted file mode 100755 index 22889e82..00000000 --- a/modules/local/utility/segger2xr/resources/usr/bin/segger2xr.py +++ /dev/null @@ -1,247 +0,0 @@ -#!/usr/bin/env python3 -""" -Convert Segger prediction output to XeniumRanger-compatible format. - -Reads Segger PREDICT output (transcripts.parquet with segger_cell_id), -produces Baysor-format segmentation CSV, refined transcripts parquet, -and GeoJSON cell boundary polygons for xeniumranger import-segmentation. -""" - -import argparse -import json -from pathlib import Path -from typing import List - -import pandas as pd -from scipy.spatial import ConvexHull - -# Expected columns in transcripts.parquet -REQUIRED_COLUMNS: List[str] = [ - "transcript_id", - "cell_id", - "overlaps_nucleus", - "feature_name", - "x_location", - "y_location", - "z_location", - "qv", -] - -# Column name for segger cell assignment (varies by segger version) -SEGGER_ID_CANDIDATES: List[str] = ["segger_cell_id", "segger_id"] - - -def refine_transcripts(parquet_path: str) -> pd.DataFrame: - """ - Read segger PREDICT output and extract cell assignments. - Supports both 'segger_cell_id' (newer) and 'segger_id' (older) column names. - """ - parquet_file = Path(parquet_path) - if not parquet_file.exists(): - raise FileNotFoundError(f"File not found: {parquet_path}") - - df = pd.read_parquet(parquet_file, engine="pyarrow") - - missing_cols = [col for col in REQUIRED_COLUMNS if col not in df.columns] - if missing_cols: - raise ValueError(f"Missing required columns: {missing_cols}") - - # Find segger cell assignment column - segger_col = None - for candidate in SEGGER_ID_CANDIDATES: - if candidate in df.columns: - segger_col = candidate - break - if segger_col is None: - raise ValueError( - f"No segger cell assignment column found. " - f"Expected one of {SEGGER_ID_CANDIDATES}, got columns: {list(df.columns)}" - ) - - # Replace cell_id with segger assignment - cell_id_index = df.columns.get_loc("cell_id") - df = df.drop(columns=["cell_id"]) - segger_series = df.pop(segger_col) - df.insert(cell_id_index, "cell_id", segger_series) - - return df - - -def build_cell_map(df: pd.DataFrame, min_transcripts: int = 3) -> dict: - """ - Build a mapping from raw segger cell IDs to non-numeric string IDs. - - Only includes cells that have: - - >= min_transcripts assigned transcripts - - At least one transcript with valid (non-NaN) x/y coordinates - - Cell IDs use "cell-N" format (hyphen + integer) as required by - xeniumranger's cell ID parser. Non-numeric to avoid polars Int64 inference. - """ - cell_ids = df["cell_id"].fillna("UNASSIGNED").astype(str) - is_unassigned = (cell_ids == "UNASSIGNED") | (cell_ids == "") | (cell_ids == "0") - assigned = cell_ids[~is_unassigned] - counts = assigned.value_counts() - enough_tx = set(counts[counts >= min_transcripts].index) - - # Exclude cells with all-NaN coordinates (no spatial info = useless) - has_coords = df.dropna(subset=["x_location", "y_location"]) - has_coords_ids = set(has_coords["cell_id"].fillna("UNASSIGNED").astype(str)) - valid_cells = sorted(enough_tx & has_coords_ids) - - return {cell: f"cell-{i + 1}" for i, cell in enumerate(valid_cells)} - - -def to_baysor_csv(df: pd.DataFrame, output_path: str, cell_map: dict) -> None: - """ - Convert transcript DataFrame to Baysor-compatible CSV format. - - xeniumranger 4.0 import-segmentation --transcript-assignment expects a - Baysor segmentation CSV with at minimum: transcript_id, cell, is_noise, - x, y columns. This function maps Xenium/Segger columns to Baysor format. - """ - baysor_df = pd.DataFrame() - baysor_df["transcript_id"] = df["transcript_id"] - baysor_df["x"] = df["x_location"] - baysor_df["y"] = df["y_location"] - baysor_df["z"] = df["z_location"] - baysor_df["gene"] = df["feature_name"] - - cell_ids = df["cell_id"].fillna("UNASSIGNED").astype(str) - is_unassigned = (cell_ids == "UNASSIGNED") | (cell_ids == "") | (cell_ids == "0") - baysor_df["cell"] = cell_ids.map(cell_map).fillna("") - baysor_df["is_noise"] = is_unassigned.astype(int) - - baysor_df.to_csv(output_path, index=False) - - n_assigned = (~is_unassigned).sum() - n_noise = is_unassigned.sum() - n_cells = len(cell_map) - print( - f"Baysor CSV: {n_assigned} assigned, {n_noise} noise, {n_cells} cells -> {output_path}" - ) - - -def _make_buffer_polygon(cx: float, cy: float, radius: float = 0.5) -> list: - """Create a small square polygon around a centroid as fallback.""" - return [ - [cx - radius, cy - radius], - [cx + radius, cy - radius], - [cx + radius, cy + radius], - [cx - radius, cy + radius], - [cx - radius, cy - radius], # close ring - ] - - -def generate_viz_polygons(df: pd.DataFrame, output_path: str, cell_map: dict) -> None: - """ - Generate a GeoJSON file with cell boundary polygons. - - Uses ConvexHull when possible; falls back to a small buffer polygon around - the centroid for cells with < 3 unique points or collinear points. - - Required by xeniumranger import-segmentation when using --transcript-assignment. - Each feature MUST have a top-level "id" field (xeniumranger reads item["id"]). - Cell IDs must match those in the Baysor CSV. - """ - assigned = df[ - df["cell_id"].notna() - & (df["cell_id"].astype(str) != "UNASSIGNED") - & (df["cell_id"].astype(str) != "") - ].copy() - - features = [] - grouped = assigned.groupby("cell_id") - - for cell_id, group in grouped: - mapped_id = cell_map.get(str(cell_id)) - if mapped_id is None: - continue - - coords = group[["x_location", "y_location"]].dropna().values - - polygon_coords = None - if len(coords) >= 3: - try: - hull = ConvexHull(coords) - hull_points = coords[hull.vertices].tolist() - hull_points.append(hull_points[0]) # close polygon ring - polygon_coords = hull_points - except Exception: - pass - - # Fallback: buffer polygon around centroid - if polygon_coords is None: - cx, cy = coords.mean(axis=0).astype(float) - polygon_coords = _make_buffer_polygon(cx, cy) - - features.append( - { - "type": "Feature", - "id": mapped_id, - "geometry": { - "type": "Polygon", - "coordinates": [polygon_coords], - }, - "properties": {"cell_id": mapped_id}, - } - ) - - geojson = {"type": "FeatureCollection", "features": features} - - with open(output_path, "w") as f: - json.dump(geojson, f) - - print(f"Generated {len(features)} cell polygons in {output_path}") - - -def main(input_file: str, prefix: str, min_transcripts: int = 3) -> None: - """Run the full segger-to-xeniumranger conversion pipeline.""" - Path(prefix).mkdir(parents=True, exist_ok=True) - transcripts = refine_transcripts(input_file) - - # Build cell ID mapping, filtering cells with < min_transcripts - cell_map = build_cell_map(transcripts, min_transcripts=min_transcripts) - - # xeniumranger 4.0 expects Baysor-format CSV (not parquet) with is_noise column - to_baysor_csv(transcripts, f"{prefix}/segmentation.csv", cell_map) - - # Also save the refined parquet for downstream use - transcripts.to_parquet(f"{prefix}/transcripts.parquet", engine="pyarrow") - - # Generate cell boundary polygons (required companion to --transcript-assignment) - # Uses ConvexHull when possible; falls back to buffer polygon for edge cases - generate_viz_polygons(transcripts, f"{prefix}/segmentation_polygons.json", cell_map) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Convert Segger prediction output to XeniumRanger-compatible format." - ) - parser.add_argument( - "--transcripts", - required=True, - help="Path to Segger output transcripts parquet file", - ) - parser.add_argument( - "--prefix", - required=True, - help="Output directory prefix (sample ID)", - ) - parser.add_argument( - "--min-transcripts", - type=int, - default=3, - help="Minimum transcripts per cell (default: 3)", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - main( - input_file=args.transcripts, - prefix=args.prefix, - min_transcripts=args.min_transcripts, - ) diff --git a/modules/local/utility/split_transcripts/resources/usr/bin/split_transcripts.py b/modules/local/utility/split_transcripts/resources/usr/bin/split_transcripts.py deleted file mode 100755 index 275fbab1..00000000 --- a/modules/local/utility/split_transcripts/resources/usr/bin/split_transcripts.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -Split transcript coordinates into spatial tiles. - -Reads a Xenium transcripts.parquet file and computes quantile-based spatial -tiles, writing a splits.csv with tile boundaries. -""" - -import argparse -import os -from typing import List - -import pandas as pd - - -def compute_quantile_ranges(df: pd.DataFrame, col: str, n_bins: int) -> List: - """ - Compute the bin edges for `df[col]` such that each of the n_bins - has ~equal count of points. Returns a list of (min, max) tuples. - """ - _, bins = pd.qcut(df[col], q=n_bins, retbins=True, duplicates="drop") - - ranges = [(bins[i], bins[i + 1]) for i in range(len(bins) - 1)] - - return ranges - - -def make_tiles(df: pd.DataFrame, x_bins: int, y_bins: int) -> pd.DataFrame: - """ - Produce a DataFrame with one row per tile: - tile_id, x_min, x_max, y_min, y_max - """ - x_ranges = compute_quantile_ranges(df, "x_location", x_bins) - y_ranges = compute_quantile_ranges(df, "y_location", y_bins) - - tiles = [] - for ix, (x_min, x_max) in enumerate(x_ranges, start=1): - for iy, (y_min, y_max) in enumerate(y_ranges, start=1): - tiles.append( - { - "tile_id": f"{ix}_{iy}", - "x_min": x_min, - "x_max": x_max, - "y_min": y_min, - "y_max": y_max, - } - ) - - return pd.DataFrame(tiles) - - -def main( - transcripts: str, - x_bins: int = 10, - y_bins: int = 10, - prefix: str = "", -) -> None: - """Generate spatial tile splits from transcript coordinates.""" - # read parquet file - df = pd.read_parquet(transcripts, engine="fastparquet") - - # compute tiles - tiles_df = make_tiles(df, x_bins, y_bins) - - # save csv file - os.makedirs(prefix, exist_ok=True) - tiles_df.to_csv(f"{prefix}/splits.csv", index=False) - - return None - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments.""" - parser = argparse.ArgumentParser( - description="Split transcript coordinates into spatial tiles." - ) - parser.add_argument( - "--transcripts", - required=True, - help="Path to transcripts parquet file", - ) - parser.add_argument( - "--x-bins", - type=int, - required=True, - help="Number of bins along X axis", - ) - parser.add_argument( - "--y-bins", - type=int, - required=True, - help="Number of bins along Y axis", - ) - parser.add_argument( - "--prefix", - required=True, - help="Output directory prefix", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - main( - transcripts=args.transcripts, - x_bins=args.x_bins, - y_bins=args.y_bins, - prefix=args.prefix, - ) diff --git a/modules/local/xenium_patch/stitch/resources/usr/bin/stitch_transcripts.py b/modules/local/xenium_patch/stitch/resources/usr/bin/stitch_transcripts.py deleted file mode 100755 index d9fb8d41..00000000 --- a/modules/local/xenium_patch/stitch/resources/usr/bin/stitch_transcripts.py +++ /dev/null @@ -1,808 +0,0 @@ -#!/usr/bin/env python3 -"""Stitch per-patch Baysor segmentation results into unified output. - -Standalone script that replaces the xenium_patch CLI package's stitch -functionality. Uses sopa's solve_conflicts() for overlap resolution. -""" - -from __future__ import annotations - -import argparse -import json -import os -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass -from pathlib import Path - -import geopandas as gpd -import numpy as np -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.csv as pa_csv -import shapely -from shapely.affinity import translate -from shapely.geometry import mapping, shape -from sopa.segmentation.resolve import solve_conflicts - -# --------------------------------------------------------------------------- -# Geometry helpers -# --------------------------------------------------------------------------- - - -def _ensure_polygon(geom) -> "shapely.Polygon | None": - """Extract a single Polygon from any geometry, or return None. - - XeniumRanger only accepts Polygon. make_valid() and solve_conflicts - can produce MultiPolygon, GeometryCollection, MultiLineString, etc. - """ - if geom is None or geom.is_empty: - return None - if geom.geom_type == "Polygon": - return geom - if geom.geom_type == "MultiPolygon": - return max(geom.geoms, key=lambda g: g.area) - if geom.geom_type == "GeometryCollection": - polys = [g for g in geom.geoms if g.geom_type == "Polygon"] - return max(polys, key=lambda g: g.area) if polys else None - # LineString, MultiLineString, Point, etc. — not a polygon - return None - - -# --------------------------------------------------------------------------- -# Inline types (from _types.py) -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True) -class Bounds: - """Axis-aligned bounding box in either pixel or micron coordinates.""" - - x_min: float - x_max: float - y_min: float - y_max: float - - -@dataclass(frozen=True) -class PatchInfo: - """Metadata for a single patch in the grid.""" - - patch_id: str - row: int - col: int - global_bounds_px: Bounds - global_bounds_um: Bounds - core_bounds_px: Bounds - core_bounds_um: Bounds - - -@dataclass -class PatchGridMetadata: - """Full grid metadata, serializable to JSON.""" - - version: str - bundle_path: str - image_height_px: int - image_width_px: int - pixel_size_um: float - transcript_extent_um: Bounds - grid_rows: int - grid_cols: int - overlap_um: float - overlap_px: int - patches: list[PatchInfo] - grid_type: str = "uniform" - - -# --------------------------------------------------------------------------- -# Internal result containers -# --------------------------------------------------------------------------- - - -@dataclass -class _PatchGeoResult: - """Result of parallel GeoJSON processing for a single patch.""" - - features: list[dict] - cell_ids: list[str] - - -@dataclass -class _PatchCsvResult: - """Result of parallel CSV reading for a single patch.""" - - table: pa.Table - has_cell_col: bool - has_x_col: bool - has_y_col: bool - has_gene_col: bool = False - has_feature_name_col: bool = False - - -# --------------------------------------------------------------------------- -# Grid metadata I/O (from grid.py) -# --------------------------------------------------------------------------- - - -def _dict_to_bounds(d: dict) -> Bounds: - return Bounds(d["x_min"], d["x_max"], d["y_min"], d["y_max"]) - - -def load_grid_metadata(input_path: Path) -> PatchGridMetadata: - """Deserialize PatchGridMetadata from JSON. - - Args: - input_path: Path to JSON file to read. - - Returns: - Reconstructed PatchGridMetadata. - """ - with open(input_path) as f: - data = json.load(f) - - patches = [ - PatchInfo( - patch_id=p["patch_id"], - row=p["row"], - col=p["col"], - global_bounds_px=_dict_to_bounds(p["global_bounds_px"]), - global_bounds_um=_dict_to_bounds(p["global_bounds_um"]), - core_bounds_px=_dict_to_bounds(p["core_bounds_px"]), - core_bounds_um=_dict_to_bounds(p["core_bounds_um"]), - ) - for p in data["patches"] - ] - - return PatchGridMetadata( - version=data["version"], - bundle_path=data["bundle_path"], - image_height_px=data["image_height_px"], - image_width_px=data["image_width_px"], - pixel_size_um=data["pixel_size_um"], - transcript_extent_um=_dict_to_bounds(data["transcript_extent_um"]), - grid_rows=data["grid_rows"], - grid_cols=data["grid_cols"], - overlap_um=data["overlap_um"], - overlap_px=data["overlap_px"], - grid_type=data.get("grid_type", "uniform"), - patches=patches, - ) - - -# --------------------------------------------------------------------------- -# GeoJSON I/O (from polygon_io.py) -# --------------------------------------------------------------------------- - - -def _normalize_geometry_collection(geojson: dict) -> dict: - """Convert a GeometryCollection to a FeatureCollection. - - proseg-to-baysor produces a non-standard GeoJSON GeometryCollection where - each geometry object has a custom ``cell`` key (bare integer) instead of - using Feature wrappers. This normalises it to a standard FeatureCollection - with ``id`` and ``properties.cell_id`` on each feature, using the - ``"cell-{N}"`` format that matches the companion CSV. - - Args: - geojson: Parsed GeoJSON dict with type GeometryCollection. - - Returns: - Standard FeatureCollection dict. - """ - features = [] - for geom in geojson.get("geometries", []): - cell_raw = geom.get("cell", "") - cell_id = str(cell_raw) - clean_geom = {k: v for k, v in geom.items() if k != "cell"} - feature = { - "type": "Feature", - "id": cell_id, - "geometry": clean_geom, - "properties": {"cell_id": cell_id}, - } - features.append(feature) - return {"type": "FeatureCollection", "features": features} - - -def read_geojson(geojson_path: Path) -> dict: - """Read a GeoJSON file and normalise to FeatureCollection. - - Handles both standard FeatureCollections and the GeometryCollection - format produced by proseg-to-baysor. - - Args: - geojson_path: Path to the GeoJSON file. - - Returns: - Parsed GeoJSON dict (always a FeatureCollection). - """ - with open(geojson_path) as f: - data = json.load(f) - if data.get("type") == "GeometryCollection": - return _normalize_geometry_collection(data) - return data - - -def transform_polygons(geojson: dict, offset_x: float, offset_y: float) -> dict: - """Shift all polygon coordinates by (offset_x, offset_y). - - Args: - geojson: Input FeatureCollection. - offset_x: Translation in x. - offset_y: Translation in y. - - Returns: - New FeatureCollection with shifted geometries. - """ - features = [] - for feat in geojson.get("features", []): - geom = shape(feat["geometry"]) - shifted = translate(geom, xoff=offset_x, yoff=offset_y) - new_feat = {**feat, "geometry": mapping(shifted)} - features.append(new_feat) - return {"type": "FeatureCollection", "features": features} - - -def write_geojson(geojson: dict, output_path: Path) -> None: - """Write a GeoJSON FeatureCollection. - - Args: - geojson: GeoJSON dict to write. - output_path: Destination path (parent dirs created automatically). - """ - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: - json.dump(geojson, f) - - -# --------------------------------------------------------------------------- -# Arrow utilities (from _arrow_utils.py) -# --------------------------------------------------------------------------- - - -def float_str_array(f64_array: pa.Array) -> pa.Array: - """Convert a float64 pyarrow array to string using Python's str(float) format. - - pyarrow's built-in cast omits trailing '.0' for whole numbers. This - function ensures output matches str(float(...)) for CSV compatibility. - - Args: - f64_array: Float64 pyarrow array to convert. - - Returns: - String pyarrow array with Python-formatted float values. - """ - return pa.array( - [str(v) if v is not None else None for v in f64_array.to_pylist()], - type=pa.string(), - ) - - -# --------------------------------------------------------------------------- -# Parallel I/O -# --------------------------------------------------------------------------- - - -def _read_and_transform_geojson( - patch: PatchInfo, - patches_dir: Path, - geojson_filename: str, -) -> _PatchGeoResult | None: - """Read, transform GeoJSON for a single patch (no core clipping). - - Args: - patch: Patch metadata. - patches_dir: Root patches directory. - geojson_filename: GeoJSON filename within each patch directory. - - Returns: - _PatchGeoResult with features and cell IDs, or None if no GeoJSON. - """ - geojson_path = patches_dir / patch.patch_id / geojson_filename - if not geojson_path.exists(): - return None - - geojson = read_geojson(geojson_path) - - offset_x = patch.global_bounds_um.x_min - offset_y = patch.global_bounds_um.y_min - geojson = transform_polygons(geojson, offset_x, offset_y) - - features = geojson.get("features", []) - seen: set[str] = set() - cell_ids: list[str] = [] - for feat in features: - old_id = str(feat.get("id", feat.get("properties", {}).get("cell_id", ""))) - if old_id not in seen: - seen.add(old_id) - cell_ids.append(old_id) - - return _PatchGeoResult(features=features, cell_ids=cell_ids) - - -def _read_patch_csv( - patch: PatchInfo, - patches_dir: Path, - csv_filename: str, -) -> _PatchCsvResult | None: - """Read a patch CSV into a pyarrow Table. - - All columns are read as strings to preserve exact formatting. - - Args: - patch: Patch metadata. - patches_dir: Root patches directory. - csv_filename: CSV filename within each patch directory. - - Returns: - _PatchCsvResult with the table and column presence flags, or None. - """ - csv_path = patches_dir / patch.patch_id / csv_filename - if not csv_path.exists(): - return None - - with open(csv_path) as fh: - header_line = fh.readline().strip() - col_names = header_line.split(",") - all_string_types = {name: pa.string() for name in col_names} - - table = pa_csv.read_csv( - csv_path, - convert_options=pa_csv.ConvertOptions( - column_types=all_string_types, - strings_can_be_null=False, - ), - read_options=pa_csv.ReadOptions(use_threads=True), - ) - - return _PatchCsvResult( - table=table, - has_cell_col="cell" in table.column_names, - has_x_col="x" in table.column_names, - has_y_col="y" in table.column_names, - has_gene_col="gene" in table.column_names, - has_feature_name_col="feature_name" in table.column_names, - ) - - -# --------------------------------------------------------------------------- -# CSV processing -# --------------------------------------------------------------------------- - - -def _transform_patch_coords( - csv_result: _PatchCsvResult, - offset_x: float, - offset_y: float, -) -> pa.Table: - """Shift transcript coordinates from local patch space to global space. - - Args: - csv_result: The raw CSV table and column flags. - offset_x: X offset for coordinate transform (microns). - offset_y: Y offset for coordinate transform (microns). - - Returns: - Table with x, y columns shifted to global coordinates. - """ - table = csv_result.table - - if table.num_rows == 0: - return table - - if csv_result.has_x_col: - x_f64 = pc.add( - table.column("x").cast(pa.float64()), - pa.scalar(offset_x, type=pa.float64()), - ) - table = table.set_column( - table.schema.get_field_index("x"), - "x", - float_str_array(x_f64), - ) - if csv_result.has_y_col: - y_f64 = pc.add( - table.column("y").cast(pa.float64()), - pa.scalar(offset_y, type=pa.float64()), - ) - table = table.set_column( - table.schema.get_field_index("y"), - "y", - float_str_array(y_f64), - ) - - return table - - -# --------------------------------------------------------------------------- -# Sopa conflict resolution -# --------------------------------------------------------------------------- - - -def _stitch_sopa_resolve( - metadata: PatchGridMetadata, - geo_results: list[_PatchGeoResult | None], - csv_results: list[_PatchCsvResult | None], - all_geojson_features: list[dict], - all_tables: list[pa.Table], - threshold: float = 0.5, -) -> set[str]: - """Stitch per-patch segmentation using spatial containment assignment. - - 1. Collect ALL non-empty polygons from all patches (no transcript filtering). - 2. Resolve overlapping polygons via sopa's solve_conflicts(). - 3. Assign sequential global cell IDs (cell-1, cell-2, ...). - 4. Spatially assign transcripts to resolved polygons using STRtree. - 5. Noise transcripts (outside all polygons) kept only from their core patch. - - This approach works regardless of whether Baysor's CSV ``cell`` column - matches GeoJSON cell IDs -- all assignment is done by spatial containment. - - Args: - metadata: Grid metadata with patch list. - geo_results: Per-patch GeoJSON results (already in global coords). - csv_results: Per-patch CSV results. - all_geojson_features: Output list to append resolved GeoJSON features. - all_tables: Output list to append processed CSV tables. - threshold: Overlap threshold for sopa's solve_conflicts (0-1). - - Returns: - Set of global cell IDs created by merging overlapping cells. - """ - # --- Phase 1: Collect all polygons from all patches --- - all_polygons: list = [] - patch_indices_list: list[int] = [] - - for i, patch in enumerate(metadata.patches): - geo_result = geo_results[i] - if geo_result is None: - continue - - for feat in geo_result.features: - polygon = shape(feat["geometry"]) - if polygon.is_empty: - continue - if not polygon.is_valid: - polygon = shapely.make_valid(polygon) - # Ensure we have a single Polygon (xeniumranger rejects all else) - polygon = _ensure_polygon(polygon) - if polygon is None: - continue - - all_polygons.append(polygon) - patch_indices_list.append(i) - - if not all_polygons: - print("[stitch] No polygons found in any patch") - # Still transform and collect CSVs as noise-only - for i, patch in enumerate(metadata.patches): - csv_result = csv_results[i] - if csv_result is None: - continue - offset_x = patch.global_bounds_um.x_min - offset_y = patch.global_bounds_um.y_min - transformed = _transform_patch_coords(csv_result, offset_x, offset_y) - if transformed.num_rows > 0: - all_tables.append(transformed) - return set() - - # --- Phase 2: Resolve overlapping polygons via sopa --- - patch_idx_array = np.array(patch_indices_list, dtype=np.int64) - input_gdf = gpd.GeoDataFrame(geometry=all_polygons) - resolved_gdf, kept_indices = solve_conflicts( - input_gdf, - threshold=threshold, - patch_indices=patch_idx_array, - return_indices=True, - ) - - # --- Phase 3: Assign global cell IDs to resolved polygons --- - merged_cell_ids: set[str] = set() - kept_arr = np.asarray(kept_indices) - resolved_polys: list = [] - resolved_ids: list[str] = [] - - for rank, orig_idx in enumerate(kept_arr, start=1): - global_id = f"cell-{rank}" - geom = resolved_gdf.geometry.iloc[rank - 1] - - # Ensure single Polygon after solve_conflicts union - geom = _ensure_polygon(geom) - if geom is None: - continue - - if orig_idx < 0: - merged_cell_ids.add(global_id) - - resolved_polys.append(geom) - resolved_ids.append(global_id) - - all_geojson_features.append( - { - "type": "Feature", - "id": global_id, - "geometry": mapping(geom), - "properties": {"cell_id": global_id}, - } - ) - - print( - f"[stitch] Resolved {len(all_polygons)} input polygons to " - f"{len(resolved_polys)} cells ({len(merged_cell_ids)} merged)" - ) - - # --- Phase 4: Spatial transcript assignment via STRtree --- - poly_tree = shapely.STRtree(resolved_polys) - - for i, patch in enumerate(metadata.patches): - csv_result = csv_results[i] - if csv_result is None: - continue - - offset_x = patch.global_bounds_um.x_min - offset_y = patch.global_bounds_um.y_min - core = patch.core_bounds_um - - transformed = _transform_patch_coords(csv_result, offset_x, offset_y) - if transformed.num_rows == 0: - continue - - if not csv_result.has_x_col or not csv_result.has_y_col: - all_tables.append(transformed) - continue - - # Get global coordinates for spatial query - gx = transformed.column("x").cast(pa.float64()).to_numpy(zero_copy_only=False) - gy = transformed.column("y").cast(pa.float64()).to_numpy(zero_copy_only=False) - points = shapely.points(gx, gy) - - # Query STRtree: returns (input_indices, tree_indices) - point_hits, poly_hits = poly_tree.query(points, predicate="intersects") - - # Build point -> cell_id mapping (first hit wins) - point_to_cell: dict[int, str] = {} - for pt_idx, poly_idx in zip(point_hits, poly_hits): - if pt_idx not in point_to_cell: - point_to_cell[pt_idx] = resolved_ids[poly_idx] - - # Build cell and is_noise columns - n_rows = transformed.num_rows - cell_arr = [""] * n_rows - is_noise_arr = ["true"] * n_rows - for pt_idx, cell_id in point_to_cell.items(): - cell_arr[pt_idx] = cell_id - is_noise_arr[pt_idx] = "false" - - # Filter noise transcripts to core bounds only - # Assigned transcripts are kept from all patches (dedup later by transcript_id) - in_core = ( - (gx >= core.x_min) - & (gx < core.x_max) - & (gy >= core.y_min) - & (gy < core.y_max) - ) - is_assigned = np.array([c != "" for c in cell_arr]) - keep_mask = pa.array(is_assigned | in_core, type=pa.bool_()) - - filtered = transformed.filter(keep_mask) - cell_arr_filtered = [c for c, k in zip(cell_arr, (is_assigned | in_core)) if k] - is_noise_filtered = [ - n for n, k in zip(is_noise_arr, (is_assigned | in_core)) if k - ] - - if filtered.num_rows == 0: - continue - - # Set cell and is_noise columns - cell_idx = ( - filtered.schema.get_field_index("cell") - if "cell" in filtered.column_names - else None - ) - if cell_idx is not None: - filtered = filtered.set_column( - cell_idx, "cell", pa.array(cell_arr_filtered, type=pa.string()) - ) - else: - filtered = filtered.append_column( - "cell", pa.array(cell_arr_filtered, type=pa.string()) - ) - - noise_idx = ( - filtered.schema.get_field_index("is_noise") - if "is_noise" in filtered.column_names - else None - ) - if noise_idx is not None: - filtered = filtered.set_column( - noise_idx, - "is_noise", - pa.array(is_noise_filtered, type=pa.string()), - ) - else: - filtered = filtered.append_column( - "is_noise", pa.array(is_noise_filtered, type=pa.string()) - ) - - all_tables.append(filtered) - - return merged_cell_ids - - -# --------------------------------------------------------------------------- -# Main orchestrator -# --------------------------------------------------------------------------- - - -def stitch_transcript_assignments( - patches_dir: Path, - output_dir: Path, - csv_filename: str = "segmentation.csv", - geojson_filename: str = "segmentation_polygons.json", - max_workers: int | None = None, -) -> None: - """Stitch per-patch transcript assignments and polygons into unified output. - - For each patch, reads the transcript assignment CSV and polygon GeoJSON. - Cells are deduplicated using sopa's solve_conflicts() which resolves - overlapping cells at patch boundaries based on area overlap ratio. - - Processing is split into a parallel I/O phase (reading GeoJSON and CSV - files via thread pool) and a sequential phase (dedup, global cell ID - assignment, remapping, and concatenation). - - Args: - patches_dir: Directory containing patch subdirectories and patch_grid.json. - output_dir: Output directory for stitched CSV and GeoJSON. - csv_filename: CSV filename within each patch directory. - geojson_filename: GeoJSON filename within each patch directory. - max_workers: Maximum number of threads for parallel I/O. - """ - patches_dir = Path(patches_dir) - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - metadata = load_grid_metadata(patches_dir / "patch_grid.json") - - n_patches = len(metadata.patches) - if max_workers is None: - max_workers = min(n_patches, os.cpu_count() or 1) - - # ---- Parallel phase: read GeoJSON and CSV files concurrently ---- - with ThreadPoolExecutor(max_workers=max_workers) as executor: - geo_futures = [ - executor.submit( - _read_and_transform_geojson, p, patches_dir, geojson_filename - ) - for p in metadata.patches - ] - csv_futures = [ - executor.submit(_read_patch_csv, p, patches_dir, csv_filename) - for p in metadata.patches - ] - geo_results = [f.result() for f in geo_futures] - csv_results = [f.result() for f in csv_futures] - - # ---- Sequential phase: assign global cell IDs, remap, concatenate ---- - all_tables: list[pa.Table] = [] - all_geojson_features: list[dict] = [] - - _stitch_sopa_resolve( - metadata, - geo_results, - csv_results, - all_geojson_features, - all_tables, - threshold=0.5, - ) - - # Concatenate all patch tables - if all_tables: - merged = pa.concat_tables(all_tables) - - # Deduplicate by transcript_id: prefer assigned over noise - if "transcript_id" in merged.column_names: - if "cell" in merged.column_names: - is_noise = pc.equal(merged.column("cell"), "").cast(pa.int8()) - row_order = pa.array(np.arange(merged.num_rows), type=pa.int64()) - sort_table = pa.table({"_noise": is_noise, "_row": row_order}) - sort_indices = pc.sort_indices( - sort_table, - sort_keys=[("_noise", "ascending"), ("_row", "ascending")], - ) - merged = merged.take(sort_indices) - - tid_np = merged.column("transcript_id").to_numpy(zero_copy_only=False) - _, first_indices = np.unique(tid_np, return_index=True) - first_indices.sort() - merged = merged.take(first_indices) - - # Log assignment stats - if "cell" in merged.column_names: - cell_vals = merged.column("cell").to_pylist() - n_assigned = sum(1 for c in cell_vals if c) - n_noise = sum(1 for c in cell_vals if not c) - print( - f"[stitch] Final: {merged.num_rows} transcripts, " - f"{n_assigned} assigned, {n_noise} noise" - ) - - # Cast is_noise to integer for xeniumranger compatibility - if "is_noise" in merged.column_names: - noise_col = merged.column("is_noise") - if noise_col.type == pa.string(): - lower = pc.utf8_lower(noise_col) - is_true = pc.or_(pc.equal(lower, "true"), pc.equal(lower, "1")) - idx = merged.column_names.index("is_noise") - merged = merged.set_column(idx, "is_noise", is_true.cast(pa.int8())) - - # Write CSV - if merged.num_rows > 0: - csv_out = output_dir / "xr-transcript-metadata.csv" - pa_csv.write_csv( - merged, - csv_out, - write_options=pa_csv.WriteOptions(quoting_style="needed"), - ) - - # Safety net: remove orphan polygons with zero transcripts - if all_geojson_features and all_tables: - csv_cell_ids: set[str] = set() - if "cell" in merged.column_names: - csv_cell_ids = set(c for c in merged.column("cell").to_pylist() if c) - all_geojson_features = [ - f - for f in all_geojson_features - if str(f.get("id", f.get("properties", {}).get("cell_id", ""))) - in csv_cell_ids - ] - - # Write merged GeoJSON - if all_geojson_features: - merged_geo = {"type": "FeatureCollection", "features": all_geojson_features} - write_geojson(merged_geo, output_dir / "xr-cell-polygons.geojson") - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Stitch per-patch Baysor segmentation results into unified output." - ) - parser.add_argument( - "--patches", - type=Path, - required=True, - help="Directory containing patch subdirectories and patch_grid.json", - ) - parser.add_argument( - "--output", - type=Path, - required=True, - help="Output directory for stitched CSV and GeoJSON", - ) - parser.add_argument( - "--csv-filename", - default="segmentation.csv", - help="CSV filename within each patch (default: segmentation.csv)", - ) - parser.add_argument( - "--geojson-filename", - default="segmentation_polygons.json", - help="GeoJSON filename within each patch (default: segmentation_polygons.json)", - ) - args = parser.parse_args() - - stitch_transcript_assignments( - patches_dir=args.patches, - output_dir=args.output, - csv_filename=args.csv_filename, - geojson_filename=args.geojson_filename, - ) - - -if __name__ == "__main__": - main()