Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions omicverse/bulk/_alignment/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def star_align(
# ---------- Counting via featureCounts ----------
def featurecounts(
self,
bam_triples: Sequence[Tuple[str, str | Path, Optional[str]]], # [(srr, bam, index_dir|None)]
bam_triples: Sequence[Tuple[str, str | Path, Optional[str]] | Tuple[str, str | Path, Optional[str], Optional[bool]]], # [(srr, bam, index_dir|None[, is_paired])]
*,
gtf: Optional[str | Path] = None, # Explicit GTF path takes highest priority.
simple: Optional[bool] = None, # None→cfg.featurecounts_simple
Expand All @@ -859,9 +859,10 @@ def featurecounts(
out_root.mkdir(parents=True, exist_ok=True)

# ---------- Built-in GTF inference ----------
def _infer_gtf_from_bams(triples: Sequence[Tuple[str, str | Path, Optional[str]]]) -> Optional[str]:
def _infer_gtf_from_bams(triples: Sequence[Tuple[str, str | Path, Optional[str]] | Tuple[str, str | Path, Optional[str], Optional[bool]]]) -> Optional[str]:
# 1) Prefer GTF discovery from each sample's index_dir when available.
for _srr, _bam, idx_dir in triples:
for rec in triples:
_srr, _bam, idx_dir = rec[:3]
if not idx_dir:
continue
idx = Path(idx_dir)
Expand Down Expand Up @@ -920,21 +921,39 @@ def _table_path_for(srr: str) -> Path:
return out_root / srr / f"{srr}.counts.txt"

# Idempotent shortcut: skip when every output already exists.
outs_by_srr: List[Tuple[str, Path]] = [(str(srr), _table_path_for(str(srr))) for srr, _bam, _ in bam_triples]
outs_by_srr: List[Tuple[str, Path]] = [
(str(rec[0]), _table_path_for(str(rec[0]))) for rec in bam_triples
]
if all(step["validation"]([str(p)]) for _, p in outs_by_srr):
print("[SKIP] featureCounts for all")
tables = [(srr, str(p)) for srr, p in outs_by_srr]
return {"tables": tables, "matrix": None, "failed": []}

# Assemble (srr, bam) pairs and execute.
bam_pairs = [(str(srr), str(bam)) for (srr, bam, _idx) in bam_triples]
# Assemble (srr, bam[, is_paired]) pairs and execute.
layout_hint = None if self.cfg.library_layout == "auto" else (self.cfg.library_layout == "paired")
bam_pairs = []
for rec in bam_triples:
if len(rec) == 4:
srr, bam, _idx, is_paired = rec # type: ignore[misc]
elif len(rec) == 3:
srr, bam, _idx = rec # type: ignore[misc]
is_paired = None
else:
raise ValueError(f"featurecounts expects (srr, bam, idx_dir[, is_paired]) tuples; got {rec}")

# User hint overrides auto detection.
if layout_hint is not None:
is_paired = layout_hint

bam_pairs.append((str(srr), str(bam), is_paired))

ret = step["command"](
bam_pairs,
logger=None,
gtf=str(gtf), # Explicit runtime GTF has top priority.
)

tables = [(srr, str(_table_path_for(str(srr)))) for srr, _ in bam_pairs]
tables = [(rec[0], str(_table_path_for(str(rec[0])))) for rec in bam_pairs]
matrix_path = ret.get("matrix") if isinstance(ret, dict) else None
return {"tables": tables, "matrix": matrix_path, "failed": []}

Expand Down Expand Up @@ -1340,6 +1359,9 @@ def run(self, srr_list: Sequence[str], *, with_align: bool = False, align_index:
bam_triples = self.star_align(fastqs_qc)
# Extract BAM paths from the bam_triples structure [(srr, bam_path, index_dir), ...].
bams = [(srr, Path(bam_path)) for srr, bam_path, _ in bam_triples]
# Pass paired-end hints forward to featureCounts when available.
paired_flags = {srr: bool(fq2) for srr, _c1, fq2, *_rest in fastqs_qc}
bam_triples = [(srr, bam_path, idx_dir, paired_flags.get(srr)) for srr, bam_path, idx_dir in bam_triples]
else:
# Skip the alignment step and return an empty result.
logger.info("Skipping alignment step because with_align=False")
Expand Down
6 changes: 3 additions & 3 deletions omicverse/bulk/_alignment/count_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def make_featurecounts_step(
gtf_path: str | None = None,
):
"""
Input: BAM list [(srr, bam), ...]
Input: BAM list [(srr, bam) | (srr, bam, is_paired), ...]
Output:
- Per-sample counts: work/counts/{SRR}/{SRR}.counts.txt (or .csv)
- Optional aggregate matrix: work/counts/matrix.{by}.csv
Validation: per-sample count files exist and contain rows.
"""
def _cmd(bam_pairs: Sequence[tuple[str, str]], logger=None, gtf: str | None = None):
def _cmd(bam_pairs: Sequence[tuple[str, str] | tuple[str, str, bool]], logger=None, gtf: str | None = None):
"""
bam_pairs: [(srr, bam_path), ...]
bam_pairs: [(srr, bam_path[, is_paired]), ...]
gtf: Optional runtime GTF override (takes highest priority).
"""
os.makedirs(out_root, exist_ok=True)
Expand Down
120 changes: 107 additions & 13 deletions omicverse/bulk/_alignment/count_tools.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
# count_tools.py featureCounts batch utilities
from __future__ import annotations
import os, subprocess, sys
import os, subprocess, sys, logging
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import pandas as pd
from datetime import datetime
from typing import Optional

logger = logging.getLogger(__name__)

def _feature_counts_one_with_path(bam_path: str, out_dir: str, gtf: str, threads: int = 8, simple: bool = True, featurecounts_path: str = None):

def _infer_paired_end_from_bam(bam_path: str) -> Optional[bool]:
"""
Best-effort detection of paired-end BAMs.
- Try pysam first (fast, minimal IO).
- Fallback to `samtools view -c -f 1` when pysam is unavailable.
Returns True/False when detected, or None when detection is inconclusive.
"""
try:
import pysam # type: ignore
with pysam.AlignmentFile(bam_path, "rb") as bam:
for rec in bam.fetch(until_eof=True):
return bool(rec.is_paired)
except Exception:
pass

try:
from .tools_check import resolve_tool, merged_env
samtools = resolve_tool("samtools")
env = merged_env()
except Exception:
samtools, env = None, None

if samtools:
try:
proc = subprocess.run(
[samtools, "view", "-c", "-f", "1", bam_path],
capture_output=True,
text=True,
check=True,
env=env,
)
paired_count = int(proc.stdout.strip() or 0)
return paired_count > 0
except Exception as e: # pragma: no cover - defensive
logger.debug(f"[featureCounts] samtools paired detection failed for {bam_path}: {e}")

return None
from typing import Optional

logger = logging.getLogger(__name__)


def _feature_counts_one_with_path(
bam_path: str,
out_dir: str,
gtf: str,
threads: int = 8,
simple: bool = True,
featurecounts_path: str = None,
is_paired: Optional[bool] = None,
):
"""Helper function that accepts a pre-resolved featureCounts path"""
if featurecounts_path is None:
return _feature_counts_one(bam_path, out_dir, gtf, threads, simple)
return _feature_counts_one(bam_path, out_dir, gtf, threads, simple, is_paired=is_paired)

# Use the provided path directly
srr = Path(bam_path).stem.replace(".bam", "")
Expand All @@ -20,13 +73,20 @@ def _feature_counts_one_with_path(bam_path: str, out_dir: str, gtf: str, threads
if os.path.exists(out_file) and os.path.getsize(out_file) > 0:
return srr, out_file

if is_paired is None:
is_paired = _infer_paired_end_from_bam(bam_path)
if is_paired:
logger.info(f"[featureCounts] Detected paired-end BAM for {srr}; enabling -p.")

cmd = [
featurecounts_path,
"-T", str(threads),
"-a", gtf,
"-o", out_file,
bam_path
]
if is_paired:
cmd.extend(["-p", "-B", "-C"])
cmd.append(bam_path)

# Use proper environment
from .tools_check import merged_env
Expand Down Expand Up @@ -64,7 +124,14 @@ def _feature_counts_one_with_path(bam_path: str, out_dir: str, gtf: str, threads

return srr, out_file

def _feature_counts_one(bam_path: str, out_dir: str, gtf: str, threads: int = 8, simple: bool = True):
def _feature_counts_one(
bam_path: str,
out_dir: str,
gtf: str,
threads: int = 8,
simple: bool = True,
is_paired: Optional[bool] = None,
):
# -------------- Safety guard for missing GTF --------------
if gtf is None:
gtf = os.environ.get("FC_GTF_HINT")
Expand All @@ -82,11 +149,14 @@ def _feature_counts_one(bam_path: str, out_dir: str, gtf: str, threads: int = 8,
if os.path.exists(out_file) and os.path.getsize(out_file) > 0:
return srr, out_file

if is_paired is None:
is_paired = _infer_paired_end_from_bam(bam_path)
if is_paired:
logger.info(f"[featureCounts] Detected paired-end BAM for {srr}; enabling -p.")

# -------------- Enhanced featureCounts detection (best-effort) --------------
from .tools_check import resolve_tool, merged_env, check_featurecounts
import shutil, logging

logger = logging.getLogger(__name__)
import shutil

featurecounts_path = resolve_tool("featureCounts")
if not featurecounts_path:
Expand All @@ -113,8 +183,10 @@ def _feature_counts_one(bam_path: str, out_dir: str, gtf: str, threads: int = 8,
"-T", str(threads),
"-a", gtf,
"-o", out_file,
bam_path
]
if is_paired:
cmd.extend(["-p", "-B", "-C"])
cmd.append(bam_path)

# Use the merged environment to ensure featureCounts is discoverable.
env = merged_env()
Expand Down Expand Up @@ -155,7 +227,7 @@ def _feature_counts_one(bam_path: str, out_dir: str, gtf: str, threads: int = 8,


def feature_counts_batch(
bam_items: list[tuple[str, str]], # [(srr, bam_path)]
bam_items: list[tuple[str, str] | tuple[str, str, Optional[bool]]], # [(srr, bam_path[, is_paired])]
out_dir: str,
gtf: str | None = None,
simple: bool = True,
Expand Down Expand Up @@ -198,10 +270,11 @@ def feature_counts_batch(
"skipping counting for all BAMs. "
"You can install it with: conda install -c bioconda subread -y"
)
failed_list = [(item[0], "featureCounts not available") for item in bam_items]
return {
"tables": [],
"matrix": None,
"failed": [(srr, "featureCounts not available") for srr, _ in bam_items],
"failed": failed_list,
}
# -----------------------------------------

Expand All @@ -213,10 +286,31 @@ def feature_counts_batch(
# Ensure each worker has enough CPU resources.
max_workers = min(4, cpu_count // max(1, threads // 4))

# Normalize to (srr, bam, is_paired|None) tuples to propagate layout hints.
normalized_items: list[tuple[str, str, Optional[bool]]] = []
for item in bam_items:
if len(item) == 3:
srr, bam, is_paired = item # type: ignore[misc]
elif len(item) == 2:
srr, bam = item # type: ignore[misc]
is_paired = None
else:
raise ValueError(f"feature_counts_batch expected (srr, bam[, is_paired]) tuples, got: {item}")
normalized_items.append((str(srr), str(bam), is_paired))

with ProcessPoolExecutor(max_workers=max_workers) as ex:
futures = {
ex.submit(_feature_counts_one_with_path, bam, out_dir, gtf, threads, simple, featurecounts_path): srr
for srr, bam in bam_items
ex.submit(
_feature_counts_one_with_path,
bam,
out_dir,
gtf,
threads,
simple,
featurecounts_path,
is_paired,
): srr
for srr, bam, is_paired in normalized_items
}
for fut in as_completed(futures):
srr = futures[fut]
Expand Down
48 changes: 24 additions & 24 deletions omicverse/bulk/_alignment/data_prepare_pipline.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,31 +100,29 @@ def run_pipeline(srr_list_or_csv):

# Step 1: fasterq (batch mode).
fasterq_step = STEPS[1]
outs_by_srr = []
for srr in srr_list:
outs = _render_paths(fasterq_step["outputs"], SRR=srr)
outs_by_srr.append((srr, outs))
# Skip when everything already exists.
if all(fasterq_step["validation"](outs) for _, outs in outs_by_srr):
logger.info("[SKIP] fasterq for all")
fastq_paths = [(srr, outs[0], outs[1]) for srr, outs in outs_by_srr]
else:
ret = fasterq_step["command"]([s for s,_ in outs_by_srr], logger=logger)
# Normalize to [(srr, fq1, fq2)]; you can unpack ret["success"], here we follow the template.
fastq_paths = [(srr, outs[0], outs[1]) for srr, outs in outs_by_srr]
fq_products = fasterq_step["command"](srr_list, logger=logger)
order = {s: i for i, s in enumerate(srr_list)}
fastq_paths = []
for rec in sorted(fq_products, key=lambda x: order.get(x[0], 0)):
if not isinstance(rec, (list, tuple)) or len(rec) < 2:
raise ValueError(f"Unexpected fasterq record: {rec}")
srr, fq1 = rec[0], rec[1]
fq2 = rec[2] if len(rec) > 2 else None
fastq_paths.append((srr, str(fq1), str(fq2) if fq2 else None))

# Step 2: fastp (batch mode).
fastp_step = STEPS[2]
outs_by_srr = []
for srr, fq1, fq2 in fastq_paths:
outs = _render_paths(fastp_step["outputs"], SRR=srr)
outs_by_srr.append((srr, fq1, fq2, outs))
if all(fastp_step["validation"](o) for *_, o in outs_by_srr):
logger.info("[SKIP] fastp for all")
clean_fastqs = [(srr, o[0], o[1]) for srr, _, _, o in outs_by_srr]
else:
ret = fastp_step["command"]([(srr, fq1, fq2) for srr, fq1, fq2, _ in outs_by_srr], logger=logger)
clean_fastqs = [(srr, o[0], o[1]) for srr, _, _, o in outs_by_srr]
fp_products = fastp_step["command"](fastq_paths, logger=logger)
order_fp = {s: i for i, (s, *_rest) in enumerate(fastq_paths)}
clean_fastqs = []
for rec in sorted(fp_products, key=lambda x: order_fp.get(x[0], 0)):
if not isinstance(rec, (list, tuple)) or len(rec) < 2:
raise ValueError(f"Unexpected fastp record: {rec}")
srr, c1 = rec[0], rec[1]
c2 = rec[2] if len(rec) > 2 else None
clean_fastqs.append((srr, str(c1), str(c2) if c2 else None))

paired_flags = {srr: bool(c2) for srr, _c1, c2 in clean_fastqs}

# Step 3: STAR (per sample; parallelizable, shown sequentially for clarity).
star_step = STEPS[3]
Expand Down Expand Up @@ -219,16 +217,18 @@ def _find_gtf_from_index(index_dir: str | os.PathLike) -> str:
# Step 4: featureCounts (batch mode).
fc_step = STEPS[4]
outs_by_srr = []
fc_inputs = []
for srr, bam, _idx in bam_paths: # Note bam_paths now contains triplets.
outs = _render_paths(fc_step["outputs"], SRR=srr)
outs_by_srr.append((srr, bam, outs))

fc_inputs.append((srr, bam, paired_flags.get(srr)))

if all(fc_step["validation"](o) for _, _, o in outs_by_srr):
logger.info("[SKIP] featureCounts for all")
count_tables = [o[0] for _, _, o in outs_by_srr]
else:
# Explicitly pass gtf to the featureCounts command function ← NEW.
ret = fc_step["command"]([(srr, bam) for srr, bam, _ in outs_by_srr],
ret = fc_step["command"](fc_inputs,
logger=logger,
gtf=inferred_gtf)
count_tables = [o[0] for _, _, o in outs_by_srr]
Expand Down
Loading
Loading