diff --git a/bin/baysor_create_dataset.py b/bin/baysor_create_dataset.py new file mode 100755 index 00000000..4e5a263a --- /dev/null +++ b/bin/baysor_create_dataset.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +Create a sampled dataset for Baysor preview mode. + +Reads a CSV transcript file and randomly samples a fraction of rows, +writing the result to a new CSV file. +""" + +import argparse +import csv +import os +import random +from pathlib import Path + + +class BaysorPreview(): + """ + Utility class to generate baysor preview dataset + """ + @staticmethod + def generate_dataset( + transcripts: Path, + sampled_transcripts: Path, + sample_fraction: float = 0.3, + random_state: int = 42, + prefix: str = "" + ) -> None: + """ + Reads a csv file & randomly samples a fraction of rows, + and writes the result to a .csv file. + + Args: + transcripts: unziped transcripts.csv from xenium bundle + sampled_transcripts: randomly subsampled transcripts.csv file + sample_fraction: Fraction of rows to sample + random_state: Seed for reproducibility + prefix: Output directory prefix + """ + + random.seed(random_state) + output_path = f"{prefix}/{sampled_transcripts}" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(transcripts, mode='rt', newline='') as infile, \ + open(output_path, mode='wt', newline='') as outfile: + + reader = csv.reader(infile) + writer = csv.writer(outfile) + + # get the header line + header = next(reader) + writer.writerow(header) + + # randomize csv rows to write + for row in reader: + if random.random() < float(sample_fraction): + writer.writerow(row) + + return None + + +def main() -> None: + """ + Run create dataset as nf module + """ + parser = argparse.ArgumentParser( + description="Create sampled dataset for Baysor preview" + ) + parser.add_argument( + "--transcripts", required=True, + help="Path to transcripts CSV file" + ) + parser.add_argument( + "--sample-fraction", required=True, type=float, + help="Fraction of rows to sample" + ) + parser.add_argument( + "--prefix", required=True, + help="Output directory prefix" + ) + args = parser.parse_args() + + sampled_transcripts = "sampled_transcripts.csv" + + # generate dataset + BaysorPreview.generate_dataset( + transcripts=args.transcripts, + sampled_transcripts=sampled_transcripts, + sample_fraction=args.sample_fraction, + prefix=args.prefix + ) + + return None + + +if __name__ == "__main__": + main() diff --git a/bin/baysor_preprocess_transcripts.py b/bin/baysor_preprocess_transcripts.py new file mode 100755 index 00000000..2662f83c --- /dev/null +++ b/bin/baysor_preprocess_transcripts.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Preprocess Xenium transcripts for Baysor segmentation. + +Filters transcripts based on quality score and spatial coordinate thresholds, +removes negative control probes, and outputs filtered CSV for Baysor compatibility. +""" + +import argparse +import os + +import pandas as pd + + +def filter_transcripts( + transcripts: str, + min_qv: float = 20.0, + min_x: float = 0.0, + max_x: float = 24000.0, + min_y: float = 0.0, + max_y: float = 24000.0, + prefix: str = "", +) -> None: + """ + Filter transcripts based on the specified thresholds. + + Args: + transcripts: Path to transcripts parquet file + min_qv: Minimum Q-Score to pass filtering + min_x: Minimum x-coordinate threshold + max_x: Maximum x-coordinate threshold + min_y: Minimum y-coordinate threshold + max_y: Maximum y-coordinate threshold + prefix: Output directory prefix + """ + df = pd.read_parquet(transcripts, engine="pyarrow") + + # filter transcripts df with thresholds, ignore negative controls + filtered_df = df[ + (df["qv"] >= min_qv) + & (df["x_location"] >= min_x) + & (df["x_location"] <= max_x) + & (df["y_location"] >= min_y) + & (df["y_location"] <= max_y) + & (~df["feature_name"].str.startswith("NegControlProbe_")) + & (~df["feature_name"].str.startswith("antisense_")) + & (~df["feature_name"].str.startswith("NegControlCodeword_")) + & (~df["feature_name"].str.startswith("BLANK_")) + ] + + # change cell_id of cell-free transcripts to "0" (Baysor's no-cell sentinel). + # Modern Xenium stores cell_id as a string ("UNASSIGNED" for cell-free transcripts); + # legacy Xenium used integer -1. Normalize to string and handle both cases — pandas 3 + # rejects mixing int values into a string-dtype column. + filtered_df["cell_id"] = filtered_df["cell_id"].astype(str) + neg_cell_row = filtered_df["cell_id"].isin(["-1", "UNASSIGNED"]) + filtered_df.loc[neg_cell_row, "cell_id"] = "0" + + # Output filtered transcripts as CSV for Baysor 0.7.1 compatibility. + # Baysor's Julia Parquet.jl cannot read modern pyarrow Parquet files + # (pyarrow 15+ writes size_statistics Thrift field 16 unconditionally, + # which Baysor's old Thrift deserializer doesn't recognize). + os.makedirs(prefix, exist_ok=True) + filtered_df.to_csv(f"{prefix}/filtered_transcripts.csv", index=False) + + return None + + +def main() -> None: + """ + Run preprocess transcripts as nf module. + """ + parser = argparse.ArgumentParser( + description="Preprocess Xenium transcripts for Baysor" + ) + parser.add_argument( + "--transcripts", required=True, help="Path to transcripts parquet file" + ) + parser.add_argument("--prefix", required=True, help="Output directory prefix") + parser.add_argument( + "--min-qv", + type=float, + default=20.0, + help="Minimum Q-Score threshold (default: 20.0)", + ) + parser.add_argument( + "--min-x", + type=float, + default=0.0, + help="Minimum x-coordinate threshold (default: 0.0)", + ) + parser.add_argument( + "--max-x", + type=float, + default=24000.0, + help="Maximum x-coordinate threshold (default: 24000.0)", + ) + parser.add_argument( + "--min-y", + type=float, + default=0.0, + help="Minimum y-coordinate threshold (default: 0.0)", + ) + parser.add_argument( + "--max-y", + type=float, + default=24000.0, + help="Maximum y-coordinate threshold (default: 24000.0)", + ) + args = parser.parse_args() + + filter_transcripts( + transcripts=args.transcripts, + min_qv=args.min_qv, + min_x=args.min_x, + max_x=args.max_x, + min_y=args.min_y, + max_y=args.max_y, + prefix=args.prefix, + ) + + return None + + +if __name__ == "__main__": + main() diff --git a/bin/ficture_preprocess.py b/bin/ficture_preprocess.py new file mode 100755 index 00000000..2e0c687c --- /dev/null +++ b/bin/ficture_preprocess.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Preprocess Xenium transcripts for FICTURE analysis.""" + +import argparse +import gzip +import logging +import os +import re +import sys + +import pandas as pd + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Preprocess Xenium transcripts for FICTURE" + ) + parser.add_argument( + "--transcripts", required=True, help="Path to transcripts file (CSV)" + ) + parser.add_argument( + "--features", default="", help="Path to features file (optional)" + ) + parser.add_argument( + "--negative-control-regex", default="", help="Regex for negative control probes" + ) + return parser.parse_args() + + +def main(): + """Run FICTURE preprocessing.""" + args = parse_args() + print("[START]") + + negctrl_regex = "BLANK|NegCon" + if args.negative_control_regex: + negctrl_regex = args.negative_control_regex + + unit_info = ["X", "Y", "gene", "cell_id", "overlaps_nucleus"] + oheader = unit_info + ["Count"] + + feature = pd.DataFrame() + xmin = sys.maxsize + xmax = 0 + ymin = sys.maxsize + ymax = 0 + + output = "processed_transcripts.tsv.gz" + feature_file = "feature.clean.tsv.gz" + min_phred_score = 15 + + with gzip.open(output, "wt") as wf: + wf.write("\t".join(oheader) + "\n") + + for chunk in pd.read_csv(args.transcripts, header=0, chunksize=500000): + chunk = chunk.loc[(chunk.qv > min_phred_score)] + chunk.rename(columns={"feature_name": "gene"}, inplace=True) + if negctrl_regex != "": + chunk = chunk[ + ~chunk.gene.str.contains(negctrl_regex, flags=re.IGNORECASE, regex=True) + ] + chunk.rename(columns={"x_location": "X", "y_location": "Y"}, inplace=True) + chunk["Count"] = 1 + chunk[oheader].to_csv( + output, sep="\t", mode="a", index=False, header=False, float_format="%.2f" + ) + logging.info(f"{chunk.shape[0]}") + feature = pd.concat( + [feature, chunk.groupby(by="gene").agg({"Count": "sum"}).reset_index()] + ) + x0 = chunk.X.min() + x1 = chunk.X.max() + y0 = chunk.Y.min() + y1 = chunk.Y.max() + xmin = min(int(xmin), int(x0)) + xmax = max(int(xmax), int(x1)) + ymin = min(int(ymin), int(y0)) + ymax = max(int(ymax), int(y1)) + + if os.path.exists(args.features): + feature_list = [] + with open(args.features, "r") as ff: + for line in ff: + feature_list.append(line.strip("\n")) + feature = feature.groupby(by="gene").agg({"Count": "sum"}).reset_index() + feature = feature[[x in feature_list for x in feature["gene"]]] + feature.to_csv(feature_file, sep="\t", index=False) + + f = os.path.join(os.path.dirname(output), "coordinate_minmax.tsv") + with open(f, "w") as wf: + wf.write(f"xmin\t{xmin}\n") + wf.write(f"xmax\t{xmax}\n") + wf.write(f"ymin\t{ymin}\n") + wf.write(f"ymax\t{ymax}\n") + + print("[FINISH]") + + +if __name__ == "__main__": + main() diff --git a/bin/segger_create_dataset.py b/bin/segger_create_dataset.py new file mode 100755 index 00000000..c73ab006 --- /dev/null +++ b/bin/segger_create_dataset.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Run segger create_dataset with spatialxe-specific preprocessing and workarounds. + +Wraps segger's create_dataset_fast.py with: + - bundle_local symlink prep (handles read-only S3/Fusion mounts) + - parquet column statistics (segger needs these) + - WORKAROUND: filter trainable tiles from test_tiles when segger commit 0787167 mis-splits + - WORKAROUND: replace NaN bd.x with zeros after get_polygon_props produces NaN + +Each WORKAROUND should be removable when the upstream segger bug is fixed. +""" + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + +# imports for actual work (used in functions below) +import pyarrow.parquet as pq +import pyarrow.compute as pc +import torch + + +SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/create_dataset_fast.py" + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--bundle-dir", required=True) + p.add_argument("--output-dir", required=True) + p.add_argument("--sample-type", required=True, choices=["xenium"]) + p.add_argument("--tile-width", type=int, required=True) + p.add_argument("--tile-height", type=int, required=True) + p.add_argument("--n-workers", type=int, required=True) + # remaining args forwarded to segger CLI + args, extra = p.parse_known_args() + return args, extra + + +def prepare_bundle(bundle_dir): + """Create local bundle dir with absolute symlinks (S3/Fusion read-only-safe).""" + Path("bundle_local").mkdir(exist_ok=True) + for item in Path(bundle_dir).iterdir(): + try: + abs_path = item.resolve() + except Exception: + abs_path = item + target = Path("bundle_local") / item.name + if target.exists() or target.is_symlink(): + target.unlink() + target.symlink_to(abs_path) + + # Segger expects nucleus_boundaries.parquet but Xenium bundles have cell_boundaries.parquet + nb = Path("bundle_local/nucleus_boundaries.parquet") + cb = Path("bundle_local/cell_boundaries.parquet") + if not nb.exists() and cb.exists(): + print( + "Creating nucleus_boundaries.parquet symlink from cell_boundaries.parquet" + ) + nb.symlink_to(cb.resolve()) + + print("Bundle contents:") + for item in sorted(Path("bundle_local").iterdir()): + print(f" {item.name}") + + +def add_parquet_stats(): + """Rewrite key parquet files with column statistics (segger requires them).""" + Path("bundle_stats").mkdir(exist_ok=True) + for fname in ["transcripts.parquet", "nucleus_boundaries.parquet"]: + src = Path("bundle_local") / fname + dst = Path("bundle_stats") / fname + if not src.exists(): + print(f" Skip {src}") + continue + t = pq.read_table(str(src)) + pq.write_table(t, str(dst), write_statistics=True, compression="snappy") + print(f" Done {fname} ({len(t)} rows)") + + # Symlink everything else from bundle_local into bundle_stats + for item in Path("bundle_local").iterdir(): + dst = Path("bundle_stats") / item.name + if not dst.exists(): + dst.symlink_to(item.resolve()) + + # Debug: check overlaps_nucleus column in transcripts + print("\n=== Debugging overlaps_nucleus data ===") + tx = pq.read_table("bundle_stats/transcripts.parquet") + bd = pq.read_table("bundle_stats/nucleus_boundaries.parquet") + if "overlaps_nucleus" in tx.column_names: + col = tx.column("overlaps_nucleus") + print(f"overlaps_nucleus dtype: {col.type}") + unique_vals = pc.unique(col) + print(f"overlaps_nucleus unique values: {unique_vals.to_pylist()[:10]}") + val_counts = pc.value_counts(col) + print(f"overlaps_nucleus value_counts: {val_counts.to_pylist()}") + else: + print("WARNING: overlaps_nucleus column NOT FOUND in transcripts.parquet") + + if "cell_id" in tx.column_names and "cell_id" in bd.column_names: + tx_cells = set(pc.unique(tx.column("cell_id")).to_pylist()) + bd_cells = set(pc.unique(bd.column("cell_id")).to_pylist()) + overlap = tx_cells & bd_cells + print(f"Transcripts unique cell_ids: {len(tx_cells)}") + print(f"Boundaries unique cell_ids: {len(bd_cells)}") + print(f"Overlapping cell_ids: {len(overlap)}") + print("=== End Debug ===\n") + + +def run_segger_cli(args, extra): + cmd = [ + "python3", + SEGGER_CLI, + "--base_dir", + "bundle_stats", + "--data_dir", + args.output_dir, + "--sample_type", + args.sample_type, + "--tile_width", + str(args.tile_width), + "--tile_height", + str(args.tile_height), + "--n_workers", + str(args.n_workers), + *extra, + ] + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd) + if result.returncode != 0: + sys.exit(result.returncode) + + +def filter_trainable_tiles_if_needed(prefix): + """ + WORKAROUND: segger commit 0787167 has a bug where all tiles end up in test_tiles + regardless of test_prob/val_prob settings. Move ONLY trainable tiles (those with + edge_label_index) from test_tiles to train_tiles. + + Remove this function once segger >= 0.1.x is bumped with the upstream fix. + """ + train_dir = Path(prefix) / "train_tiles" / "processed" + test_dir = Path(prefix) / "test_tiles" / "processed" + val_dir = Path(prefix) / "val_tiles" / "processed" + + train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 + test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 + val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 + print( + f"Dataset split (before fix): train={train_count} val={val_count} test={test_count}" + ) + + if train_count == 0 and test_count > 0: + print( + "Applying workaround: filtering trainable tiles from test_tiles (segger split bug)" + ) + moved = 0 + skipped = 0 + for tile_path in list(test_dir.iterdir()): + if not tile_path.name.endswith(".pt"): + continue + try: + tile = torch.load(str(tile_path), weights_only=False) + edge_store = tile["tx", "belongs", "bd"] + if ( + hasattr(edge_store, "edge_label_index") + and edge_store.edge_label_index.numel() > 0 + ): + shutil.move(str(tile_path), str(train_dir / tile_path.name)) + moved += 1 + else: + skipped += 1 + except Exception as e: + print(f"Warning: Could not process {tile_path.name}: {e}") + skipped += 1 + print(f"Moved {moved} trainable tiles to train_tiles") + print(f"Skipped {skipped} test-only tiles (no edge_label_index)") + + train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 + test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 + val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 + print( + f"Dataset split (after fix): train={train_count} val={val_count} test={test_count}" + ) + + if train_count == 0: + print(f"ERROR: No trainable tiles were created in {train_dir}", file=sys.stderr) + print( + "This usually means no transcripts overlap with nucleus boundaries in the dataset.", + file=sys.stderr, + ) + print( + "Check if the Xenium bundle contains valid overlaps_nucleus data in transcripts.parquet.", + file=sys.stderr, + ) + sys.exit(1) + print(f"Successfully created {train_count} trainable tiles") + + +def fix_bd_x_nan(prefix): + """ + WORKAROUND: segger's get_polygon_props() produces NaN boundary features (bd.x) + when polygon geometries have zero area or index misalignment during GeoDataFrame + construction. Replace NaN bd.x with zeros so BCEWithLogitsLoss doesn't propagate NaN. + + Remove this function once segger >= 0.1.x is bumped with the upstream fix. + """ + fixed = 0 + total = 0 + for split in ["train_tiles", "test_tiles", "val_tiles"]: + tile_dir = Path(prefix) / split / "processed" + if not tile_dir.is_dir(): + continue + for tile_path in tile_dir.iterdir(): + if not tile_path.name.endswith(".pt"): + continue + total += 1 + tile = torch.load(str(tile_path), weights_only=False) + bd_x = tile["bd"].x + if bd_x.isnan().any(): + tile["bd"].x = torch.nan_to_num(bd_x, nan=0.0) + torch.save(tile, str(tile_path)) + fixed += 1 + print(f"Fixed NaN bd.x in {fixed}/{total} tiles") + + +def main(): + args, extra = parse_args() + + # Ensure numba cache dir is writable (env var should be set by caller, but belt-and-suspenders) + os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) + os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) + + prepare_bundle(args.bundle_dir) + print("Adding statistics to parquet files...") + add_parquet_stats() + + # Sanity-check bundle_stats + print("bundle_stats contents:") + for item in sorted(Path("bundle_stats").iterdir()): + print(f" {item.name}") + + run_segger_cli(args, extra) + + filter_trainable_tiles_if_needed(args.output_dir) + fix_bd_x_nan(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/bin/segger_predict.py b/bin/segger_predict.py new file mode 100755 index 00000000..56a77ffc --- /dev/null +++ b/bin/segger_predict.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +""" +Run segger predict with spatialxe-specific preprocessing. + +Wraps segger's predict_fast.py with: + - GPU enumeration (replaces inline python3 -c torch check) + - WORKAROUND: patch predict_parquet.py at runtime to add torch.no_grad() for ~30-50% VRAM savings + - WORKAROUND: seed random.choice for deterministic GPU assignment (avoids stochastic OOM) + +Both WORKAROUNDs should be removable once the patches are upstreamed to segger. +""" + +import argparse +import os +import subprocess +import sys + + +SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/predict_fast.py" + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--models-dir", required=True) + p.add_argument("--segger-data-dir", required=True) + p.add_argument("--transcripts-file", required=True) + p.add_argument("--benchmarks-dir", required=True) + p.add_argument("--batch-size", type=int, required=True) + p.add_argument("--use-cc", required=True) + p.add_argument("--knn-method", required=True) + p.add_argument("--num-workers", type=int, required=True) + args, extra = p.parse_known_args() + return args, extra + + +def detect_gpus(): + """Return comma-separated list of available CUDA device ids (or "0" if none).""" + import torch + + print("=== GPU Detection (SEGGER_PREDICT) ===") + print(f"PyTorch CUDA available: {torch.cuda.is_available()}") + n = torch.cuda.device_count() + print(f"CUDA device count: {n}") + print("======================================") + if n > 0: + return ",".join(str(i) for i in range(n)) + return "0" + + +def patch_predict_parquet(): + """ + WORKAROUND: patch segger.prediction.predict_parquet at runtime. + + Avoids rebuilding the segger Docker image. Two patches: + 1. Add torch.no_grad() to disable gradient graphs during inference (~30-50% VRAM savings). + 2. Seed random for deterministic GPU assignment (avoids stochastic OOM). + + Remove this function once the patches are upstreamed to segger. + """ + import segger.prediction.predict_parquet as m + + pred_py = m.__file__ + print(f"Patching {pred_py}: torch.no_grad() + round-robin GPU assignment") + # Use sed via subprocess for in-place edit (matches the original behavior exactly) + subprocess.run( + [ + "sed", + "-i", + "s/with cp.cuda.Device(gpu_id):/with cp.cuda.Device(gpu_id), torch.no_grad():/", + pred_py, + ], + check=True, + ) + subprocess.run( + [ + "sed", + "-i", + "s/gpu_id = random.choice(gpu_ids)/random.seed(0); gpu_id = random.choice(gpu_ids)/", + pred_py, + ], + check=True, + ) + + +def run_segger_cli(args, extra, gpu_ids): + cmd = [ + "python3", + SEGGER_CLI, + "--models_dir", + args.models_dir, + "--segger_data_dir", + args.segger_data_dir, + "--transcripts_file", + args.transcripts_file, + "--benchmarks_dir", + args.benchmarks_dir, + "--batch_size", + str(args.batch_size), + "--use_cc", + str(args.use_cc), + "--knn_method", + args.knn_method, + "--num_workers", + str(args.num_workers), + "--gpu_ids", + gpu_ids, + *extra, + ] + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd) + if result.returncode != 0: + sys.exit(result.returncode) + + +def main(): + args, extra = parse_args() + + # Limit cupy GPU memory to 80% so PyTorch has headroom for graph attention ops + os.environ.setdefault("CUPY_GPU_MEMORY_LIMIT", "80%") + # Belt-and-suspenders: ensure PyTorch uses expandable segments + os.environ.setdefault( + "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512" + ) + # Numba cache directory + os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) + os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) + + gpu_ids = detect_gpus() + print(f"Using GPUs: {gpu_ids}") + + patch_predict_parquet() + + run_segger_cli(args, extra, gpu_ids) + + +if __name__ == "__main__": + main() diff --git a/bin/spatialdata_merge.py b/bin/spatialdata_merge.py new file mode 100755 index 00000000..409d8c00 --- /dev/null +++ b/bin/spatialdata_merge.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +"""Merge two spatialdata bundles to create a layered spatialdata object.""" + +import argparse +import json +import os +import shutil + +import spatialdata + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Merge two spatialdata bundles") + parser.add_argument("--raw-bundle", required=True, help="Path to raw spatialdata bundle") + parser.add_argument("--redefined-bundle", required=True, help="Path to redefined spatialdata bundle") + parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") + parser.add_argument("--output-folder", required=True, help="Output folder name") + return parser.parse_args() + + +def main(): + """Run spatialdata merge.""" + args = parse_args() + print("[START]") + + output_dir = f"spatialdata/{args.prefix}/{args.output_folder}" + + # Ensure the output folder exists + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) + + # Copy the entire reference bundle as is + for root, _, files in os.walk(args.raw_bundle): + rel_path = os.path.relpath(root, args.raw_bundle) + target_path = os.path.join(output_dir, rel_path) + os.makedirs(target_path, exist_ok=True) + for file in files: + shutil.copy(os.path.join(root, file), os.path.join(target_path, file)) + + # Rename folders in Points, Shapes, and Tables to raw_* + for category in ["points", "shapes", "tables"]: + category_path = os.path.join(output_dir, category) + if os.path.exists(category_path): + for folder in next(os.walk(category_path))[1]: + old_path = os.path.join(category_path, folder) + print(folder) + new_path = os.path.join(category_path, f"raw_{folder}") + os.rename(old_path, new_path) + + # Copy folders from redefined_bundle and rename them as redefined_* + for category in ["points", "shapes", "tables"]: + add_category_path = os.path.join(args.redefined_bundle, category) + output_category_path = os.path.join(output_dir, category) + os.makedirs(output_category_path, exist_ok=True) + + if os.path.exists(add_category_path): + for folder in next(os.walk(add_category_path))[1]: + src_folder = os.path.join(add_category_path, folder) + dest_folder = os.path.join(output_category_path, f"redefined_{folder}") + shutil.copytree(src_folder, dest_folder) + + # Invalidate consolidated metadata in zarr.json -- the directory renames above + # made the element paths in the metadata stale (e.g., 'points/transcripts' -> + # 'points/raw_transcripts'). Without consolidated metadata, sd.read_zarr() + # discovers elements by scanning the filesystem directly. + zarr_json = os.path.join(output_dir, "zarr.json") + if os.path.exists(zarr_json): + with open(zarr_json) as f: + meta = json.load(f) + if "consolidated_metadata" in meta: + del meta["consolidated_metadata"] + with open(zarr_json, "w") as f: + json.dump(meta, f) + print("[NOTE] Removed stale consolidated metadata from zarr.json") + + print("[FINISH]") + + +if __name__ == "__main__": + main() diff --git a/bin/spatialdata_meta.py b/bin/spatialdata_meta.py new file mode 100755 index 00000000..935f39b2 --- /dev/null +++ b/bin/spatialdata_meta.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +"""Add metadata to SpatialData bundle.""" + +import argparse +import json +import sys + +import pandas as pd +import spatialdata as sd +import zarr + +# Fix zarr v3 + anndata + numcodecs incompatibility: +# anndata's string writer passes numcodecs.VLenUTF8 to zarr.Group.create_array, +# but zarr v3 only accepts ArrayArrayCodec types. OME-Zarr 0.5 requires zarr v3 +# for images, so we can't downgrade the store format. Instead, we intercept +# create_array to strip numcodecs codecs and let zarr v3 handle strings natively. +import numcodecs +import zarr.core.group as _zarr_group + +_orig_create_array = _zarr_group.Group.create_array + + +def _v3_compat_create_array(self, *args, **kwargs): + """Strip numcodecs VLenUTF8 from codec params for zarr v3 compatibility.""" + for param in ("filters", "compressor", "object_codec"): + val = kwargs.get(param) + if val is None: + continue + if isinstance(val, numcodecs.vlen.VLenUTF8): + del kwargs[param] + elif isinstance(val, (list, tuple)): + cleaned = [v for v in val if not isinstance(v, numcodecs.vlen.VLenUTF8)] + if len(cleaned) != len(val): + if cleaned: + kwargs[param] = cleaned + else: + del kwargs[param] + return _orig_create_array(self, *args, **kwargs) + + +_zarr_group.Group.create_array = _v3_compat_create_array + + +def _is_arrow_backed(dtype): + """Check if a pandas dtype is backed by PyArrow.""" + return isinstance(dtype, pd.ArrowDtype) or ( + hasattr(dtype, "storage") and getattr(dtype, "storage", None) == "pyarrow" + ) or "pyarrow" in str(dtype) + + +def _convert_df_arrow_to_numpy(df): + """Convert Arrow-backed dtypes in a DataFrame to numpy object dtype.""" + for col in df.columns: + dtype = df[col].dtype + if _is_arrow_backed(dtype): + df[col] = df[col].astype("object") + elif isinstance(dtype, pd.CategoricalDtype): + cats = dtype.categories + if cats is not None and _is_arrow_backed(cats.dtype): + df[col] = df[col].cat.rename_categories(cats.astype("object")) + if _is_arrow_backed(df.index.dtype): + df.index = pd.Index(df.index.astype("object")) + + +def convert_arrow_to_numpy(sdata): + """Convert Arrow-backed dtypes to numpy for anndata zarr write compatibility.""" + for table_key in list(sdata.tables.keys()): + adata = sdata.tables[table_key] + _convert_df_arrow_to_numpy(adata.obs) + _convert_df_arrow_to_numpy(adata.var) + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Add metadata to SpatialData bundle") + parser.add_argument("--spatialdata-bundle", required=True, help="Path to spatialdata bundle") + parser.add_argument("--xenium-bundle", required=True, help="Path to xenium bundle") + parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") + parser.add_argument("--metadata", required=True, help="Metadata string from Nextflow meta map") + parser.add_argument("--output-folder", required=True, help="Output folder name") + return parser.parse_args() + + +def main(): + """Run spatialdata metadata addition.""" + args = parse_args() + print("[START]") + + sdata = sd.read_zarr(args.spatialdata_bundle) + + # Convert metadata into dict + print("[NOTE] Read in provenance ...") + metadata = args.metadata.strip("[]") # Remove square brackets + pairs = metadata.split(", ") # Split by comma and space + metadata = {k: v for k, v in (pair.split(":") for pair in pairs)} # Create dictionary + + for key in metadata: + if key not in sdata['raw_table'].uns['spatialdata_attrs']: + sdata['raw_table'].uns['spatialdata_attrs'][key] = metadata[key] + else: + print(f'[ERROR] {key} already exist in sdata[raw_table].uns[spatialdata_attrs].', file=sys.stderr) + + # Add experimental metadata + print("[NOTE] Read in experiment metadata ...") + sdata['raw_table'].uns['experiment_xenium'] = '' + metadata_experiment = f'{args.xenium_bundle}/experiment.xenium' + with open(metadata_experiment, "r") as f: + metadata_experiment = json.load(f) + sdata['raw_table'].uns['experiment_xenium'] = json.dumps(metadata_experiment) + + # Add gene panel metadata + print("[NOTE] Read in gene panel metadata ...") + sdata['raw_table'].uns['gene_panel'] = '' + metadata_gene_panel = f'{args.xenium_bundle}/gene_panel.json' + with open(metadata_gene_panel, "r") as f: + metadata_gene_panel = json.load(f) + sdata['raw_table'].uns['gene_panel'] = json.dumps(metadata_gene_panel) + + convert_arrow_to_numpy(sdata) + sdata.write(f"spatialdata/{args.prefix}/{args.output_folder}", overwrite=True, consolidate_metadata=True, sdata_formats=None) + + print("[FINISH]") + + +if __name__ == "__main__": + main() diff --git a/bin/spatialdata_write.py b/bin/spatialdata_write.py new file mode 100755 index 00000000..421e830f --- /dev/null +++ b/bin/spatialdata_write.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Write spatialdata object from segmentation format.""" + +import argparse +import sys + +import pandas as pd +import spatialdata +from spatialdata_io import xenium + +# Fix zarr v3 + anndata + numcodecs incompatibility: +# anndata's string writer passes numcodecs.VLenUTF8 to zarr.Group.create_array, +# but zarr v3 only accepts ArrayArrayCodec types. OME-Zarr 0.5 requires zarr v3 +# for images, so we can't downgrade the store format. Instead, we intercept +# create_array to strip numcodecs codecs and let zarr v3 handle strings natively. +import numcodecs +import zarr.core.group as _zarr_group + +_orig_create_array = _zarr_group.Group.create_array + + +def _v3_compat_create_array(self, *args, **kwargs): + """Strip numcodecs VLenUTF8 from codec params for zarr v3 compatibility.""" + for param in ("filters", "compressor", "object_codec"): + val = kwargs.get(param) + if val is None: + continue + if isinstance(val, numcodecs.vlen.VLenUTF8): + del kwargs[param] + elif isinstance(val, (list, tuple)): + cleaned = [v for v in val if not isinstance(v, numcodecs.vlen.VLenUTF8)] + if len(cleaned) != len(val): + if cleaned: + kwargs[param] = cleaned + else: + del kwargs[param] + return _orig_create_array(self, *args, **kwargs) + + +_zarr_group.Group.create_array = _v3_compat_create_array + + +def _is_arrow_backed(dtype): + """Check if a pandas dtype is backed by PyArrow.""" + return ( + isinstance(dtype, pd.ArrowDtype) + or (hasattr(dtype, "storage") and getattr(dtype, "storage", None) == "pyarrow") + or "pyarrow" in str(dtype) + ) + + +def _convert_df_arrow_to_numpy(df): + """Convert Arrow-backed dtypes in a DataFrame to numpy object dtype. + + Handles three cases: + 1. Regular columns with Arrow-backed dtypes + 2. Categorical columns whose categories are Arrow-backed + 3. Index with Arrow-backed dtype + """ + for col in df.columns: + dtype = df[col].dtype + if _is_arrow_backed(dtype): + df[col] = df[col].astype("object") + elif isinstance(dtype, pd.CategoricalDtype): + cats = dtype.categories + if cats is not None and _is_arrow_backed(cats.dtype): + df[col] = df[col].cat.rename_categories(cats.astype("object")) + if _is_arrow_backed(df.index.dtype): + df.index = pd.Index(df.index.astype("object")) + + +def convert_arrow_to_numpy(sdata): + """Convert Arrow-backed dtypes to numpy for anndata zarr write compatibility.""" + for table_key in list(sdata.tables.keys()): + adata = sdata.tables[table_key] + _convert_df_arrow_to_numpy(adata.obs) + _convert_df_arrow_to_numpy(adata.var) + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Write spatialdata object from segmentation format") + parser.add_argument("--bundle", required=True, help="Path to input bundle") + parser.add_argument("--prefix", required=True, help="Output prefix (sample ID)") + parser.add_argument("--output-folder", required=True, help="Output folder name") + parser.add_argument("--segmented-object", required=True, help="Segmented object type (cells, nuclei, cells_and_nuclei)") + parser.add_argument("--coordinate-space", required=True, help="Coordinate space (pixels, microns)") + parser.add_argument("--format", required=True, help="Input format (xenium)") + return parser.parse_args() + + +def main(): + """Run spatialdata write.""" + args = parse_args() + print("[START]") + + cells_as_circles = False + cells_boundaries = False + nucleus_boundaries = False + cells_labels = False + nucleus_labels = False + + if args.segmented_object == "cells": + cells_boundaries = True + cells_labels = True + elif args.segmented_object == "nuclei": + nucleus_boundaries = True + nucleus_labels = True + elif args.segmented_object == "cells_and_nuclei": + cells_boundaries = True + nucleus_boundaries = True + cells_labels = True + nucleus_labels = True + else: + cells_as_circles = False + + # set sd variables based on the coordinate space + if args.coordinate_space == "pixels": + cells_labels = True + nucleus_labels = True + # Labels are sufficient in pixel space; boundaries can contain + # degenerate polygons (< 4 vertices) from XeniumRanger that + # crash spatialdata_io's shapely LinearRing parser. + cells_boundaries = False + nucleus_boundaries = False + + if args.coordinate_space == "microns": + cells_labels = False + cells_boundaries = True + nucleus_boundaries = False + nucleus_labels = False + cells_as_circles = False + + if args.format == "xenium": + sd_xenium_obj = xenium( + args.bundle, + cells_as_circles=cells_as_circles, + cells_boundaries=cells_boundaries, + nucleus_boundaries=nucleus_boundaries, + cells_labels=cells_labels, + nucleus_labels=nucleus_labels, + transcripts=True, + morphology_mip=True, + morphology_focus=True, + ) + print(sd_xenium_obj) + convert_arrow_to_numpy(sd_xenium_obj) + sd_xenium_obj.write(f"spatialdata/{args.prefix}/{args.output_folder}") + else: + sys.exit("[ERROR] Format not found") + + print("[FINISH]") + + +if __name__ == "__main__": + main() diff --git a/modules/local/utility/convert_mask_uint32/resources/usr/bin/convert_mask_uint32.py b/bin/utility_convert_mask_uint32.py similarity index 100% rename from modules/local/utility/convert_mask_uint32/resources/usr/bin/convert_mask_uint32.py rename to bin/utility_convert_mask_uint32.py diff --git a/modules/local/utility/downscale_morphology/resources/usr/bin/downscale_morphology.py b/bin/utility_downscale_morphology.py similarity index 100% rename from modules/local/utility/downscale_morphology/resources/usr/bin/downscale_morphology.py rename to bin/utility_downscale_morphology.py diff --git a/modules/local/utility/extract_dapi/resources/usr/bin/extract_dapi.py b/bin/utility_extract_dapi.py similarity index 100% rename from modules/local/utility/extract_dapi/resources/usr/bin/extract_dapi.py rename to bin/utility_extract_dapi.py diff --git a/bin/utility_extract_data.py b/bin/utility_extract_data.py new file mode 100755 index 00000000..0ea737c2 --- /dev/null +++ b/bin/utility_extract_data.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +""" +Extract preview data from Baysor preview HTML reports. + +Parses embedded Vega-Lite spec variables and base64 PNG images from the +Baysor preview.html file, writing MultiQC-compatible TSV and PNG files. +""" + +import argparse +import base64 +import html +import json +import re +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import pandas as pd +from bs4 import BeautifulSoup + + +def get_png_files(soup: BeautifulSoup, outdir: Path) -> None: + """Get png base64 images following specific h1 tags in preview.html""" + target_ids = ["Transcript_Plots", "Noise_Level"] + outdir.mkdir(parents=True, exist_ok=True) + + for h1_id in target_ids: + h1_tag = soup.find("h1", id=h1_id) + if not h1_tag: + print(f"[WARN] No

with id {h1_id} found") + continue + + # Look for the first after the h1 in the DOM + img_tag = h1_tag.find_next("img") + if not img_tag or not img_tag.get("src"): + print(f"[WARN] No found after h1#{h1_id}") + continue + + img_src = img_tag["src"] + if img_src.startswith("data:image/png;base64,"): + base64_data = img_src.split(",", 1)[1] + data = base64.b64decode(base64_data) + else: + print(f"[WARN] img src is not base64 PNG for h1#{h1_id}") + continue + + # save png files with _mqc suffix for MultiQC integration + img_name = f"{h1_id}_mqc.png".lower() + out_path = outdir / img_name + with open(out_path, "wb") as f: + f.write(data) + + print(f"[INFO] Saved {img_name}") + + return None + + +def extract_js_object(text: str, start_idx: int) -> Tuple[Optional[str], int]: + """Extract json-like object starting at start_idx.""" + if start_idx >= len(text) or text[start_idx] != "{": + return None, start_idx + + stack, in_str, escape, quote = [], False, False, None + for i in range(start_idx, len(text)): + ch = text[i] + if in_str: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == quote: + in_str = False + else: + if ch in ('"', "'"): + in_str, quote = True, ch + elif ch == "{": + stack.append("{") + elif ch == "}": + stack.pop() + if not stack: + return text[start_idx : i + 1], i + 1 + elif ch == "/" and i + 1 < len(text): + # skip js comments + nxt = text[i + 1] + if nxt == "/": + end = text.find("\n", i + 2) + i = len(text) - 1 if end == -1 else end + elif nxt == "*": + end = text.find("*/", i + 2) + if end == -1: + break + i = end + 1 + + return None, start_idx + + +def js_to_json(js: str) -> str: + """Convert a JS object string to valid JSON.""" + # Remove comments + js = re.sub(r"/\*.*?\*/", "", js, flags=re.S) + js = re.sub(r"//[^\n]*", "", js) + + # Convert single-quoted strings to double-quoted strings + js = re.sub( + r"'((?:\\.|[^'\\])*)'", + lambda m: '"' + m.group(1).replace('"', '\\"') + '"', + js, + ) + + # Remove trailing commas + js = re.sub(r",\s*(?=[}\]])", "", js) + js = re.sub(r",\s*,+", ",", js) + + return js.strip() + + +def find_variables(script_text: str) -> Dict[str, str]: + """Find all 'var|let|const specN =' declarations and extract their objects.""" + specs: Dict[str, str] = {} + script_text = html.unescape(script_text) + pattern = re.compile(r"(?:var|let|const)\s+(spec\d+)\s*=\s*{", re.I) + + for match in pattern.finditer(script_text): + var = match.group(1) + obj, _ = extract_js_object(script_text, match.end() - 1) + if obj: + specs[var] = obj + else: + print(f"[WARN] Could not extract object for {var}") + return specs + + +def write_tsvs(specs: Dict[str, str], outdir: Path) -> List[Path]: + """Convert extracted json to tsv.""" + outdir.mkdir(parents=True, exist_ok=True) + written: List[Path] = [] + + for var, js_obj in specs.items(): + try: + data = json.loads(js_to_json(js_obj)) + values = data.get("data", {}).get("values", []) + if not values: + print(f"[WARN] No data.values found in {var}") + continue + + df = pd.DataFrame(values) + outpath = outdir / f"{var}_mqc.tsv" + + with open(outpath, "w") as f: + f.write("# plot_type: linegraph\n") + f.write(f"# section_name: {var}\n") + f.write("# description: Extracted preview data\n") + df.to_csv(f, sep="\t", index=False) + + written.append(outpath) + print(f"[INFO] Wrote {outpath} ({len(df)} rows × {len(df.columns)} cols)") + except Exception as e: + print(f"[ERROR] Failed to process {var}: {e}") + + return written + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Extract preview data from Baysor preview HTML reports." + ) + parser.add_argument( + "--preview-html", + required=True, + help="Path to Baysor preview HTML file", + ) + parser.add_argument( + "--prefix", + required=True, + help="Output directory prefix (sample ID)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + input_path: Path = Path(args.preview_html) + outdir: Path = Path(args.prefix) + + text = input_path.read_text(encoding="utf-8", errors="ignore") + soup = BeautifulSoup(text, "html.parser") + + # get the script section + if " argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Get transcript coordinate bounds from a Parquet file." + ) + parser.add_argument( + "--transcripts", + required=True, + help="Path to transcripts parquet file", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + result = get_coordinates(args.transcripts) + print(",".join(str(v) for v in result)) diff --git a/bin/utility_parquet_to_csv.py b/bin/utility_parquet_to_csv.py new file mode 100755 index 00000000..bfa19c40 --- /dev/null +++ b/bin/utility_parquet_to_csv.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +Convert a Parquet file to CSV format. + +Reads a Parquet file and writes it as CSV, optionally gzip-compressed. +""" + +import argparse +from pathlib import Path + +import pandas as pd + + +def convert_parquet( + transcripts: str, + extension: str = ".csv", + prefix: str = "", +) -> None: + """ + Convert a Parquet file to CSV or CSV.GZ format. + + Args: + transcripts: Filename of the input parquet file + extension: Output extension ('.csv' or '.gz' for gzip) + prefix: Output directory prefix + """ + df = pd.read_parquet(transcripts, engine="pyarrow") + + Path(prefix).mkdir(parents=True, exist_ok=True) + + if extension == ".gz": + output = transcripts.replace(".parquet", ".csv.gz") + df.to_csv(f"{prefix}/{output}", compression="gzip", index=False) + else: + output = transcripts.replace(".parquet", ".csv") + df.to_csv(f"{prefix}/{output}", index=False) + + return None + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Convert a Parquet file to CSV format." + ) + parser.add_argument( + "--transcripts", + required=True, + help="Input parquet filename", + ) + parser.add_argument( + "--extension", + default=".csv", + help="Output extension: '.csv' or '.gz' (default: .csv)", + ) + parser.add_argument( + "--prefix", + required=True, + help="Output directory prefix (sample ID)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + convert_parquet( + transcripts=args.transcripts, + extension=args.extension, + prefix=args.prefix, + ) diff --git a/bin/utility_resize_tif.py b/bin/utility_resize_tif.py new file mode 100755 index 00000000..6cca640d --- /dev/null +++ b/bin/utility_resize_tif.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +""" +Resize a segmentation TIFF mask to match transcript coordinates. + +This script rescales a segmentation mask image to match the coordinate +space of Xenium transcript data using microns-per-pixel metadata. +""" + +import argparse +import json +import os +from typing import Tuple + +import numpy as np +import pandas as pd +import tifffile +from skimage.transform import resize + + +def read_mask(mask_path: str) -> np.ndarray: + """Read the segmentation mask from a TIFF file.""" + print(f"Reading mask: {mask_path}") + mask = tifffile.imread(mask_path) + print(f"Mask shape: {mask.shape}, dtype: {mask.dtype}") + return mask + + +def read_transcript_bounds(transcript_path: str) -> Tuple[float, float, float, float]: + """Read transcript coordinates and return their bounding box.""" + print(f"Reading transcripts: {transcript_path}") + if transcript_path.endswith(".parquet"): + transcripts = pd.read_parquet(transcript_path, columns=["x_location", "y_location"]) + else: + transcripts = pd.read_csv(transcript_path) + + if "x_location" not in transcripts.columns or "y_location" not in transcripts.columns: + raise ValueError("Transcript file must contain 'x_location' and 'y_location' columns.") + + x_min, x_max = transcripts["x_location"].min(), transcripts["x_location"].max() + y_min, y_max = transcripts["y_location"].min(), transcripts["y_location"].max() + + print(f"Transcript bounds: X=({x_min:.2f}, {x_max:.2f}), Y=({y_min:.2f}, {y_max:.2f})") + return x_min, x_max, y_min, y_max + + +def read_microns_per_pixel(metadata_path: str) -> float: + """Extract microns_per_pixel or pixel_size from metadata JSON.""" + print(f"Reading metadata: {metadata_path}") + with open(metadata_path, "r") as f: + metadata = json.load(f) + + mpp = metadata.get("microns_per_pixel") or metadata.get("pixel_size") + if mpp is None: + raise KeyError("Metadata JSON must contain 'microns_per_pixel' or 'pixel_size'.") + + print(f"Microns per pixel: {mpp}") + return float(mpp) + + +def compute_target_size( + x_min: float, x_max: float, y_min: float, y_max: float, microns_per_pixel: float +) -> Tuple[int, int]: + """Compute new image size (in pixels) to cover given coordinates.""" + new_width = int(round((x_max - x_min) / microns_per_pixel)) + new_height = int(round((y_max - y_min) / microns_per_pixel)) + print(f"Target image size: {new_width} x {new_height} pixels") + return new_height, new_width + + +def resize_mask(mask: np.ndarray, new_shape: Tuple[int, int]) -> np.ndarray: + """Resize mask using nearest-neighbor interpolation (preserve labels).""" + print("Resizing mask...") + resized = resize( + mask, + new_shape, + order=0, # nearest neighbor to preserve segmentation labels + preserve_range=True, + anti_aliasing=False, + ).astype(mask.dtype) + print(f"Resized shape: {resized.shape}") + return resized + + +def main(mask_path: str, transcripts_path: str, metadata_path: str, output_path: str) -> None: + """Resize segmentation mask to match Xenium coordinate space.""" + # Validate input files + for path in [mask_path, transcripts_path, metadata_path]: + if not os.path.exists(path): + raise FileNotFoundError(f"File not found: {path}") + + # Load data + mask = read_mask(mask_path) + x_min, x_max, y_min, y_max = read_transcript_bounds(transcripts_path) + microns_per_pixel = read_microns_per_pixel(metadata_path) + + # Compute physical mask size + height, width = mask.shape + print(f"Original mask size: {width * microns_per_pixel:.2f} x {height * microns_per_pixel:.2f} um") + + # Compute target size + new_height, new_width = compute_target_size(x_min, x_max, y_min, y_max, microns_per_pixel) + + # Resize and save + resized_mask = resize_mask(mask, (new_height, new_width)) + tifffile.imwrite(output_path, resized_mask) + + print(f"Saved resized mask -> {output_path}") + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Resize a segmentation TIFF mask to match transcript coordinates." + ) + parser.add_argument("--mask", required=True, help="Path to segmentation mask TIFF") + parser.add_argument("--transcripts", required=True, help="Path to transcripts file") + parser.add_argument("--metadata", required=True, help="Path to metadata JSON") + parser.add_argument("--prefix", required=True, help="Output directory prefix") + parser.add_argument("--mask-filename", required=True, help="Original mask filename for output naming") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + os.makedirs(args.prefix, exist_ok=True) + output_mask: str = os.path.join(args.prefix, f"resized_{args.mask_filename}.tif") + + main( + mask_path=args.mask, + transcripts_path=args.transcripts, + metadata_path=args.metadata, + output_path=output_mask, + ) diff --git a/bin/utility_segger2xr.py b/bin/utility_segger2xr.py new file mode 100755 index 00000000..22889e82 --- /dev/null +++ b/bin/utility_segger2xr.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" +Convert Segger prediction output to XeniumRanger-compatible format. + +Reads Segger PREDICT output (transcripts.parquet with segger_cell_id), +produces Baysor-format segmentation CSV, refined transcripts parquet, +and GeoJSON cell boundary polygons for xeniumranger import-segmentation. +""" + +import argparse +import json +from pathlib import Path +from typing import List + +import pandas as pd +from scipy.spatial import ConvexHull + +# Expected columns in transcripts.parquet +REQUIRED_COLUMNS: List[str] = [ + "transcript_id", + "cell_id", + "overlaps_nucleus", + "feature_name", + "x_location", + "y_location", + "z_location", + "qv", +] + +# Column name for segger cell assignment (varies by segger version) +SEGGER_ID_CANDIDATES: List[str] = ["segger_cell_id", "segger_id"] + + +def refine_transcripts(parquet_path: str) -> pd.DataFrame: + """ + Read segger PREDICT output and extract cell assignments. + Supports both 'segger_cell_id' (newer) and 'segger_id' (older) column names. + """ + parquet_file = Path(parquet_path) + if not parquet_file.exists(): + raise FileNotFoundError(f"File not found: {parquet_path}") + + df = pd.read_parquet(parquet_file, engine="pyarrow") + + missing_cols = [col for col in REQUIRED_COLUMNS if col not in df.columns] + if missing_cols: + raise ValueError(f"Missing required columns: {missing_cols}") + + # Find segger cell assignment column + segger_col = None + for candidate in SEGGER_ID_CANDIDATES: + if candidate in df.columns: + segger_col = candidate + break + if segger_col is None: + raise ValueError( + f"No segger cell assignment column found. " + f"Expected one of {SEGGER_ID_CANDIDATES}, got columns: {list(df.columns)}" + ) + + # Replace cell_id with segger assignment + cell_id_index = df.columns.get_loc("cell_id") + df = df.drop(columns=["cell_id"]) + segger_series = df.pop(segger_col) + df.insert(cell_id_index, "cell_id", segger_series) + + return df + + +def build_cell_map(df: pd.DataFrame, min_transcripts: int = 3) -> dict: + """ + Build a mapping from raw segger cell IDs to non-numeric string IDs. + + Only includes cells that have: + - >= min_transcripts assigned transcripts + - At least one transcript with valid (non-NaN) x/y coordinates + + Cell IDs use "cell-N" format (hyphen + integer) as required by + xeniumranger's cell ID parser. Non-numeric to avoid polars Int64 inference. + """ + cell_ids = df["cell_id"].fillna("UNASSIGNED").astype(str) + is_unassigned = (cell_ids == "UNASSIGNED") | (cell_ids == "") | (cell_ids == "0") + assigned = cell_ids[~is_unassigned] + counts = assigned.value_counts() + enough_tx = set(counts[counts >= min_transcripts].index) + + # Exclude cells with all-NaN coordinates (no spatial info = useless) + has_coords = df.dropna(subset=["x_location", "y_location"]) + has_coords_ids = set(has_coords["cell_id"].fillna("UNASSIGNED").astype(str)) + valid_cells = sorted(enough_tx & has_coords_ids) + + return {cell: f"cell-{i + 1}" for i, cell in enumerate(valid_cells)} + + +def to_baysor_csv(df: pd.DataFrame, output_path: str, cell_map: dict) -> None: + """ + Convert transcript DataFrame to Baysor-compatible CSV format. + + xeniumranger 4.0 import-segmentation --transcript-assignment expects a + Baysor segmentation CSV with at minimum: transcript_id, cell, is_noise, + x, y columns. This function maps Xenium/Segger columns to Baysor format. + """ + baysor_df = pd.DataFrame() + baysor_df["transcript_id"] = df["transcript_id"] + baysor_df["x"] = df["x_location"] + baysor_df["y"] = df["y_location"] + baysor_df["z"] = df["z_location"] + baysor_df["gene"] = df["feature_name"] + + cell_ids = df["cell_id"].fillna("UNASSIGNED").astype(str) + is_unassigned = (cell_ids == "UNASSIGNED") | (cell_ids == "") | (cell_ids == "0") + baysor_df["cell"] = cell_ids.map(cell_map).fillna("") + baysor_df["is_noise"] = is_unassigned.astype(int) + + baysor_df.to_csv(output_path, index=False) + + n_assigned = (~is_unassigned).sum() + n_noise = is_unassigned.sum() + n_cells = len(cell_map) + print( + f"Baysor CSV: {n_assigned} assigned, {n_noise} noise, {n_cells} cells -> {output_path}" + ) + + +def _make_buffer_polygon(cx: float, cy: float, radius: float = 0.5) -> list: + """Create a small square polygon around a centroid as fallback.""" + return [ + [cx - radius, cy - radius], + [cx + radius, cy - radius], + [cx + radius, cy + radius], + [cx - radius, cy + radius], + [cx - radius, cy - radius], # close ring + ] + + +def generate_viz_polygons(df: pd.DataFrame, output_path: str, cell_map: dict) -> None: + """ + Generate a GeoJSON file with cell boundary polygons. + + Uses ConvexHull when possible; falls back to a small buffer polygon around + the centroid for cells with < 3 unique points or collinear points. + + Required by xeniumranger import-segmentation when using --transcript-assignment. + Each feature MUST have a top-level "id" field (xeniumranger reads item["id"]). + Cell IDs must match those in the Baysor CSV. + """ + assigned = df[ + df["cell_id"].notna() + & (df["cell_id"].astype(str) != "UNASSIGNED") + & (df["cell_id"].astype(str) != "") + ].copy() + + features = [] + grouped = assigned.groupby("cell_id") + + for cell_id, group in grouped: + mapped_id = cell_map.get(str(cell_id)) + if mapped_id is None: + continue + + coords = group[["x_location", "y_location"]].dropna().values + + polygon_coords = None + if len(coords) >= 3: + try: + hull = ConvexHull(coords) + hull_points = coords[hull.vertices].tolist() + hull_points.append(hull_points[0]) # close polygon ring + polygon_coords = hull_points + except Exception: + pass + + # Fallback: buffer polygon around centroid + if polygon_coords is None: + cx, cy = coords.mean(axis=0).astype(float) + polygon_coords = _make_buffer_polygon(cx, cy) + + features.append( + { + "type": "Feature", + "id": mapped_id, + "geometry": { + "type": "Polygon", + "coordinates": [polygon_coords], + }, + "properties": {"cell_id": mapped_id}, + } + ) + + geojson = {"type": "FeatureCollection", "features": features} + + with open(output_path, "w") as f: + json.dump(geojson, f) + + print(f"Generated {len(features)} cell polygons in {output_path}") + + +def main(input_file: str, prefix: str, min_transcripts: int = 3) -> None: + """Run the full segger-to-xeniumranger conversion pipeline.""" + Path(prefix).mkdir(parents=True, exist_ok=True) + transcripts = refine_transcripts(input_file) + + # Build cell ID mapping, filtering cells with < min_transcripts + cell_map = build_cell_map(transcripts, min_transcripts=min_transcripts) + + # xeniumranger 4.0 expects Baysor-format CSV (not parquet) with is_noise column + to_baysor_csv(transcripts, f"{prefix}/segmentation.csv", cell_map) + + # Also save the refined parquet for downstream use + transcripts.to_parquet(f"{prefix}/transcripts.parquet", engine="pyarrow") + + # Generate cell boundary polygons (required companion to --transcript-assignment) + # Uses ConvexHull when possible; falls back to buffer polygon for edge cases + generate_viz_polygons(transcripts, f"{prefix}/segmentation_polygons.json", cell_map) + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Convert Segger prediction output to XeniumRanger-compatible format." + ) + parser.add_argument( + "--transcripts", + required=True, + help="Path to Segger output transcripts parquet file", + ) + parser.add_argument( + "--prefix", + required=True, + help="Output directory prefix (sample ID)", + ) + parser.add_argument( + "--min-transcripts", + type=int, + default=3, + help="Minimum transcripts per cell (default: 3)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main( + input_file=args.transcripts, + prefix=args.prefix, + min_transcripts=args.min_transcripts, + ) diff --git a/bin/utility_split_transcripts.py b/bin/utility_split_transcripts.py new file mode 100755 index 00000000..275fbab1 --- /dev/null +++ b/bin/utility_split_transcripts.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +Split transcript coordinates into spatial tiles. + +Reads a Xenium transcripts.parquet file and computes quantile-based spatial +tiles, writing a splits.csv with tile boundaries. +""" + +import argparse +import os +from typing import List + +import pandas as pd + + +def compute_quantile_ranges(df: pd.DataFrame, col: str, n_bins: int) -> List: + """ + Compute the bin edges for `df[col]` such that each of the n_bins + has ~equal count of points. Returns a list of (min, max) tuples. + """ + _, bins = pd.qcut(df[col], q=n_bins, retbins=True, duplicates="drop") + + ranges = [(bins[i], bins[i + 1]) for i in range(len(bins) - 1)] + + return ranges + + +def make_tiles(df: pd.DataFrame, x_bins: int, y_bins: int) -> pd.DataFrame: + """ + Produce a DataFrame with one row per tile: + tile_id, x_min, x_max, y_min, y_max + """ + x_ranges = compute_quantile_ranges(df, "x_location", x_bins) + y_ranges = compute_quantile_ranges(df, "y_location", y_bins) + + tiles = [] + for ix, (x_min, x_max) in enumerate(x_ranges, start=1): + for iy, (y_min, y_max) in enumerate(y_ranges, start=1): + tiles.append( + { + "tile_id": f"{ix}_{iy}", + "x_min": x_min, + "x_max": x_max, + "y_min": y_min, + "y_max": y_max, + } + ) + + return pd.DataFrame(tiles) + + +def main( + transcripts: str, + x_bins: int = 10, + y_bins: int = 10, + prefix: str = "", +) -> None: + """Generate spatial tile splits from transcript coordinates.""" + # read parquet file + df = pd.read_parquet(transcripts, engine="fastparquet") + + # compute tiles + tiles_df = make_tiles(df, x_bins, y_bins) + + # save csv file + os.makedirs(prefix, exist_ok=True) + tiles_df.to_csv(f"{prefix}/splits.csv", index=False) + + return None + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Split transcript coordinates into spatial tiles." + ) + parser.add_argument( + "--transcripts", + required=True, + help="Path to transcripts parquet file", + ) + parser.add_argument( + "--x-bins", + type=int, + required=True, + help="Number of bins along X axis", + ) + parser.add_argument( + "--y-bins", + type=int, + required=True, + help="Number of bins along Y axis", + ) + parser.add_argument( + "--prefix", + required=True, + help="Output directory prefix", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main( + transcripts=args.transcripts, + x_bins=args.x_bins, + y_bins=args.y_bins, + prefix=args.prefix, + ) diff --git a/modules/local/utility/upscale_mask/resources/usr/bin/upscale_mask.py b/bin/utility_upscale_mask.py similarity index 100% rename from modules/local/utility/upscale_mask/resources/usr/bin/upscale_mask.py rename to bin/utility_upscale_mask.py diff --git a/modules/local/xenium_patch/stitch/resources/usr/bin/stitch_postprocess.py b/bin/xenium_patch_stitch_postprocess.py similarity index 100% rename from modules/local/xenium_patch/stitch/resources/usr/bin/stitch_postprocess.py rename to bin/xenium_patch_stitch_postprocess.py diff --git a/bin/xenium_patch_stitch_transcripts.py b/bin/xenium_patch_stitch_transcripts.py new file mode 100755 index 00000000..d9fb8d41 --- /dev/null +++ b/bin/xenium_patch_stitch_transcripts.py @@ -0,0 +1,808 @@ +#!/usr/bin/env python3 +"""Stitch per-patch Baysor segmentation results into unified output. + +Standalone script that replaces the xenium_patch CLI package's stitch +functionality. Uses sopa's solve_conflicts() for overlap resolution. +""" + +from __future__ import annotations + +import argparse +import json +import os +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path + +import geopandas as gpd +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.csv as pa_csv +import shapely +from shapely.affinity import translate +from shapely.geometry import mapping, shape +from sopa.segmentation.resolve import solve_conflicts + +# --------------------------------------------------------------------------- +# Geometry helpers +# --------------------------------------------------------------------------- + + +def _ensure_polygon(geom) -> "shapely.Polygon | None": + """Extract a single Polygon from any geometry, or return None. + + XeniumRanger only accepts Polygon. make_valid() and solve_conflicts + can produce MultiPolygon, GeometryCollection, MultiLineString, etc. + """ + if geom is None or geom.is_empty: + return None + if geom.geom_type == "Polygon": + return geom + if geom.geom_type == "MultiPolygon": + return max(geom.geoms, key=lambda g: g.area) + if geom.geom_type == "GeometryCollection": + polys = [g for g in geom.geoms if g.geom_type == "Polygon"] + return max(polys, key=lambda g: g.area) if polys else None + # LineString, MultiLineString, Point, etc. — not a polygon + return None + + +# --------------------------------------------------------------------------- +# Inline types (from _types.py) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Bounds: + """Axis-aligned bounding box in either pixel or micron coordinates.""" + + x_min: float + x_max: float + y_min: float + y_max: float + + +@dataclass(frozen=True) +class PatchInfo: + """Metadata for a single patch in the grid.""" + + patch_id: str + row: int + col: int + global_bounds_px: Bounds + global_bounds_um: Bounds + core_bounds_px: Bounds + core_bounds_um: Bounds + + +@dataclass +class PatchGridMetadata: + """Full grid metadata, serializable to JSON.""" + + version: str + bundle_path: str + image_height_px: int + image_width_px: int + pixel_size_um: float + transcript_extent_um: Bounds + grid_rows: int + grid_cols: int + overlap_um: float + overlap_px: int + patches: list[PatchInfo] + grid_type: str = "uniform" + + +# --------------------------------------------------------------------------- +# Internal result containers +# --------------------------------------------------------------------------- + + +@dataclass +class _PatchGeoResult: + """Result of parallel GeoJSON processing for a single patch.""" + + features: list[dict] + cell_ids: list[str] + + +@dataclass +class _PatchCsvResult: + """Result of parallel CSV reading for a single patch.""" + + table: pa.Table + has_cell_col: bool + has_x_col: bool + has_y_col: bool + has_gene_col: bool = False + has_feature_name_col: bool = False + + +# --------------------------------------------------------------------------- +# Grid metadata I/O (from grid.py) +# --------------------------------------------------------------------------- + + +def _dict_to_bounds(d: dict) -> Bounds: + return Bounds(d["x_min"], d["x_max"], d["y_min"], d["y_max"]) + + +def load_grid_metadata(input_path: Path) -> PatchGridMetadata: + """Deserialize PatchGridMetadata from JSON. + + Args: + input_path: Path to JSON file to read. + + Returns: + Reconstructed PatchGridMetadata. + """ + with open(input_path) as f: + data = json.load(f) + + patches = [ + PatchInfo( + patch_id=p["patch_id"], + row=p["row"], + col=p["col"], + global_bounds_px=_dict_to_bounds(p["global_bounds_px"]), + global_bounds_um=_dict_to_bounds(p["global_bounds_um"]), + core_bounds_px=_dict_to_bounds(p["core_bounds_px"]), + core_bounds_um=_dict_to_bounds(p["core_bounds_um"]), + ) + for p in data["patches"] + ] + + return PatchGridMetadata( + version=data["version"], + bundle_path=data["bundle_path"], + image_height_px=data["image_height_px"], + image_width_px=data["image_width_px"], + pixel_size_um=data["pixel_size_um"], + transcript_extent_um=_dict_to_bounds(data["transcript_extent_um"]), + grid_rows=data["grid_rows"], + grid_cols=data["grid_cols"], + overlap_um=data["overlap_um"], + overlap_px=data["overlap_px"], + grid_type=data.get("grid_type", "uniform"), + patches=patches, + ) + + +# --------------------------------------------------------------------------- +# GeoJSON I/O (from polygon_io.py) +# --------------------------------------------------------------------------- + + +def _normalize_geometry_collection(geojson: dict) -> dict: + """Convert a GeometryCollection to a FeatureCollection. + + proseg-to-baysor produces a non-standard GeoJSON GeometryCollection where + each geometry object has a custom ``cell`` key (bare integer) instead of + using Feature wrappers. This normalises it to a standard FeatureCollection + with ``id`` and ``properties.cell_id`` on each feature, using the + ``"cell-{N}"`` format that matches the companion CSV. + + Args: + geojson: Parsed GeoJSON dict with type GeometryCollection. + + Returns: + Standard FeatureCollection dict. + """ + features = [] + for geom in geojson.get("geometries", []): + cell_raw = geom.get("cell", "") + cell_id = str(cell_raw) + clean_geom = {k: v for k, v in geom.items() if k != "cell"} + feature = { + "type": "Feature", + "id": cell_id, + "geometry": clean_geom, + "properties": {"cell_id": cell_id}, + } + features.append(feature) + return {"type": "FeatureCollection", "features": features} + + +def read_geojson(geojson_path: Path) -> dict: + """Read a GeoJSON file and normalise to FeatureCollection. + + Handles both standard FeatureCollections and the GeometryCollection + format produced by proseg-to-baysor. + + Args: + geojson_path: Path to the GeoJSON file. + + Returns: + Parsed GeoJSON dict (always a FeatureCollection). + """ + with open(geojson_path) as f: + data = json.load(f) + if data.get("type") == "GeometryCollection": + return _normalize_geometry_collection(data) + return data + + +def transform_polygons(geojson: dict, offset_x: float, offset_y: float) -> dict: + """Shift all polygon coordinates by (offset_x, offset_y). + + Args: + geojson: Input FeatureCollection. + offset_x: Translation in x. + offset_y: Translation in y. + + Returns: + New FeatureCollection with shifted geometries. + """ + features = [] + for feat in geojson.get("features", []): + geom = shape(feat["geometry"]) + shifted = translate(geom, xoff=offset_x, yoff=offset_y) + new_feat = {**feat, "geometry": mapping(shifted)} + features.append(new_feat) + return {"type": "FeatureCollection", "features": features} + + +def write_geojson(geojson: dict, output_path: Path) -> None: + """Write a GeoJSON FeatureCollection. + + Args: + geojson: GeoJSON dict to write. + output_path: Destination path (parent dirs created automatically). + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(geojson, f) + + +# --------------------------------------------------------------------------- +# Arrow utilities (from _arrow_utils.py) +# --------------------------------------------------------------------------- + + +def float_str_array(f64_array: pa.Array) -> pa.Array: + """Convert a float64 pyarrow array to string using Python's str(float) format. + + pyarrow's built-in cast omits trailing '.0' for whole numbers. This + function ensures output matches str(float(...)) for CSV compatibility. + + Args: + f64_array: Float64 pyarrow array to convert. + + Returns: + String pyarrow array with Python-formatted float values. + """ + return pa.array( + [str(v) if v is not None else None for v in f64_array.to_pylist()], + type=pa.string(), + ) + + +# --------------------------------------------------------------------------- +# Parallel I/O +# --------------------------------------------------------------------------- + + +def _read_and_transform_geojson( + patch: PatchInfo, + patches_dir: Path, + geojson_filename: str, +) -> _PatchGeoResult | None: + """Read, transform GeoJSON for a single patch (no core clipping). + + Args: + patch: Patch metadata. + patches_dir: Root patches directory. + geojson_filename: GeoJSON filename within each patch directory. + + Returns: + _PatchGeoResult with features and cell IDs, or None if no GeoJSON. + """ + geojson_path = patches_dir / patch.patch_id / geojson_filename + if not geojson_path.exists(): + return None + + geojson = read_geojson(geojson_path) + + offset_x = patch.global_bounds_um.x_min + offset_y = patch.global_bounds_um.y_min + geojson = transform_polygons(geojson, offset_x, offset_y) + + features = geojson.get("features", []) + seen: set[str] = set() + cell_ids: list[str] = [] + for feat in features: + old_id = str(feat.get("id", feat.get("properties", {}).get("cell_id", ""))) + if old_id not in seen: + seen.add(old_id) + cell_ids.append(old_id) + + return _PatchGeoResult(features=features, cell_ids=cell_ids) + + +def _read_patch_csv( + patch: PatchInfo, + patches_dir: Path, + csv_filename: str, +) -> _PatchCsvResult | None: + """Read a patch CSV into a pyarrow Table. + + All columns are read as strings to preserve exact formatting. + + Args: + patch: Patch metadata. + patches_dir: Root patches directory. + csv_filename: CSV filename within each patch directory. + + Returns: + _PatchCsvResult with the table and column presence flags, or None. + """ + csv_path = patches_dir / patch.patch_id / csv_filename + if not csv_path.exists(): + return None + + with open(csv_path) as fh: + header_line = fh.readline().strip() + col_names = header_line.split(",") + all_string_types = {name: pa.string() for name in col_names} + + table = pa_csv.read_csv( + csv_path, + convert_options=pa_csv.ConvertOptions( + column_types=all_string_types, + strings_can_be_null=False, + ), + read_options=pa_csv.ReadOptions(use_threads=True), + ) + + return _PatchCsvResult( + table=table, + has_cell_col="cell" in table.column_names, + has_x_col="x" in table.column_names, + has_y_col="y" in table.column_names, + has_gene_col="gene" in table.column_names, + has_feature_name_col="feature_name" in table.column_names, + ) + + +# --------------------------------------------------------------------------- +# CSV processing +# --------------------------------------------------------------------------- + + +def _transform_patch_coords( + csv_result: _PatchCsvResult, + offset_x: float, + offset_y: float, +) -> pa.Table: + """Shift transcript coordinates from local patch space to global space. + + Args: + csv_result: The raw CSV table and column flags. + offset_x: X offset for coordinate transform (microns). + offset_y: Y offset for coordinate transform (microns). + + Returns: + Table with x, y columns shifted to global coordinates. + """ + table = csv_result.table + + if table.num_rows == 0: + return table + + if csv_result.has_x_col: + x_f64 = pc.add( + table.column("x").cast(pa.float64()), + pa.scalar(offset_x, type=pa.float64()), + ) + table = table.set_column( + table.schema.get_field_index("x"), + "x", + float_str_array(x_f64), + ) + if csv_result.has_y_col: + y_f64 = pc.add( + table.column("y").cast(pa.float64()), + pa.scalar(offset_y, type=pa.float64()), + ) + table = table.set_column( + table.schema.get_field_index("y"), + "y", + float_str_array(y_f64), + ) + + return table + + +# --------------------------------------------------------------------------- +# Sopa conflict resolution +# --------------------------------------------------------------------------- + + +def _stitch_sopa_resolve( + metadata: PatchGridMetadata, + geo_results: list[_PatchGeoResult | None], + csv_results: list[_PatchCsvResult | None], + all_geojson_features: list[dict], + all_tables: list[pa.Table], + threshold: float = 0.5, +) -> set[str]: + """Stitch per-patch segmentation using spatial containment assignment. + + 1. Collect ALL non-empty polygons from all patches (no transcript filtering). + 2. Resolve overlapping polygons via sopa's solve_conflicts(). + 3. Assign sequential global cell IDs (cell-1, cell-2, ...). + 4. Spatially assign transcripts to resolved polygons using STRtree. + 5. Noise transcripts (outside all polygons) kept only from their core patch. + + This approach works regardless of whether Baysor's CSV ``cell`` column + matches GeoJSON cell IDs -- all assignment is done by spatial containment. + + Args: + metadata: Grid metadata with patch list. + geo_results: Per-patch GeoJSON results (already in global coords). + csv_results: Per-patch CSV results. + all_geojson_features: Output list to append resolved GeoJSON features. + all_tables: Output list to append processed CSV tables. + threshold: Overlap threshold for sopa's solve_conflicts (0-1). + + Returns: + Set of global cell IDs created by merging overlapping cells. + """ + # --- Phase 1: Collect all polygons from all patches --- + all_polygons: list = [] + patch_indices_list: list[int] = [] + + for i, patch in enumerate(metadata.patches): + geo_result = geo_results[i] + if geo_result is None: + continue + + for feat in geo_result.features: + polygon = shape(feat["geometry"]) + if polygon.is_empty: + continue + if not polygon.is_valid: + polygon = shapely.make_valid(polygon) + # Ensure we have a single Polygon (xeniumranger rejects all else) + polygon = _ensure_polygon(polygon) + if polygon is None: + continue + + all_polygons.append(polygon) + patch_indices_list.append(i) + + if not all_polygons: + print("[stitch] No polygons found in any patch") + # Still transform and collect CSVs as noise-only + for i, patch in enumerate(metadata.patches): + csv_result = csv_results[i] + if csv_result is None: + continue + offset_x = patch.global_bounds_um.x_min + offset_y = patch.global_bounds_um.y_min + transformed = _transform_patch_coords(csv_result, offset_x, offset_y) + if transformed.num_rows > 0: + all_tables.append(transformed) + return set() + + # --- Phase 2: Resolve overlapping polygons via sopa --- + patch_idx_array = np.array(patch_indices_list, dtype=np.int64) + input_gdf = gpd.GeoDataFrame(geometry=all_polygons) + resolved_gdf, kept_indices = solve_conflicts( + input_gdf, + threshold=threshold, + patch_indices=patch_idx_array, + return_indices=True, + ) + + # --- Phase 3: Assign global cell IDs to resolved polygons --- + merged_cell_ids: set[str] = set() + kept_arr = np.asarray(kept_indices) + resolved_polys: list = [] + resolved_ids: list[str] = [] + + for rank, orig_idx in enumerate(kept_arr, start=1): + global_id = f"cell-{rank}" + geom = resolved_gdf.geometry.iloc[rank - 1] + + # Ensure single Polygon after solve_conflicts union + geom = _ensure_polygon(geom) + if geom is None: + continue + + if orig_idx < 0: + merged_cell_ids.add(global_id) + + resolved_polys.append(geom) + resolved_ids.append(global_id) + + all_geojson_features.append( + { + "type": "Feature", + "id": global_id, + "geometry": mapping(geom), + "properties": {"cell_id": global_id}, + } + ) + + print( + f"[stitch] Resolved {len(all_polygons)} input polygons to " + f"{len(resolved_polys)} cells ({len(merged_cell_ids)} merged)" + ) + + # --- Phase 4: Spatial transcript assignment via STRtree --- + poly_tree = shapely.STRtree(resolved_polys) + + for i, patch in enumerate(metadata.patches): + csv_result = csv_results[i] + if csv_result is None: + continue + + offset_x = patch.global_bounds_um.x_min + offset_y = patch.global_bounds_um.y_min + core = patch.core_bounds_um + + transformed = _transform_patch_coords(csv_result, offset_x, offset_y) + if transformed.num_rows == 0: + continue + + if not csv_result.has_x_col or not csv_result.has_y_col: + all_tables.append(transformed) + continue + + # Get global coordinates for spatial query + gx = transformed.column("x").cast(pa.float64()).to_numpy(zero_copy_only=False) + gy = transformed.column("y").cast(pa.float64()).to_numpy(zero_copy_only=False) + points = shapely.points(gx, gy) + + # Query STRtree: returns (input_indices, tree_indices) + point_hits, poly_hits = poly_tree.query(points, predicate="intersects") + + # Build point -> cell_id mapping (first hit wins) + point_to_cell: dict[int, str] = {} + for pt_idx, poly_idx in zip(point_hits, poly_hits): + if pt_idx not in point_to_cell: + point_to_cell[pt_idx] = resolved_ids[poly_idx] + + # Build cell and is_noise columns + n_rows = transformed.num_rows + cell_arr = [""] * n_rows + is_noise_arr = ["true"] * n_rows + for pt_idx, cell_id in point_to_cell.items(): + cell_arr[pt_idx] = cell_id + is_noise_arr[pt_idx] = "false" + + # Filter noise transcripts to core bounds only + # Assigned transcripts are kept from all patches (dedup later by transcript_id) + in_core = ( + (gx >= core.x_min) + & (gx < core.x_max) + & (gy >= core.y_min) + & (gy < core.y_max) + ) + is_assigned = np.array([c != "" for c in cell_arr]) + keep_mask = pa.array(is_assigned | in_core, type=pa.bool_()) + + filtered = transformed.filter(keep_mask) + cell_arr_filtered = [c for c, k in zip(cell_arr, (is_assigned | in_core)) if k] + is_noise_filtered = [ + n for n, k in zip(is_noise_arr, (is_assigned | in_core)) if k + ] + + if filtered.num_rows == 0: + continue + + # Set cell and is_noise columns + cell_idx = ( + filtered.schema.get_field_index("cell") + if "cell" in filtered.column_names + else None + ) + if cell_idx is not None: + filtered = filtered.set_column( + cell_idx, "cell", pa.array(cell_arr_filtered, type=pa.string()) + ) + else: + filtered = filtered.append_column( + "cell", pa.array(cell_arr_filtered, type=pa.string()) + ) + + noise_idx = ( + filtered.schema.get_field_index("is_noise") + if "is_noise" in filtered.column_names + else None + ) + if noise_idx is not None: + filtered = filtered.set_column( + noise_idx, + "is_noise", + pa.array(is_noise_filtered, type=pa.string()), + ) + else: + filtered = filtered.append_column( + "is_noise", pa.array(is_noise_filtered, type=pa.string()) + ) + + all_tables.append(filtered) + + return merged_cell_ids + + +# --------------------------------------------------------------------------- +# Main orchestrator +# --------------------------------------------------------------------------- + + +def stitch_transcript_assignments( + patches_dir: Path, + output_dir: Path, + csv_filename: str = "segmentation.csv", + geojson_filename: str = "segmentation_polygons.json", + max_workers: int | None = None, +) -> None: + """Stitch per-patch transcript assignments and polygons into unified output. + + For each patch, reads the transcript assignment CSV and polygon GeoJSON. + Cells are deduplicated using sopa's solve_conflicts() which resolves + overlapping cells at patch boundaries based on area overlap ratio. + + Processing is split into a parallel I/O phase (reading GeoJSON and CSV + files via thread pool) and a sequential phase (dedup, global cell ID + assignment, remapping, and concatenation). + + Args: + patches_dir: Directory containing patch subdirectories and patch_grid.json. + output_dir: Output directory for stitched CSV and GeoJSON. + csv_filename: CSV filename within each patch directory. + geojson_filename: GeoJSON filename within each patch directory. + max_workers: Maximum number of threads for parallel I/O. + """ + patches_dir = Path(patches_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + metadata = load_grid_metadata(patches_dir / "patch_grid.json") + + n_patches = len(metadata.patches) + if max_workers is None: + max_workers = min(n_patches, os.cpu_count() or 1) + + # ---- Parallel phase: read GeoJSON and CSV files concurrently ---- + with ThreadPoolExecutor(max_workers=max_workers) as executor: + geo_futures = [ + executor.submit( + _read_and_transform_geojson, p, patches_dir, geojson_filename + ) + for p in metadata.patches + ] + csv_futures = [ + executor.submit(_read_patch_csv, p, patches_dir, csv_filename) + for p in metadata.patches + ] + geo_results = [f.result() for f in geo_futures] + csv_results = [f.result() for f in csv_futures] + + # ---- Sequential phase: assign global cell IDs, remap, concatenate ---- + all_tables: list[pa.Table] = [] + all_geojson_features: list[dict] = [] + + _stitch_sopa_resolve( + metadata, + geo_results, + csv_results, + all_geojson_features, + all_tables, + threshold=0.5, + ) + + # Concatenate all patch tables + if all_tables: + merged = pa.concat_tables(all_tables) + + # Deduplicate by transcript_id: prefer assigned over noise + if "transcript_id" in merged.column_names: + if "cell" in merged.column_names: + is_noise = pc.equal(merged.column("cell"), "").cast(pa.int8()) + row_order = pa.array(np.arange(merged.num_rows), type=pa.int64()) + sort_table = pa.table({"_noise": is_noise, "_row": row_order}) + sort_indices = pc.sort_indices( + sort_table, + sort_keys=[("_noise", "ascending"), ("_row", "ascending")], + ) + merged = merged.take(sort_indices) + + tid_np = merged.column("transcript_id").to_numpy(zero_copy_only=False) + _, first_indices = np.unique(tid_np, return_index=True) + first_indices.sort() + merged = merged.take(first_indices) + + # Log assignment stats + if "cell" in merged.column_names: + cell_vals = merged.column("cell").to_pylist() + n_assigned = sum(1 for c in cell_vals if c) + n_noise = sum(1 for c in cell_vals if not c) + print( + f"[stitch] Final: {merged.num_rows} transcripts, " + f"{n_assigned} assigned, {n_noise} noise" + ) + + # Cast is_noise to integer for xeniumranger compatibility + if "is_noise" in merged.column_names: + noise_col = merged.column("is_noise") + if noise_col.type == pa.string(): + lower = pc.utf8_lower(noise_col) + is_true = pc.or_(pc.equal(lower, "true"), pc.equal(lower, "1")) + idx = merged.column_names.index("is_noise") + merged = merged.set_column(idx, "is_noise", is_true.cast(pa.int8())) + + # Write CSV + if merged.num_rows > 0: + csv_out = output_dir / "xr-transcript-metadata.csv" + pa_csv.write_csv( + merged, + csv_out, + write_options=pa_csv.WriteOptions(quoting_style="needed"), + ) + + # Safety net: remove orphan polygons with zero transcripts + if all_geojson_features and all_tables: + csv_cell_ids: set[str] = set() + if "cell" in merged.column_names: + csv_cell_ids = set(c for c in merged.column("cell").to_pylist() if c) + all_geojson_features = [ + f + for f in all_geojson_features + if str(f.get("id", f.get("properties", {}).get("cell_id", ""))) + in csv_cell_ids + ] + + # Write merged GeoJSON + if all_geojson_features: + merged_geo = {"type": "FeatureCollection", "features": all_geojson_features} + write_geojson(merged_geo, output_dir / "xr-cell-polygons.geojson") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Stitch per-patch Baysor segmentation results into unified output." + ) + parser.add_argument( + "--patches", + type=Path, + required=True, + help="Directory containing patch subdirectories and patch_grid.json", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output directory for stitched CSV and GeoJSON", + ) + parser.add_argument( + "--csv-filename", + default="segmentation.csv", + help="CSV filename within each patch (default: segmentation.csv)", + ) + parser.add_argument( + "--geojson-filename", + default="segmentation_polygons.json", + help="GeoJSON filename within each patch (default: segmentation_polygons.json)", + ) + args = parser.parse_args() + + stitch_transcript_assignments( + patches_dir=args.patches, + output_dir=args.output, + csv_filename=args.csv_filename, + geojson_filename=args.geojson_filename, + ) + + +if __name__ == "__main__": + main() diff --git a/modules/local/baysor/create_dataset/main.nf b/modules/local/baysor/create_dataset/main.nf index d4e0043e..98046161 100644 --- a/modules/local/baysor/create_dataset/main.nf +++ b/modules/local/baysor/create_dataset/main.nf @@ -24,7 +24,7 @@ process BAYSOR_CREATE_DATASET { prefix = task.ext.prefix ?: "${meta.id}" """ - create_dataset.py \\ + baysor_create_dataset.py \\ --transcripts ${transcripts} \\ --sample-fraction ${sample_fraction} \\ --prefix ${prefix} diff --git a/modules/local/baysor/preprocess/main.nf b/modules/local/baysor/preprocess/main.nf index 7b3c6ac8..cfe6fe3b 100644 --- a/modules/local/baysor/preprocess/main.nf +++ b/modules/local/baysor/preprocess/main.nf @@ -30,7 +30,7 @@ process BAYSOR_PREPROCESS_TRANSCRIPTS { prefix = task.ext.prefix ?: "${meta.id}" """ - preprocess_transcripts.py \\ + baysor_preprocess_transcripts.py \\ --transcripts ${transcripts} \\ --prefix ${prefix} \\ --min-qv ${min_qv} \\ diff --git a/modules/local/segger/create_dataset/main.nf b/modules/local/segger/create_dataset/main.nf index 81320eff..ce008ec3 100644 --- a/modules/local/segger/create_dataset/main.nf +++ b/modules/local/segger/create_dataset/main.nf @@ -33,7 +33,7 @@ process SEGGER_CREATE_DATASET { export NUMBA_CACHE_DIR=\$PWD/.numba_cache mkdir -p \$NUMBA_CACHE_DIR - run_create_dataset.py \\ + segger_create_dataset.py \\ --bundle-dir ${base_dir} \\ --output-dir ${prefix} \\ --sample-type ${params.format} \\ diff --git a/modules/local/segger/predict/main.nf b/modules/local/segger/predict/main.nf index 0da7a594..d4180394 100644 --- a/modules/local/segger/predict/main.nf +++ b/modules/local/segger/predict/main.nf @@ -27,7 +27,7 @@ process SEGGER_PREDICT { def args = task.ext.args ?: '' prefix = task.ext.prefix ?: "${meta.id}" """ - run_predict.py \\ + segger_predict.py \\ --models-dir ${models_dir} \\ --segger-data-dir ${segger_dataset} \\ --transcripts-file ${transcripts} \\ diff --git a/modules/local/utility/convert_mask_uint32/main.nf b/modules/local/utility/convert_mask_uint32/main.nf index 3f0333a7..40e5c35c 100644 --- a/modules/local/utility/convert_mask_uint32/main.nf +++ b/modules/local/utility/convert_mask_uint32/main.nf @@ -35,7 +35,7 @@ process CONVERT_MASK_UINT32 { script: prefix = task.ext.prefix ?: "${meta.id}" """ - convert_mask_uint32.py \\ + utility_convert_mask_uint32.py \\ --input ${mask} \\ --output ${prefix}_uint32_mask.tif """ diff --git a/modules/local/utility/downscale_morphology/main.nf b/modules/local/utility/downscale_morphology/main.nf index edaf3d67..ab4f478a 100644 --- a/modules/local/utility/downscale_morphology/main.nf +++ b/modules/local/utility/downscale_morphology/main.nf @@ -41,7 +41,7 @@ process DOWNSCALE_MORPHOLOGY { def diam_mean = 30 prefix = task.ext.prefix ?: "${meta.id}" """ - downscale_morphology.py \\ + utility_downscale_morphology.py \\ --image ${image} \\ --diameter ${diameter} \\ --diam-mean ${diam_mean} \\ diff --git a/modules/local/utility/extract_dapi/main.nf b/modules/local/utility/extract_dapi/main.nf index 79cce91f..ef9a88bd 100644 --- a/modules/local/utility/extract_dapi/main.nf +++ b/modules/local/utility/extract_dapi/main.nf @@ -36,7 +36,7 @@ process EXTRACT_DAPI { prefix = task.ext.prefix ?: "${meta.id}" def channel_index = task.ext.channel_index ?: 0 """ - extract_dapi.py \\ + utility_extract_dapi.py \\ --input ${image} \\ --output ${prefix}_dapi.tif \\ --channel-index ${channel_index} diff --git a/modules/local/utility/extract_preview_data/main.nf b/modules/local/utility/extract_preview_data/main.nf index fb07df29..1240ddbf 100644 --- a/modules/local/utility/extract_preview_data/main.nf +++ b/modules/local/utility/extract_preview_data/main.nf @@ -26,7 +26,7 @@ process EXTRACT_PREVIEW_DATA { prefix = task.ext.prefix ?: "${meta.id}" """ - extract_data.py \\ + utility_extract_data.py \\ --preview-html ${preview_html} \\ --prefix ${prefix} """ diff --git a/modules/local/utility/get_coordinates/main.nf b/modules/local/utility/get_coordinates/main.nf index 3fdd7862..2b672239 100644 --- a/modules/local/utility/get_coordinates/main.nf +++ b/modules/local/utility/get_coordinates/main.nf @@ -25,7 +25,7 @@ process GET_TRANSCRIPTS_COORDINATES { prefix = task.ext.prefix ?: "${meta.id}" """ - get_coordinates.py \\ + utility_get_coordinates.py \\ --transcripts ${transcripts} """ diff --git a/modules/local/utility/parquet_to_csv/main.nf b/modules/local/utility/parquet_to_csv/main.nf index 033ed00a..9c31fe41 100644 --- a/modules/local/utility/parquet_to_csv/main.nf +++ b/modules/local/utility/parquet_to_csv/main.nf @@ -25,7 +25,7 @@ process PARQUET_TO_CSV { prefix = task.ext.prefix ?: "${meta.id}" """ - parquet_to_csv.py \\ + utility_parquet_to_csv.py \\ --transcripts ${transcripts} \\ --extension ${extension} \\ --prefix ${prefix} diff --git a/modules/local/utility/resize_tif/main.nf b/modules/local/utility/resize_tif/main.nf index 6877af27..35685b7c 100644 --- a/modules/local/utility/resize_tif/main.nf +++ b/modules/local/utility/resize_tif/main.nf @@ -26,7 +26,7 @@ process RESIZE_TIF { prefix = task.ext.prefix ?: "${meta.id}" """ - resize_tif.py \\ + utility_resize_tif.py \\ --mask ${mask} \\ --transcripts ${transcripts} \\ --metadata ${metadata} \\ diff --git a/modules/local/utility/segger2xr/main.nf b/modules/local/utility/segger2xr/main.nf index 1964469a..073748d7 100644 --- a/modules/local/utility/segger2xr/main.nf +++ b/modules/local/utility/segger2xr/main.nf @@ -27,7 +27,7 @@ process SEGGER2XR { def min_transcripts = task.ext.min_transcripts_per_cell ?: 3 """ - segger2xr.py \\ + utility_segger2xr.py \\ --transcripts ${transcripts} \\ --prefix ${meta.id} \\ --min-transcripts ${min_transcripts} diff --git a/modules/local/utility/split_transcripts/main.nf b/modules/local/utility/split_transcripts/main.nf index 5cfa0b65..f7057e31 100644 --- a/modules/local/utility/split_transcripts/main.nf +++ b/modules/local/utility/split_transcripts/main.nf @@ -26,7 +26,7 @@ process SPLIT_TRANSCRIPTS { def prefix = task.ext.prefix ?: "${meta.id}" """ - split_transcripts.py \\ + utility_split_transcripts.py \\ --transcripts ${transcripts} \\ --x-bins ${x_bins} \\ --y-bins ${y_bins} \\ diff --git a/modules/local/utility/upscale_mask/main.nf b/modules/local/utility/upscale_mask/main.nf index a201abf1..246290fc 100644 --- a/modules/local/utility/upscale_mask/main.nf +++ b/modules/local/utility/upscale_mask/main.nf @@ -35,7 +35,7 @@ process UPSCALE_MASK { script: prefix = task.ext.prefix ?: "${meta.id}" """ - upscale_mask.py \\ + utility_upscale_mask.py \\ --mask ${mask} \\ --scale-info ${scale_info} \\ --prefix ${prefix} diff --git a/modules/local/xenium_patch/stitch/main.nf b/modules/local/xenium_patch/stitch/main.nf index d805a0f5..c674a409 100644 --- a/modules/local/xenium_patch/stitch/main.nf +++ b/modules/local/xenium_patch/stitch/main.nf @@ -35,14 +35,14 @@ process XENIUM_PATCH_STITCH { script: def args = task.ext.args ?: '' """ - stitch_transcripts.py \\ + xenium_patch_stitch_transcripts.py \\ --patches ${patches} \\ --output output \\ ${args} # Post-process: ensure all GeoJSON geometries are Polygon and # reconcile dropped cells in the transcript CSV. - stitch_postprocess.py \\ + xenium_patch_stitch_postprocess.py \\ --geojson output/xr-cell-polygons.geojson \\ --csv output/xr-transcript-metadata.csv """