diff --git a/README.md b/README.md index 513178ea..ada73bbb 100644 --- a/README.md +++ b/README.md @@ -119,4 +119,3 @@ You can leave your comments and bug reports at our [GitHub repository tracker](h --data_type (assembly|pacbio|nanopore) -o OUTPUT_FOLDER * If multiple files are provided, IsoQuant will create a single output annotation and a single set of gene/transcript expression tables. - diff --git a/docs/visualization.md b/docs/visualization.md index 5bf1cd82..550b3395 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -4,35 +4,41 @@ IsoQuant provides a visualization tool to help interpret and explore the output ## Running the visualization tool -To run the visualization tool, use the following command: +To run the visualization tool, use one of the following commands: ```bash - +# Visualize a predefined list of genes python visualize.py --gene_list [options] +# Automatically find the top N most differentially expressed genes +python visualize.py --find_genes [N] [options] ``` ## Command line options * `output_directory` (required): Directory containing IsoQuant output files. -* * `--gene_list` (required): Path to a .txt file containing a list of genes, each on its own line. -* `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. +* `--gene_list`: Path to a .txt file containing a list of genes, each on its own line. Mutually exclusive with `--find_genes`. +* `--find_genes [N]`: Automatically select the top **N** genes with the highest combined differential-expression rank between chosen conditions (default 100 if *N* is omitted). +* `--viz_output`: Optional directory to save visualization output files. Defaults to `/visualization`. * `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. -* `--counts`: Use counts instead of TPM files for visualization. * `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. -* `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. +* `--filter_transcripts `: Minimum expression value a transcript must reach in at least one condition to be included in plots (default 1.0). +* `--gsea`: Perform Gene Set Enrichment Analysis on differential expression results (requires `--find_genes`). +* `--technical_replicates`: Specify technical replicate groupings as a file (`sample,group`) or inline (`sample1:group1,sample2:group1`). ## Output -The visualization tool generates the following plots based on the IsoQuant output: +The visualization tool can generate the following outputs: -1. Transcript usage profiles: For each gene specified in the gene list, a plot showing the relative usage of different transcripts across conditions or samples. +1. Transcript usage profiles: For each gene, a plot showing the relative usage of different transcripts across conditions or samples. 2. Gene-specific transcript maps: Visual representation of the different splicing patterns of transcripts for each gene, allowing easy comparison of exon usage and alternative splicing events. -3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed. +3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed (enabled interactively). + +4. Global transcript alignment classifications: A chart representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. -4. Global transcript alignment classifications: A chart or plot representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. +5. Differential expression tables and volcano plots when `--find_genes` is used, with optional GSEA pathway visualizations if `--gsea` is supplied. -These visualizations provide valuable insights into transcript diversity, splicing patterns, and the overall quality of the IsoQuant analysis. +These visualizations and reports provide valuable insights into transcript diversity, splicing patterns, differential expression, and the overall quality of the IsoQuant analysis. diff --git a/install_r_packages.py b/install_r_packages.py new file mode 100644 index 00000000..f346bb60 --- /dev/null +++ b/install_r_packages.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +import rpy2.robjects.packages as rpackages +from rpy2.robjects.vectors import StrVector + +# List of R packages to install +r_package_names = ('DESeq2', 'ggplot2', 'ggrepel', 'RColorBrewer', 'clusterProfiler', 'org.Hs.eg.db') + +# Get R's utility package +utils = rpackages.importr('utils') + +# Select CRAN mirror (optional, but recommended for reproducibility) +utils.chooseCRANmirror(ind=1) # Select the first mirror in the list + +# Function to check if R package is installed +def is_installed(package_name): + return package_name in rpackages.packages() + +# Install R packages if not already installed +packages_to_install = [pkg for pkg in r_package_names if not is_installed(pkg)] + +if packages_to_install: + print(f"Installing R packages: {', '.join(packages_to_install)}") + utils.install_packages(StrVector(packages_to_install)) +else: + print("All required R packages are already installed.") \ No newline at end of file diff --git a/isoquant.py b/isoquant.py index 6aef4661..ee90bdb8 100755 --- a/isoquant.py +++ b/isoquant.py @@ -31,253 +31,537 @@ ASSEMBLY, PACBIO_CCS_DATA, NANOPORE_DATA, - DataSetReadMapper + DataSetReadMapper, ) from src.dataset_processor import DatasetProcessor, PolyAUsageStrategies from src.graph_based_model_construction import StrandnessReportingLevel from src.long_read_assigner import AmbiguityResolvingMethod -from src.long_read_counter import COUNTING_STRATEGIES, CountingStrategy, NormalizationMethod, GroupedOutputFormat +from src.long_read_counter import ( + COUNTING_STRATEGIES, + CountingStrategy, + NormalizationMethod, + GroupedOutputFormat, +) from src.input_data_storage import InputDataStorage from src.multimap_resolver import MultimapResolvingStrategy from src.stats import combine_counts -logger = logging.getLogger('IsoQuant') +logger = logging.getLogger("IsoQuant") def bool_str(s): s = s.lower() - if s not in {'false', 'true', '0', '1'}: - raise ValueError('Not a valid boolean string') - return s == 'true' or s == '1' + if s not in {"false", "true", "0", "1"}: + raise ValueError("Not a valid boolean string") + return s == "true" or s == "1" def parse_args(cmd_args=None, namespace=None): - parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter) - ref_args_group = parser.add_argument_group('Reference data') - input_args_group = parser.add_argument_group('Input data') - output_args_group = parser.add_argument_group('Output naming') - pipeline_args_group = parser.add_argument_group('Pipeline options') - algo_args_group = parser.add_argument_group('Algorithm settings') + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter + ) + ref_args_group = parser.add_argument_group("Reference data") + input_args_group = parser.add_argument_group("Input data") + output_args_group = parser.add_argument_group("Output naming") + pipeline_args_group = parser.add_argument_group("Pipeline options") + algo_args_group = parser.add_argument_group("Algorithm settings") other_options = parser.add_argument_group("Additional options:") - show_full_help = '--full_help' in cmd_args + show_full_help = "--full_help" in cmd_args def add_additional_option(*args, **kwargs): # show command only with --full-help if not show_full_help: - kwargs['help'] = argparse.SUPPRESS + kwargs["help"] = argparse.SUPPRESS other_options.add_argument(*args, **kwargs) - def add_additional_option_to_group(opt_group, *args, **kwargs): # show command only with --full-help + def add_additional_option_to_group( + opt_group, *args, **kwargs + ): # show command only with --full-help if not show_full_help: - kwargs['help'] = argparse.SUPPRESS + kwargs["help"] = argparse.SUPPRESS opt_group.add_argument(*args, **kwargs) def add_hidden_option(*args, **kwargs): # show command only with --full-help - kwargs['help'] = argparse.SUPPRESS + kwargs["help"] = argparse.SUPPRESS parser.add_argument(*args, **kwargs) - parser.add_argument("--full_help", action='help', help="show full list of options") - add_hidden_option('--debug', action='store_true', default=False, - help='Debug log output.') - - output_args_group.add_argument("--output", "-o", help="output folder, will be created automatically " - "[default=isoquant_output]", - type=str, default="isoquant_output") - output_args_group.add_argument('--prefix', '-p', type=str, - help='experiment name; to be used for folder and file naming; default is OUT', - default="OUT") - output_args_group.add_argument('--labels', '-l', nargs='+', type=str, - help='sample/replica labels to be used as column names; input file names are used ' - 'if not set; must be equal to the number of input files given via --fastq/--bam') + parser.add_argument("--full_help", action="help", help="show full list of options") + add_hidden_option( + "--debug", action="store_true", default=False, help="Debug log output." + ) + + output_args_group.add_argument( + "--output", + "-o", + help="output folder, will be created automatically " + "[default=isoquant_output]", + type=str, + default="isoquant_output", + ) + output_args_group.add_argument( + "--prefix", + "-p", + type=str, + help="experiment name; to be used for folder and file naming; default is OUT", + default="OUT", + ) + output_args_group.add_argument( + "--labels", + "-l", + nargs="+", + type=str, + help="sample/replica labels to be used as column names; input file names are used " + "if not set; must be equal to the number of input files given via --fastq/--bam", + ) # REFERENCE - ref_args_group.add_argument("--reference", "-r", help="reference genome in FASTA format (can be gzipped)", - type=str) - ref_args_group.add_argument("--genedb", "-g", help="gene database in gffutils DB format or GTF/GFF " - "format (optional)", type=str) - ref_args_group.add_argument('--complete_genedb', action='store_true', default=False, - help="use this flag if gene annotation contains transcript and gene metafeatures, " - "e.g. with official annotations, such as GENCODE; " - "speeds up gene database conversion") - add_additional_option_to_group(ref_args_group, "--index", help="genome index for specified aligner (optional)", - type=str) + ref_args_group.add_argument( + "--reference", + "-r", + help="reference genome in FASTA format (can be gzipped)", + type=str, + ) + ref_args_group.add_argument( + "--genedb", + "-g", + help="gene database in gffutils DB format or GTF/GFF " "format (optional)", + type=str, + ) + ref_args_group.add_argument( + "--complete_genedb", + action="store_true", + default=False, + help="use this flag if gene annotation contains transcript and gene metafeatures, " + "e.g. with official annotations, such as GENCODE; " + "speeds up gene database conversion", + ) + add_additional_option_to_group( + ref_args_group, + "--index", + help="genome index for specified aligner (optional)", + type=str, + ) # INPUT READS input_args = input_args_group.add_mutually_exclusive_group() - input_args.add_argument('--bam', nargs='+', type=str, - help='sorted and indexed BAM file(s), each file will be treated as a separate sample') - input_args.add_argument('--fastq', nargs='+', type=str, - help='input FASTQ file(s), each file will be treated as a separate sample; ' - 'reference genome should be provided when using reads as input') - add_additional_option_to_group(input_args,'--bam_list', type=str, - help='text file with list of BAM files, one file per line, ' - 'leave empty line between samples') - add_additional_option_to_group(input_args,'--fastq_list', type=str, - help='text file with list of FASTQ files, one file per line, ' - 'leave empty line between samples') - input_args.add_argument('--yaml', type=str, help='yaml file containing all input files, one entry per sample' - ', check readme for format info') - - input_args_group.add_argument('--illumina_bam', nargs='+', type=str, - help='sorted and indexed file(s) with Illumina reads from the same sample') - - input_args_group.add_argument("--read_group", help="a way to group feature counts (no grouping by default): " - "by BAM file tag (tag:TAG); " - "using additional file (file:FILE:READ_COL:GROUP_COL:DELIM); " - "using read id (read_id:DELIM); " - "by original file name (file_name)", type=str) + input_args.add_argument( + "--bam", + nargs="+", + type=str, + help="sorted and indexed BAM file(s), each file will be treated as a separate sample", + ) + input_args.add_argument( + "--fastq", + nargs="+", + type=str, + help="input FASTQ file(s), each file will be treated as a separate sample; " + "reference genome should be provided when using reads as input", + ) + add_additional_option_to_group( + input_args, + "--bam_list", + type=str, + help="text file with list of BAM files, one file per line, " + "leave empty line between samples", + ) + add_additional_option_to_group( + input_args, + "--fastq_list", + type=str, + help="text file with list of FASTQ files, one file per line, " + "leave empty line between samples", + ) + input_args.add_argument( + "--yaml", + type=str, + help="yaml file containing all input files, one entry per sample" + ", check readme for format info", + ) + + input_args_group.add_argument( + "--illumina_bam", + nargs="+", + type=str, + help="sorted and indexed file(s) with Illumina reads from the same sample", + ) + + input_args_group.add_argument( + "--read_group", + help="a way to group feature counts (no grouping by default): " + "by BAM file tag (tag:TAG); " + "using additional file (file:FILE:READ_COL:GROUP_COL:DELIM); " + "using read id (read_id:DELIM); " + "by original file name (file_name)", + type=str, + ) # INPUT PROPERTIES - input_args_group.add_argument("--data_type", "-d", type=str, choices=DATA_TYPE_ALIASES.keys(), - help="type of data to process, supported types are: " + ", ".join(DATA_TYPE_ALIASES.keys())) - input_args_group.add_argument('--stranded', type=str, help="reads strandness type, supported values are: " + - ", ".join(SUPPORTED_STRANDEDNESS), default="none") - input_args_group.add_argument('--fl_data', action='store_true', default=False, - help="reads represent FL transcripts; both ends of the read are considered to be reliable") + input_args_group.add_argument( + "--data_type", + "-d", + type=str, + choices=DATA_TYPE_ALIASES.keys(), + help="type of data to process, supported types are: " + + ", ".join(DATA_TYPE_ALIASES.keys()), + ) + input_args_group.add_argument( + "--stranded", + type=str, + help="reads strandness type, supported values are: " + + ", ".join(SUPPORTED_STRANDEDNESS), + default="none", + ) + input_args_group.add_argument( + "--fl_data", + action="store_true", + default=False, + help="reads represent FL transcripts; both ends of the read are considered to be reliable", + ) # ALGORITHM - add_additional_option_to_group(algo_args_group, "--report_novel_unspliced", "-u", type=bool_str, - help="report novel monoexonic transcripts (true/false), " - "default: false for ONT, true for other data types") - add_additional_option_to_group(algo_args_group, "--report_canonical", type=str, - choices=[e.name for e in StrandnessReportingLevel], - help="reporting level for novel transcripts based on canonical splice sites;" - " default: " + StrandnessReportingLevel.auto.name, - default=StrandnessReportingLevel.only_stranded.name) - add_additional_option_to_group(algo_args_group, "--polya_requirement", type=str, - choices=[e.name for e in PolyAUsageStrategies], - help="require polyA tails to be present when reporting transcripts; " - "default: auto (requires polyA only when polyA percentage is >= 70%%)", - default=PolyAUsageStrategies.auto.name) - - add_additional_option_to_group(algo_args_group, "--transcript_quantification", choices=COUNTING_STRATEGIES, - help="transcript quantification strategy", type=str, - default=CountingStrategy.unique_only.name) - add_additional_option_to_group(algo_args_group, "--gene_quantification", choices=COUNTING_STRATEGIES, - help="gene quantification strategy", type=str, - default=CountingStrategy.unique_splicing_consistent.name) - - add_additional_option_to_group(algo_args_group, "--matching_strategy", - choices=["exact", "precise", "default", "loose"], - help="read-to-isoform matching strategy from the most strict to least", - type=str, default=None) - add_additional_option_to_group(algo_args_group, "--splice_correction_strategy", - choices=["none", "default_pacbio", "default_ont", - "conservative_ont", "all", "assembly"], - help="read alignment correction strategy to use", type=str, default=None) - add_additional_option_to_group(algo_args_group, "--model_construction_strategy", - choices=["reliable", "default_pacbio", "sensitive_pacbio", "fl_pacbio", - "default_ont", "sensitive_ont", "all", "assembly"], - help="transcript model construction strategy to use", type=str, default=None) + add_additional_option_to_group( + algo_args_group, + "--report_novel_unspliced", + "-u", + type=bool_str, + help="report novel monoexonic transcripts (true/false), " + "default: false for ONT, true for other data types", + ) + add_additional_option_to_group( + algo_args_group, + "--report_canonical", + type=str, + choices=[e.name for e in StrandnessReportingLevel], + help="reporting level for novel transcripts based on canonical splice sites;" + " default: " + StrandnessReportingLevel.auto.name, + default=StrandnessReportingLevel.only_stranded.name, + ) + add_additional_option_to_group( + algo_args_group, + "--polya_requirement", + type=str, + choices=[e.name for e in PolyAUsageStrategies], + help="require polyA tails to be present when reporting transcripts; " + "default: auto (requires polyA only when polyA percentage is >= 70%%)", + default=PolyAUsageStrategies.auto.name, + ) + + add_additional_option_to_group( + algo_args_group, + "--transcript_quantification", + choices=COUNTING_STRATEGIES, + help="transcript quantification strategy", + type=str, + default=CountingStrategy.unique_only.name, + ) + add_additional_option_to_group( + algo_args_group, + "--gene_quantification", + choices=COUNTING_STRATEGIES, + help="gene quantification strategy", + type=str, + default=CountingStrategy.unique_splicing_consistent.name, + ) + + add_additional_option_to_group( + algo_args_group, + "--matching_strategy", + choices=["exact", "precise", "default", "loose"], + help="read-to-isoform matching strategy from the most strict to least", + type=str, + default=None, + ) + add_additional_option_to_group( + algo_args_group, + "--splice_correction_strategy", + choices=[ + "none", + "default_pacbio", + "default_ont", + "conservative_ont", + "all", + "assembly", + ], + help="read alignment correction strategy to use", + type=str, + default=None, + ) + add_additional_option_to_group( + algo_args_group, + "--model_construction_strategy", + choices=[ + "reliable", + "default_pacbio", + "sensitive_pacbio", + "fl_pacbio", + "default_ont", + "sensitive_ont", + "all", + "assembly", + ], + help="transcript model construction strategy to use", + type=str, + default=None, + ) # OUTPUT PROPERTIES - pipeline_args_group.add_argument("--threads", "-t", help="number of threads to use", type=int, - default="16") - pipeline_args_group.add_argument('--check_canonical', action='store_true', default=False, - help="report whether splice junctions are canonical") - pipeline_args_group.add_argument("--sqanti_output", help="produce SQANTI-like TSV output", - action='store_true', default=False) - pipeline_args_group.add_argument("--count_exons", help="perform exon and intron counting", - action='store_true', default=False) - add_additional_option_to_group(pipeline_args_group,"--bam_tags", - help="comma separated list of BAM tags to be imported to read_assignments.tsv", - type=str) + pipeline_args_group.add_argument( + "--threads", "-t", help="number of threads to use", type=int, default="16" + ) + pipeline_args_group.add_argument( + "--check_canonical", + action="store_true", + default=False, + help="report whether splice junctions are canonical", + ) + pipeline_args_group.add_argument( + "--sqanti_output", + help="produce SQANTI-like TSV output", + action="store_true", + default=False, + ) + pipeline_args_group.add_argument( + "--count_exons", + help="perform exon and intron counting", + action="store_true", + default=False, + ) + add_additional_option_to_group( + pipeline_args_group, + "--bam_tags", + help="comma separated list of BAM tags to be imported to read_assignments.tsv", + type=str, + ) # PIPELINE STEPS resume_args = pipeline_args_group.add_mutually_exclusive_group() - resume_args.add_argument("--resume", action="store_true", default=False, - help="resume failed run, specify output folder, input options are not allowed") - resume_args.add_argument("--force", action="store_true", default=False, - help="force to overwrite the previous run") - add_additional_option_to_group(pipeline_args_group, '--clean_start', action='store_true', default=False, - help='Do not use previously generated index, feature db or alignments.') - - add_additional_option_to_group(pipeline_args_group, "--no_model_construction", action="store_true", - default=False, help="run only read assignment and quantification") - add_additional_option_to_group(pipeline_args_group, "--run_aligner_only", action="store_true", default=False, - help="align reads to reference without running further analysis") + resume_args.add_argument( + "--resume", + action="store_true", + default=False, + help="resume failed run, specify output folder, input options are not allowed", + ) + resume_args.add_argument( + "--force", + action="store_true", + default=False, + help="force to overwrite the previous run", + ) + add_additional_option_to_group( + pipeline_args_group, + "--clean_start", + action="store_true", + default=False, + help="Do not use previously generated index, feature db or alignments.", + ) + + add_additional_option_to_group( + pipeline_args_group, + "--no_model_construction", + action="store_true", + default=False, + help="run only read assignment and quantification", + ) + add_additional_option_to_group( + pipeline_args_group, + "--run_aligner_only", + action="store_true", + default=False, + help="align reads to reference without running further analysis", + ) # ADDITIONAL - add_additional_option("--delta", type=int, default=None, - help="delta for inexact splice junction comparison, chosen automatically based on data type") - add_hidden_option("--graph_clustering_distance", type=int, default=None, - help="intron graph clustering distance, " - "splice junctions less that this number of bp apart will not be differentiated") - add_additional_option("--no_gzip", help="do not gzip large output files", dest="gzipped", - action='store_false', default=True) - add_additional_option("--no_gtf_check", help="do not perform GTF checks", dest="gtf_check", - action='store_false', default=True) - add_additional_option("--high_memory", help="increase RAM consumption (store alignment and the genome in RAM)", - action='store_true', default=False) - add_additional_option("--no_junc_bed", action="store_true", default=False, - help="do NOT use annotation for read mapping") - add_additional_option("--junc_bed_file", type=str, - help="annotation in BED format produced by minimap's paftools.js gff2bed " - "(will be created automatically if not given)") - add_additional_option("--no_secondary", help="ignore secondary alignments (not recommended)", action='store_true', - default=False) - add_additional_option("--min_mapq", help="ignore alignments with MAPQ < this" - "(also filters out secondary alignments, default: None)", type=int) - add_additional_option("--inconsistent_mapq_cutoff", help="ignore inconsistent alignments with MAPQ < this " - "(works only with the reference annotation, default=5)", - type=int, default=5) - add_additional_option("--simple_alignments_mapq_cutoff", help="ignore alignments with 1 or 2 exons and " - "MAPQ < this (works only in annotation-free mode, " - "default=1)", type=int, default=1) - add_additional_option("--normalization_method", type=str, choices=[e.name for e in NormalizationMethod], - help="TPM normalization method: simple - conventional normalization using all counted reads;" - "usable_reads - includes all assigned reads.", - default=NormalizationMethod.simple.name) - add_additional_option("--counts_format", type=str, choices=[e.name for e in GroupedOutputFormat], - help="output format for grouped counts", - default=GroupedOutputFormat.both.name) - - add_additional_option_to_group(pipeline_args_group, "--keep_tmp", help="do not remove temporary files " - "in the end", action='store_true', - default=False) - add_additional_option_to_group(input_args_group, "--read_assignments", nargs='+', type=str, - help="reuse read assignments (binary format)", default=None) - add_hidden_option("--aligner", help="force to use this alignment method, can be " + ", ".join(SUPPORTED_ALIGNERS) - + "; chosen based on data type if not set", type=str) - add_additional_option_to_group(output_args_group, "--genedb_output", help="output folder for converted gene " - "database, will be created automatically " - " (same as output by default)", type=str) + add_additional_option( + "--delta", + type=int, + default=None, + help="delta for inexact splice junction comparison, chosen automatically based on data type", + ) + add_hidden_option( + "--graph_clustering_distance", + type=int, + default=None, + help="intron graph clustering distance, " + "splice junctions less that this number of bp apart will not be differentiated", + ) + add_additional_option( + "--no_gzip", + help="do not gzip large output files", + dest="gzipped", + action="store_false", + default=True, + ) + add_additional_option( + "--no_gtf_check", + help="do not perform GTF checks", + dest="gtf_check", + action="store_false", + default=True, + ) + add_additional_option( + "--high_memory", + help="increase RAM consumption (store alignment and the genome in RAM)", + action="store_true", + default=False, + ) + add_additional_option( + "--no_junc_bed", + action="store_true", + default=False, + help="do NOT use annotation for read mapping", + ) + add_additional_option( + "--junc_bed_file", + type=str, + help="annotation in BED format produced by minimap's paftools.js gff2bed " + "(will be created automatically if not given)", + ) + add_additional_option( + "--no_secondary", + help="ignore secondary alignments (not recommended)", + action="store_true", + default=False, + ) + add_additional_option( + "--min_mapq", + help="ignore alignments with MAPQ < this" + "(also filters out secondary alignments, default: None)", + type=int, + ) + add_additional_option( + "--inconsistent_mapq_cutoff", + help="ignore inconsistent alignments with MAPQ < this " + "(works only with the reference annotation, default=5)", + type=int, + default=5, + ) + add_additional_option( + "--simple_alignments_mapq_cutoff", + help="ignore alignments with 1 or 2 exons and " + "MAPQ < this (works only in annotation-free mode, " + "default=1)", + type=int, + default=1, + ) + add_additional_option( + "--normalization_method", + type=str, + choices=[e.name for e in NormalizationMethod], + help="TPM normalization method: simple - conventional normalization using all counted reads;" + "usable_reads - includes all assigned reads.", + default=NormalizationMethod.simple.name, + ) + add_additional_option( + "--counts_format", + type=str, + choices=[e.name for e in GroupedOutputFormat], + help="output format for grouped counts", + default=GroupedOutputFormat.both.name, + ) + + add_additional_option_to_group( + pipeline_args_group, + "--keep_tmp", + help="do not remove temporary files " "in the end", + action="store_true", + default=False, + ) + add_additional_option_to_group( + input_args_group, + "--read_assignments", + nargs="+", + type=str, + help="reuse read assignments (binary format)", + default=None, + ) + add_hidden_option( + "--aligner", + help="force to use this alignment method, can be " + + ", ".join(SUPPORTED_ALIGNERS) + + "; chosen based on data type if not set", + type=str, + ) + add_additional_option_to_group( + output_args_group, + "--genedb_output", + help="output folder for converted gene " + "database, will be created automatically " + " (same as output by default)", + type=str, + ) add_hidden_option("--cage", help="bed file with CAGE peaks", type=str, default=None) - add_hidden_option("--cage-shift", type=int, default=50, help="interval before read start to look for CAGE peak") - parser.add_argument("--test", action=TestMode, nargs=0, help="run IsoQuant on toy dataset") + add_hidden_option( + "--cage-shift", + type=int, + default=50, + help="interval before read start to look for CAGE peak", + ) + parser.add_argument( + "--test", action=TestMode, nargs=0, help="run IsoQuant on toy dataset" + ) isoquant_version = "3.4.0" try: - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION")) as version_f: + with open( + os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION") + ) as version_f: isoquant_version = version_f.readline().strip() except FileNotFoundError: pass - parser.add_argument('--version', '-v', action='version', version='IsoQuant ' + isoquant_version) + parser.add_argument( + "--version", "-v", action="version", version="IsoQuant " + isoquant_version + ) args = parser.parse_args(cmd_args, namespace) if args.resume: resume_parser = argparse.ArgumentParser(add_help=False) - resume_parser.add_argument("--resume", action="store_true", default=False, - help="resume failed run, specify only output folder, " - "input options are not allowed") - resume_parser.add_argument("--output", "-o", - help="output folder, will be created automatically [default=isoquant_output]", - type=str, required=True) - resume_parser.add_argument('--debug', action='store_true', default=argparse.SUPPRESS, - help='Debug log output.') - resume_parser.add_argument("--threads", "-t", help="number of threads to use", - type=int, default=argparse.SUPPRESS) - resume_parser.add_argument("--high_memory", - help="increase RAM consumption (store alignment and the genome in RAM)", - action='store_true', default=False) - resume_parser.add_argument("--keep_tmp", help="do not remove temporary files in the end", - action='store_true', default=argparse.SUPPRESS) + resume_parser.add_argument( + "--resume", + action="store_true", + default=False, + help="resume failed run, specify only output folder, " + "input options are not allowed", + ) + resume_parser.add_argument( + "--output", + "-o", + help="output folder, will be created automatically [default=isoquant_output]", + type=str, + required=True, + ) + resume_parser.add_argument( + "--debug", + action="store_true", + default=argparse.SUPPRESS, + help="Debug log output.", + ) + resume_parser.add_argument( + "--threads", + "-t", + help="number of threads to use", + type=int, + default=argparse.SUPPRESS, + ) + resume_parser.add_argument( + "--high_memory", + help="increase RAM consumption (store alignment and the genome in RAM)", + action="store_true", + default=False, + ) + resume_parser.add_argument( + "--keep_tmp", + help="do not remove temporary files in the end", + action="store_true", + default=argparse.SUPPRESS, + ) args, unknown_args = resume_parser.parse_known_args(cmd_args) if unknown_args: - logger.error("You cannot specify options other than --output/--threads/--debug/--high_memory " - "with --resume option") + logger.error( + "You cannot specify options other than --output/--threads/--debug/--high_memory " + "with --resume option" + ) parser.print_usage() exit(-2) @@ -296,27 +580,38 @@ def check_and_load_args(args, parser): if args.resume: if not os.path.exists(args.output) or not os.path.exists(args.param_file): # logger is not defined yet - logger.error("Previous run config was not detected, cannot resume. " - "Check that output folder is correctly specified.") + logger.error( + "Previous run config was not detected, cannot resume. " + "Check that output folder is correctly specified." + ) exit(-3) args = load_previous_run(args) elif args.output_exists: if os.path.exists(args.param_file): if args.force: - logger.warning("Output folder already contains a previous run, will be overwritten.") + logger.warning( + "Output folder already contains a previous run, will be overwritten." + ) else: - logger.warning("Output folder already contains a previous run, some files may be overwritten. " - "Use --resume to resume a failed run. Use --force to avoid this message.") + logger.warning( + "Output folder already contains a previous run, some files may be overwritten. " + "Use --resume to resume a failed run. Use --force to avoid this message." + ) logger.warning("Press Ctrl+C to interrupt the run now.") delay = 9 for i in range(delay): countdown = delay - i - sys.stdout.write("Resuming the run in %d second%s\r" % (countdown, "s" if countdown > 1 else "")) + sys.stdout.write( + "Resuming the run in %d second%s\r" + % (countdown, "s" if countdown > 1 else "") + ) time.sleep(1) logger.info("Overwriting the previous run") time.sleep(1) else: - logger.warning("Output folder already exists, some files may be overwritten.") + logger.warning( + "Output folder already exists, some files may be overwritten." + ) if args.genedb_output is None: args.genedb_output = args.output @@ -327,7 +622,15 @@ def check_and_load_args(args, parser): elif args.genedb.lower().endswith("db"): args.genedb_filename = args.genedb else: - args.genedb_filename = os.path.join(args.output, os.path.splitext(os.path.basename(args.genedb))[0] + ".db") + args.genedb_filename = os.path.join( + args.output, os.path.splitext(os.path.basename(args.genedb))[0] + ".db" + ) + if args.genedb.lower().endswith("db"): + args.genedb_filename = args.genedb + else: + args.genedb_filename = os.path.join( + args.output, os.path.splitext(os.path.basename(args.genedb))[0] + ".db" + ) if not check_input_params(args): parser.print_usage() @@ -353,21 +656,34 @@ def load_previous_run(args): def save_params(args): - for file_opt in ["genedb", "reference", "index", "bam", "fastq", "bam_list", "fastq_list", "junc_bed_file", - "cage", "genedb_output", "read_assignments"]: + for file_opt in [ + "genedb", + "reference", + "index", + "bam", + "fastq", + "bam_list", + "fastq_list", + "junc_bed_file", + "cage", + "genedb_output", + "read_assignments", + ]: if file_opt in args.__dict__ and args.__dict__[file_opt]: if isinstance(args.__dict__[file_opt], list): - args.__dict__[file_opt] = list(map(os.path.abspath, args.__dict__[file_opt])) + args.__dict__[file_opt] = list( + map(os.path.abspath, args.__dict__[file_opt]) + ) else: args.__dict__[file_opt] = os.path.abspath(args.__dict__[file_opt]) if "read_group" in args.__dict__ and args.__dict__["read_group"]: vals = args.read_group.split(":") - if len(vals) > 1 and vals[0] == 'file': + if len(vals) > 1 and vals[0] == "file": vals[1] = os.path.abspath(vals[1]) args.read_group = ":".join(vals) - pickler = pickle.Pickler(open(args.param_file, "wb"), -1) + pickler = pickle.Pickler(open(args.param_file, "wb"), -1) pickler.dump(args) pass @@ -378,36 +694,65 @@ def check_input_params(args): logger.error("Reference genome was not provided") return False if not args.data_type: - logger.error("Data type is not provided, choose one of " + " ".join(DATA_TYPE_ALIASES.keys())) + logger.error( + "Data type is not provided, choose one of " + + " ".join(DATA_TYPE_ALIASES.keys()) + ) return False elif args.data_type not in DATA_TYPE_ALIASES.keys(): - logger.error("Unsupported data type " + args.data_type + ", choose one of: " + " ".join(DATA_TYPE_ALIASES.keys())) + logger.error( + "Unsupported data type " + + args.data_type + + ", choose one of: " + + " ".join(DATA_TYPE_ALIASES.keys()) + ) return False args.data_type = DATA_TYPE_ALIASES[args.data_type] - if not args.fastq and not args.fastq_list and not args.bam and not args.bam_list and not args.read_assignments and not args.yaml: + if ( + not args.fastq + and not args.fastq_list + and not args.bam + and not args.bam_list + and not args.read_assignments + and not args.yaml + ): logger.error("No input data was provided") return False - + if args.yaml and args.illumina_bam: - logger.error("When providing a yaml file it should include all input files, including the illumina bam file.") + logger.error( + "When providing a yaml file it should include all input files, including the illumina bam file." + ) return False - + if args.illumina_bam and (args.fastq_list or args.bam_list): - logger.error("Unsupported combination of list of input files and Illumina bam file." - "To combine multiple experiments with short read correction please use yaml input.") + logger.error( + "Unsupported combination of list of input files and Illumina bam file." + "To combine multiple experiments with short read correction please use yaml input." + ) return False args.input_data = InputDataStorage(args) if args.aligner is not None and args.aligner not in SUPPORTED_ALIGNERS: - logger.error(" Unsupported aligner " + args.aligner + ", choose one of: " + " ".join(SUPPORTED_ALIGNERS)) + logger.error( + " Unsupported aligner " + + args.aligner + + ", choose one of: " + + " ".join(SUPPORTED_ALIGNERS) + ) return False if args.run_aligner_only and args.input_data.input_type == "bam": logger.error("Do not use BAM files with --run_aligner_only option.") return False if args.stranded not in SUPPORTED_STRANDEDNESS: - logger.error("Unsupported strandness " + args.stranded + ", choose one of: " + " ".join(SUPPORTED_STRANDEDNESS)) + logger.error( + "Unsupported strandness " + + args.stranded + + ", choose one of: " + + " ".join(SUPPORTED_STRANDEDNESS) + ) return False if not args.genedb: @@ -415,15 +760,21 @@ def check_input_params(args): logger.warning("--count_exons option has no effect without gene annotation") if args.sqanti_output: args.sqanti_output = False - logger.warning("--sqanti_output option has no effect without gene annotation") + logger.warning( + "--sqanti_output option has no effect without gene annotation" + ) if args.no_model_construction: - logger.warning("Setting --no_model_construction without providing a gene " - "annotation will not produce any meaningful results") + logger.warning( + "Setting --no_model_construction without providing a gene " + "annotation will not produce any meaningful results" + ) if args.no_model_construction and args.sqanti_output: args.sqanti_output = False - logger.warning("--sqanti_output option has no effect without model construction") - + logger.warning( + "--sqanti_output option has no effect without model construction" + ) + check_input_files(args) return True @@ -443,14 +794,22 @@ def check_input_files(args): if args.input_data.input_type == "bam": bamfile_in = pysam.AlignmentFile(in_file, "rb") if not bamfile_in.has_index(): - logger.critical("BAM file " + in_file + " is not indexed, run samtools sort and samtools index") + logger.critical( + "BAM file " + + in_file + + " is not indexed, run samtools sort and samtools index" + ) exit(-1) bamfile_in.close() if sample.illumina_bam is not None: for illumina in sample.illumina_bam: bamfile_in = pysam.AlignmentFile(illumina, "rb") if not bamfile_in.has_index(): - logger.critical("BAM file " + illumina + " is not indexed, run samtools sort and samtools index") + logger.critical( + "BAM file " + + illumina + + " is not indexed, run samtools sort and samtools index" + ) exit(-1) bamfile_in.close() @@ -480,13 +839,18 @@ def create_output_dirs(args): sample_dir = sample.out_dir if os.path.exists(sample_dir): if not args.resume: - logger.warning(sample_dir + " folder already exists, some files may be overwritten") + logger.warning( + sample_dir + " folder already exists, some files may be overwritten" + ) else: os.makedirs(sample_dir) sample_aux_dir = sample.aux_dir if os.path.exists(sample_aux_dir): if not args.resume: - logger.warning(sample_aux_dir + " folder already exists, some files may be overwritten") + logger.warning( + sample_aux_dir + + " folder already exists, some files may be overwritten" + ) else: os.makedirs(sample_aux_dir) @@ -506,7 +870,7 @@ def set_logger(args, logger_instance): shutil.copyfileobj(open(log_file, "r"), olf) f = open(log_file, "w") - f.write("Command line: " + args._cmd_line + '\n') + f.write("Command line: " + args._cmd_line + "\n") f.close() fh = logging.FileHandler(log_file) fh.set_name("isoquant_file_log") @@ -515,7 +879,7 @@ def set_logger(args, logger_instance): ch.set_name("isoquant_screen_log") ch.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") fh.setFormatter(formatter) ch.setFormatter(formatter) if all(fh.get_name() != h.get_name() for h in logger_instance.handlers): @@ -527,21 +891,33 @@ def set_logger(args, logger_instance): def set_data_dependent_options(args): - matching_strategies = {ASSEMBLY: "precise", PACBIO_CCS_DATA: "precise", NANOPORE_DATA: "default"} + matching_strategies = { + ASSEMBLY: "precise", + PACBIO_CCS_DATA: "precise", + NANOPORE_DATA: "default", + } if args.matching_strategy is None: args.matching_strategy = matching_strategies[args.data_type] - model_construction_strategies = {ASSEMBLY: "assembly", PACBIO_CCS_DATA: "default_pacbio", NANOPORE_DATA: "default_ont"} + model_construction_strategies = { + ASSEMBLY: "assembly", + PACBIO_CCS_DATA: "default_pacbio", + NANOPORE_DATA: "default_ont", + } if args.model_construction_strategy is None: args.model_construction_strategy = model_construction_strategies[args.data_type] if args.fl_data and args.model_construction_strategy == "default_pacbio": args.model_construction_strategy = "fl_pacbio" - splice_correction_strategies = {ASSEMBLY: "assembly", PACBIO_CCS_DATA: "default_pacbio", NANOPORE_DATA: "default_ont"} + splice_correction_strategies = { + ASSEMBLY: "assembly", + PACBIO_CCS_DATA: "default_pacbio", + NANOPORE_DATA: "default_ont", + } if args.splice_correction_strategy is None: args.splice_correction_strategy = splice_correction_strategies[args.data_type] - args.resolve_ambiguous = 'monoexon_and_fsm' if args.fl_data else 'default' + args.resolve_ambiguous = "monoexon_and_fsm" if args.fl_data else "default" args.requires_polya_for_construction = False if args.read_group is None and args.input_data.has_replicas(): args.read_group = "file_name" @@ -549,16 +925,25 @@ def set_data_dependent_options(args): def set_matching_options(args): - MatchingStrategy = namedtuple('MatchingStrategy', - ('delta', 'max_intron_shift', 'max_missed_exon_len', 'max_fake_terminal_exon_len', - 'max_suspicious_intron_abs_len', 'max_suspicious_intron_rel_len', - 'resolve_ambiguous', 'correct_minor_errors')) + MatchingStrategy = namedtuple( + "MatchingStrategy", + ( + "delta", + "max_intron_shift", + "max_missed_exon_len", + "max_fake_terminal_exon_len", + "max_suspicious_intron_abs_len", + "max_suspicious_intron_rel_len", + "resolve_ambiguous", + "correct_minor_errors", + ), + ) strategies = { - 'exact': MatchingStrategy(0, 0, 0, 0, 0, 0.0, 'monoexon_only', False), - 'precise': MatchingStrategy(4, 30, 50, 20, 0, 0.0, 'monoexon_and_fsm', True), - 'default': MatchingStrategy(6, 60, 100, 40, 60, 1.0, 'monoexon_and_fsm', True), - 'loose': MatchingStrategy(12, 60, 100, 40, 60, 1.0, 'all', True), + "exact": MatchingStrategy(0, 0, 0, 0, 0, 0.0, "monoexon_only", False), + "precise": MatchingStrategy(4, 30, 50, 20, 0, 0.0, "monoexon_and_fsm", True), + "default": MatchingStrategy(6, 60, 100, 40, 60, 1.0, "monoexon_and_fsm", True), + "loose": MatchingStrategy(12, 60, 100, 40, 60, 1.0, "all", True), } strategy = strategies[args.matching_strategy] @@ -586,32 +971,61 @@ def set_matching_options(args): args.minimal_intron_absence_overlap = 20 args.polya_window = 16 args.polya_fraction = 0.75 - if args.resolve_ambiguous == 'default': + if args.resolve_ambiguous == "default": args.resolve_ambiguous = strategy.resolve_ambiguous if args.resolve_ambiguous not in AmbiguityResolvingMethod.__dict__: - logger.error("Incorrect resolving ambiguity method: " + args.resolve_ambiguous + ", default will be used") + logger.error( + "Incorrect resolving ambiguity method: " + + args.resolve_ambiguous + + ", default will be used" + ) args.resolve_ambiguous = strategy.resolve_ambiguous args.resolve_ambiguous = AmbiguityResolvingMethod[args.resolve_ambiguous] args.correct_minor_errors = strategy.correct_minor_errors - updated_strategy = MatchingStrategy(args.delta, args.max_intron_shift, args.max_missed_exon_len, - args.max_fake_terminal_exon_len, - args.max_suspicious_intron_abs_len, args.max_suspicious_intron_rel_len, - args.resolve_ambiguous, args.correct_minor_errors) - logger.debug('Using %s strategy. Updated strategy: %s.' % (args.matching_strategy, updated_strategy)) + updated_strategy = MatchingStrategy( + args.delta, + args.max_intron_shift, + args.max_missed_exon_len, + args.max_fake_terminal_exon_len, + args.max_suspicious_intron_abs_len, + args.max_suspicious_intron_rel_len, + args.resolve_ambiguous, + args.correct_minor_errors, + ) + logger.debug( + "Using %s strategy. Updated strategy: %s." + % (args.matching_strategy, updated_strategy) + ) def set_splice_correction_options(args): - SplicSiteCorrectionStrategy = namedtuple('SplicSiteCorrectionStrategy', - ('fuzzy_junctions', 'intron_shifts', 'skipped_exons', - 'terminal_exons', 'fake_terminal_exons', 'microintron_retention')) + SplicSiteCorrectionStrategy = namedtuple( + "SplicSiteCorrectionStrategy", + ( + "fuzzy_junctions", + "intron_shifts", + "skipped_exons", + "terminal_exons", + "fake_terminal_exons", + "microintron_retention", + ), + ) strategies = { - 'none': SplicSiteCorrectionStrategy(False, False, False, False, False, False), - 'default_pacbio': SplicSiteCorrectionStrategy(True, False, True, False, False, True), - 'conservative_ont': SplicSiteCorrectionStrategy(True, False, True, False, False, False), - 'default_ont': SplicSiteCorrectionStrategy(True, False, True, False, True, True), - 'all': SplicSiteCorrectionStrategy(True, True, True, True, True, True), - 'assembly': SplicSiteCorrectionStrategy(False, False, True, False, False, False) + "none": SplicSiteCorrectionStrategy(False, False, False, False, False, False), + "default_pacbio": SplicSiteCorrectionStrategy( + True, False, True, False, False, True + ), + "conservative_ont": SplicSiteCorrectionStrategy( + True, False, True, False, False, False + ), + "default_ont": SplicSiteCorrectionStrategy( + True, False, True, False, True, True + ), + "all": SplicSiteCorrectionStrategy(True, True, True, True, True, True), + "assembly": SplicSiteCorrectionStrategy( + False, False, True, False, False, False + ), } strategy = strategies[args.splice_correction_strategy] args.correct_fuzzy_junctions = strategy.fuzzy_junctions @@ -623,35 +1037,199 @@ def set_splice_correction_options(args): def set_model_construction_options(args): - ModelConstructionStrategy = namedtuple('ModelConstructionStrategy', - ('min_novel_intron_count', - 'graph_clustering_ratio', 'graph_clustering_distance', - 'min_novel_isolated_intron_abs', 'min_novel_isolated_intron_rel', - 'terminal_position_abs', 'terminal_position_rel', - 'terminal_internal_position_rel', - 'min_known_count', 'min_nonfl_count', - 'min_novel_count', 'min_novel_count_rel', - 'min_mono_count_rel', 'singleton_adjacent_cov', - 'fl_only', 'novel_monoexonic', - 'require_monointronic_polya', 'require_monoexonic_polya', - 'report_canonical')) + ModelConstructionStrategy = namedtuple( + "ModelConstructionStrategy", + ( + "min_novel_intron_count", + "graph_clustering_ratio", + "graph_clustering_distance", + "min_novel_isolated_intron_abs", + "min_novel_isolated_intron_rel", + "terminal_position_abs", + "terminal_position_rel", + "terminal_internal_position_rel", + "min_known_count", + "min_nonfl_count", + "min_novel_count", + "min_novel_count_rel", + "min_mono_count_rel", + "singleton_adjacent_cov", + "fl_only", + "novel_monoexonic", + "require_monointronic_polya", + "require_monoexonic_polya", + "report_canonical", + ), + ) strategies = { - 'reliable': ModelConstructionStrategy(2, 0.5, 20, 5, 0.05, 1, 0.1, 0.1, 2, 4, 8, 0.05, 0.05, 50, - True, False, True, True, StrandnessReportingLevel.only_canonical), - 'default_pacbio': ModelConstructionStrategy(1, 0.5, 10, 2, 0.02, 1, 0.05, 0.05, 1, 2, 2, 0.02, 0.005, 100, - False, True, False, True, StrandnessReportingLevel.only_canonical), - 'sensitive_pacbio':ModelConstructionStrategy(1, 0.5, 5, 2, 0.005, 1, 0.01, 0.02, 1, 2, 2, 0.005, 0.001, 100, - False, True, False, False, StrandnessReportingLevel.only_stranded), - 'default_ont': ModelConstructionStrategy(1, 0.5, 20, 3, 0.02, 1, 0.05, 0.05, 1, 3, 3, 0.02, 0.02, 10, - False, False, True, True, StrandnessReportingLevel.only_canonical), - 'sensitive_ont': ModelConstructionStrategy(1, 0.5, 20, 3, 0.005, 1, 0.01, 0.02, 1, 2, 3, 0.005, 0.005, 10, - False, True, False, False, StrandnessReportingLevel.only_stranded), - 'fl_pacbio': ModelConstructionStrategy(1, 0.5, 10, 2, 0.02, 1, 0.05, 0.01, 1, 2, 3, 0.02, 0.005, 100, - True, True, False, False, StrandnessReportingLevel.only_canonical), - 'all': ModelConstructionStrategy(0, 0.3, 5, 1, 0.002, 1, 0.01, 0.01, 1, 1, 1, 0.002, 0.001, 500, - False, True, False, False, StrandnessReportingLevel.all), - 'assembly': ModelConstructionStrategy(0, 0.3, 5, 1, 0.05, 1, 0.01, 0.02, 1, 1, 1, 0.05, 0.01, 50, - False, True, False, False, StrandnessReportingLevel.only_stranded) + "reliable": ModelConstructionStrategy( + 2, + 0.5, + 20, + 5, + 0.05, + 1, + 0.1, + 0.1, + 2, + 4, + 8, + 0.05, + 0.05, + 50, + True, + False, + True, + True, + StrandnessReportingLevel.only_canonical, + ), + "default_pacbio": ModelConstructionStrategy( + 1, + 0.5, + 10, + 2, + 0.02, + 1, + 0.05, + 0.05, + 1, + 2, + 2, + 0.02, + 0.005, + 100, + False, + True, + False, + True, + StrandnessReportingLevel.only_canonical, + ), + "sensitive_pacbio": ModelConstructionStrategy( + 1, + 0.5, + 5, + 2, + 0.005, + 1, + 0.01, + 0.02, + 1, + 2, + 2, + 0.005, + 0.001, + 100, + False, + True, + False, + False, + StrandnessReportingLevel.only_stranded, + ), + "default_ont": ModelConstructionStrategy( + 1, + 0.5, + 20, + 3, + 0.02, + 1, + 0.05, + 0.05, + 1, + 3, + 3, + 0.02, + 0.02, + 10, + False, + False, + True, + True, + StrandnessReportingLevel.only_canonical, + ), + "sensitive_ont": ModelConstructionStrategy( + 1, + 0.5, + 20, + 3, + 0.005, + 1, + 0.01, + 0.02, + 1, + 2, + 3, + 0.005, + 0.005, + 10, + False, + True, + False, + False, + StrandnessReportingLevel.only_stranded, + ), + "fl_pacbio": ModelConstructionStrategy( + 1, + 0.5, + 10, + 2, + 0.02, + 1, + 0.05, + 0.01, + 1, + 2, + 3, + 0.02, + 0.005, + 100, + True, + True, + False, + False, + StrandnessReportingLevel.only_canonical, + ), + "all": ModelConstructionStrategy( + 0, + 0.3, + 5, + 1, + 0.002, + 1, + 0.01, + 0.01, + 1, + 1, + 1, + 0.002, + 0.001, + 500, + False, + True, + False, + False, + StrandnessReportingLevel.all, + ), + "assembly": ModelConstructionStrategy( + 0, + 0.3, + 5, + 1, + 0.05, + 1, + 0.01, + 0.02, + 1, + 1, + 1, + 0.05, + 0.01, + 50, + False, + True, + False, + False, + StrandnessReportingLevel.only_stranded, + ), } strategy = strategies[args.model_construction_strategy] @@ -681,8 +1259,10 @@ def set_model_construction_options(args): args.report_novel_unspliced = strategy.novel_monoexonic if not args.report_novel_unspliced and not args.no_model_construction: - logger.info("Novel unspliced transcripts will not be reported, " - "set --report_novel_unspliced true to discover them") + logger.info( + "Novel unspliced transcripts will not be reported, " + "set --report_novel_unspliced true to discover them" + ) args.require_monointronic_polya = strategy.require_monointronic_polya args.require_monoexonic_polya = strategy.require_monoexonic_polya @@ -693,16 +1273,21 @@ def set_model_construction_options(args): def set_configs_directory(args): - config_dir = os.path.join(os.environ['HOME'], '.config', 'IsoQuant') + config_dir = os.path.join(os.environ["HOME"], ".config", "IsoQuant") os.makedirs(config_dir, exist_ok=True) - args.db_config_path = os.path.join(config_dir, 'db_config.json') - args.index_config_path = os.path.join(config_dir, 'index_config.json') - args.bed_config_path = os.path.join(config_dir, 'bed_config.json') - args.alignment_config_path = os.path.join(config_dir, 'alignment_config.json') - for config_path in (args.db_config_path, args.index_config_path, args.bed_config_path, args.alignment_config_path): + args.db_config_path = os.path.join(config_dir, "db_config.json") + args.index_config_path = os.path.join(config_dir, "index_config.json") + args.bed_config_path = os.path.join(config_dir, "bed_config.json") + args.alignment_config_path = os.path.join(config_dir, "alignment_config.json") + for config_path in ( + args.db_config_path, + args.index_config_path, + args.bed_config_path, + args.alignment_config_path, + ): if not os.path.exists(config_path): - with open(config_path, 'w') as f_out: + with open(config_path, "w") as f_out: json.dump({}, f_out) @@ -721,11 +1306,15 @@ def set_additional_params(args): multimap_strategies = {} for e in MultimapResolvingStrategy: multimap_strategies[e.name] = e.value - args.multimap_strategy = MultimapResolvingStrategy(multimap_strategies[args.multimap_strategy]) + args.multimap_strategy = MultimapResolvingStrategy( + multimap_strategies[args.multimap_strategy] + ) args.needs_reference = True if args.needs_reference and not args.reference: - logger.warning("Reference genome is not provided! This may affect quality of the results!") + logger.warning( + "Reference genome is not provided! This may affect quality of the results!" + ) args.needs_reference = False args.simple_models_mapq_cutoff = 30 @@ -745,7 +1334,7 @@ def run_pipeline(args): logger.info("pyfaidx version: %s" % pyfaidx.__version__) # convert GTF/GFF if needed - if args.genedb and not args.genedb.lower().endswith('db'): + if args.genedb and not args.genedb.lower().endswith("db"): args.genedb = convert_gtf_to_db(args) # map reads if fastqs are provided @@ -756,7 +1345,9 @@ def run_pipeline(args): args.input_data = dataset_mapper.map_reads(args) if args.run_aligner_only: - logger.info("Isoform assignment step is skipped because --run-aligner-only option was used") + logger.info( + "Isoform assignment step is skipped because --run-aligner-only option was used" + ) else: # run isoform assignment dataset_processor = DatasetProcessor(args) @@ -769,35 +1360,54 @@ def run_pipeline(args): logger.info(" === IsoQuant pipeline finished === ") - # Test mode is triggered by --test option class TestMode(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - out_dir = 'isoquant_test' + out_dir = "isoquant_test" if os.path.exists(out_dir): shutil.rmtree(out_dir) source_dir = os.path.dirname(os.path.realpath(__file__)) - options = ['--output', out_dir, '--threads', '2', - '--fastq', os.path.join(source_dir, 'tests/simple_data/chr9.4M.ont.sim.fq.gz'), - '--reference', os.path.join(source_dir, 'tests/simple_data/chr9.4M.fa.gz'), - '--genedb', os.path.join(source_dir, 'tests/simple_data/chr9.4M.gtf.gz'), - '--clean_start', '--data_type', 'nanopore', '--complete_genedb', '--force', '-p', 'TEST_DATA'] - print('=== Running in test mode === ') - print('Any other option is ignored ') + options = [ + "--output", + out_dir, + "--threads", + "2", + "--fastq", + os.path.join(source_dir, "tests/simple_data/chr9.4M.ont.sim.fq.gz"), + "--reference", + os.path.join(source_dir, "tests/simple_data/chr9.4M.fa.gz"), + "--genedb", + os.path.join(source_dir, "tests/simple_data/chr9.4M.gtf.gz"), + "--clean_start", + "--data_type", + "nanopore", + "--complete_genedb", + "--force", + "-p", + "TEST_DATA", + ] + print("=== Running in test mode === ") + print("Any other option is ignored ") main(options) if self._check_log(): - logger.info(' === TEST PASSED CORRECTLY === ') + logger.info(" === TEST PASSED CORRECTLY === ") else: - logger.error(' === TEST FAILED ===') + logger.error(" === TEST FAILED ===") exit(-1) parser.exit() @staticmethod def _check_log(): - with open('isoquant_test/isoquant.log', 'r') as f: + with open("isoquant_test/isoquant.log", "r") as f: log = f.read() - correct_results = ['total assignments 4', 'polyA tail detected in 2', 'unique: 1', 'known: 2', 'Processed 1 experiment'] + correct_results = [ + "total assignments 4", + "polyA tail detected in 2", + "unique: 1", + "known: 2", + "Processed 1 experiment", + ] return all([result in log for result in correct_results]) @@ -827,12 +1437,16 @@ def main(cmd_args): print_exc(file=strout) s = strout.getvalue() if s: - logger.critical("IsoQuant failed with the following error, please, submit this issue to " - "https://github.com/ablab/IsoQuant/issues" + s) + logger.critical( + "IsoQuant failed with the following error, please, submit this issue to " + "https://github.com/ablab/IsoQuant/issues" + s + ) else: print_exc() else: - sys.stderr.write("IsoQuant failed with the following error, please, submit this issue to " - "https://github.com/ablab/IsoQuant/issues") + sys.stderr.write( + "IsoQuant failed with the following error, please, submit this issue to " + "https://github.com/ablab/IsoQuant/issues" + ) print_exc() sys.exit(-1) diff --git a/requirements.txt b/requirements.txt index 52037078..3bd76857 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,11 @@ pyfaidx>=0.7 pyyaml>=5.4 matplotlib>=3.1.3 numpy>=1.18.1 -scipy>=1.4.1 +scipy>=1.10.0 seaborn>=0.10.0 +scikit-learn>=1.5 +rpy2>=3.5.1 +mygene>=3.2.0 + + diff --git a/src/gene_info.py b/src/gene_info.py index 469bceae..a3f1e067 100644 --- a/src/gene_info.py +++ b/src/gene_info.py @@ -183,7 +183,7 @@ def __init__(self, gene_db_list, db, delta=0, prepare_profiles=True): self.set_sources() self.gene_id_map = {} self.set_gene_ids() - self.gene_attributes = {} + self.feature_attributes = {} self.set_gene_attributes() if prepare_profiles: self.exon_property_map = self.set_feature_properties(self.all_isoforms_exons, self.exon_profiles) @@ -208,7 +208,7 @@ def from_models(cls, transcript_model_storage, delta=0): gene_info.sources = {} gene_info.other_features = {} gene_info.gene_id_map = {} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} introns = set() exons = set() @@ -294,7 +294,7 @@ def from_model(cls, transcript_model, delta=0): transcript_model.gene_id: transcript_model.source} gene_info.other_features = {transcript_model.transcript_id: transcript_model.other_features} gene_info.gene_id_map = {transcript_model.transcript_id: transcript_model.gene_id} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.regions_for_bam_fetch = [(gene_info.start, gene_info.end)] gene_info.exon_property_map = None @@ -332,7 +332,7 @@ def from_region(cls, chr_id, start, end, delta=0, chr_record=None): gene_info.other_features = {} gene_info.sources = {} gene_info.gene_id_map = {} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.regions_for_bam_fetch = [(start, end)] gene_info.exon_property_map = None gene_info.intron_property_map = None @@ -390,7 +390,7 @@ def deserialize(cls, infile, genedb): gene_info.set_sources() gene_info.gene_id_map = {} gene_info.set_gene_ids() - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.set_gene_attributes() gene_info.exon_property_map = gene_info.set_feature_properties(gene_info.all_isoforms_exons, gene_info.exon_profiles) gene_info.intron_property_map = gene_info.set_feature_properties(gene_info.all_isoforms_introns, gene_info.intron_profiles) @@ -475,19 +475,26 @@ def set_gene_ids(self): self.gene_id_map[t.id] = gene_db.id def set_gene_attributes(self): - self.gene_attributes = defaultdict(str) + self.feature_attributes = defaultdict(str) for gene_db in self.gene_db_list: for attr in gene_db.attributes.keys(): - if attr in ['gene_id', 'ID', 'level']: + if attr in ['gene_id', 'ID', 'level', 'Parent']: continue if gene_db.attributes[attr]: - self.gene_attributes[gene_db.id] += '%s "%s"; ' % (attr, gene_db.attributes[attr][0]) - for t in self.db.children(gene_db, featuretype=('transcript', 'mRNA'), order_by='start'): + self.feature_attributes[gene_db.id] += '%s "%s"; ' % (attr, gene_db.attributes[attr][0]) + for t in self.db.children(gene_db, featuretype=('transcript', 'mRNA')): for attr in t.attributes.keys(): - if attr in ['transcript_id', 'gene_id', 'ID', 'level', 'exons']: + if attr in ['transcript_id', 'gene_id', 'ID', 'level', 'exons', 'Parent']: continue if t.attributes[attr]: - self.gene_attributes[t.id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) + self.feature_attributes[t.id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) + for e in self.db.children(gene_db, featuretype=('exon')): + exon_id = t.id + "_%d_%d_%s" % (e.start, e.end, e.strand) + for attr in t.attributes.keys(): + if attr in ['transcript_id', 'gene_id', 'ID', 'Parent', 'level', 'exon_id', 'exon', 'exon_number']: + continue + if t.attributes[attr]: + self.feature_attributes[exon_id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) # assigns an ordered list of all known exons and introns to self.exons and self.introns # returns 2 maps, isoform id -> intron / exon list diff --git a/src/gene_model.py b/src/gene_model.py deleted file mode 100644 index 0a7e893b..00000000 --- a/src/gene_model.py +++ /dev/null @@ -1,298 +0,0 @@ -import json -import os -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -from scipy.spatial.distance import euclidean - - -def parse_data(data): - genes = {} - for condition, condition_data in data.items(): - for gene, gene_data in condition_data.items(): - if gene not in genes: - genes[gene] = { - "chromosome": gene_data["chromosome"], - "start": gene_data["start"], - "end": gene_data["end"], - "strand": gene_data["strand"], - "biotype": gene_data["biotype"], - "transcripts": {}, - } - genes[gene]["transcripts"][condition] = gene_data["transcripts"] - genes[gene][condition] = gene_data["value"] - return genes - - -def calculate_deviance(wt_transcripts, condition_transcripts): - all_transcripts = set(wt_transcripts.keys()).union( - set(condition_transcripts.keys()) - ) - - wt_proportions = [wt_transcripts.get(t, 0) for t in all_transcripts] - condition_proportions = [condition_transcripts.get(t, 0) for t in all_transcripts] - - total_wt = sum(wt_proportions) - total_condition = sum(condition_proportions) - - if total_wt > 0: - wt_proportions = [p / total_wt for p in wt_proportions] - if total_condition > 0: - condition_proportions = [p / total_condition for p in condition_proportions] - - distance = euclidean(wt_proportions, condition_proportions) - - # Reduce distance if total unique transcripts are 1 - if len(all_transcripts) == 1: - distance *= 0.7 - - return distance - - -def calculate_metrics(genes): - metrics = [] - for gene, gene_data in genes.items(): - wt_transcripts = gene_data["transcripts"].get("wild_type", {}) - - for condition in gene_data: - if condition in [ - "chromosome", - "start", - "end", - "strand", - "biotype", - "transcripts", - "wild_type", - ]: - continue - condition_transcripts = gene_data["transcripts"].get(condition, {}) - deviance = calculate_deviance(wt_transcripts, condition_transcripts) - metrics.append({"gene": gene, "condition": condition, "deviance": deviance}) - - value = gene_data.get(condition, 0) - wt_value = gene_data.get("wild_type", 0) - abs_diff = abs(value - wt_value) - metrics.append( - { - "gene": gene, - "condition": condition, - "value": value, - "abs_diff": abs_diff, - } - ) - - return pd.DataFrame(metrics) - - -def check_known_target(gene, known_targets): - for target in known_targets: - if "|" in target: - if any(part in gene for part in target.split("|")): - return 1 - elif target == gene: - return 1 - return 0 - - -def rank_genes(df, known_genes_path=None): - if known_genes_path: - target_genes_df = pd.read_csv(known_genes_path, header=None, names=["gene"]) - known_targets = target_genes_df["gene"].tolist() - df["known_target"] = df["gene"].apply( - lambda x: check_known_target(x, known_targets) - ) - else: - df["known_target"] = 0 - - value_ranking = df.groupby("gene")["value"].mean().reset_index() - abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() - deviance_ranking = df.groupby("gene")["deviance"].mean().reset_index() - - value_ranking["rank_value"] = value_ranking["value"].rank(ascending=False) - abs_diff_ranking["rank_abs_diff"] = abs_diff_ranking["abs_diff"].rank( - ascending=False - ) - deviance_ranking["rank_deviance"] = deviance_ranking["deviance"].rank( - ascending=False - ) - - merged_df = value_ranking[["gene", "rank_value"]].merge( - abs_diff_ranking[["gene", "rank_abs_diff"]], on="gene" - ) - merged_df = merged_df.merge(deviance_ranking[["gene", "rank_deviance"]], on="gene") - merged_df = merged_df.merge(df[["gene", "known_target"]], on="gene") - # Devalue the importance of overall expression by reducing its weight - merged_df["combined_rank"] = ( - merged_df["rank_value"] # Reduced weight for rank_value - + merged_df["rank_abs_diff"] - + merged_df["rank_deviance"] - ) - - top_combined_ranking = merged_df.sort_values(by="combined_rank").head(10) - top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(10) - top_100_combined_ranking = merged_df.sort_values(by="combined_rank").head(100) - - return ( - top_combined_ranking, - top_deviance_ranking, - top_100_combined_ranking, - merged_df, - ) - - -def visualize_ranking( - top_combined_ranking, top_deviance_ranking, merged_df, output_dir -): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # Bar plot for combined rank - plt.figure(figsize=(12, 8)) - sns.barplot( - x="combined_rank", y="gene", data=top_combined_ranking, palette="viridis" - ) - plt.title("Top 10 Genes by Combined Ranking") - plt.xlabel("Combined Rank") - plt.ylabel("Gene") - plt.savefig(os.path.join(output_dir, "top_genes_combined_ranking.png"), dpi=300) - plt.close() - - # Heatmap for metric ranks - top_genes = top_combined_ranking["gene"].tolist() - heatmap_data = merged_df[merged_df["gene"].isin(top_genes)] - heatmap_data = heatmap_data.set_index("gene")[ - ["rank_value", "rank_abs_diff", "rank_deviance"] - ] - - plt.figure(figsize=(12, 8)) - sns.heatmap( - heatmap_data, - annot=True, - cmap="RdBu_r", - linewidths=0.5, - cbar_kws={"label": "Rank"}, - ) - plt.title("Metric Ranks for Top 10 Genes") - plt.savefig(os.path.join(output_dir, "metric_ranks_heatmap.png"), dpi=300) - plt.close() - - # Diverging bar plot for deviance - plt.figure(figsize=(12, 8)) - sns.barplot( - x="rank_deviance", - y="gene", - data=top_deviance_ranking, - palette="coolwarm", - orient="h", - ) - plt.title("Top 10 Genes by Transcript Deviance from Wild Type") - plt.xlabel("Rank of Deviance from Wild Type") - plt.ylabel("Gene") - plt.axvline(x=0, color="grey", linestyle="--") - plt.savefig(os.path.join(output_dir, "deviance_from_wild_type.png"), dpi=300) - plt.close() - - # Scatter plot for rank_value vs rank_abs_diff - plt.figure(figsize=(12, 8)) - sns.scatterplot( - x="rank_value", - y="rank_abs_diff", - hue="gene", - data=top_deviance_ranking, - palette="deep", - s=100, - ) - plt.title("Rank Value vs Rank Absolute Difference") - plt.xlabel("Rank Value") - plt.ylabel("Rank Absolute Difference") - plt.savefig(os.path.join(output_dir, "rank_value_vs_rank_abs_diff.png"), dpi=300) - plt.close() - - # Combined multi-metric visualization - fig, axes = plt.subplots(2, 2, figsize=(20, 16)) - sns.barplot( - x="combined_rank", - y="gene", - data=top_combined_ranking, - palette="viridis", - ax=axes[0, 0], - ) - axes[0, 0].set_title("Combined Rank") - axes[0, 0].set_xlabel("Combined Rank") - axes[0, 0].set_ylabel("Gene") - - sns.heatmap( - heatmap_data, - annot=True, - cmap="RdBu_r", - linewidths=0.5, - cbar_kws={"label": "Rank"}, - ax=axes[0, 1], - ) - axes[0, 1].set_title("Metric Ranks") - - sns.barplot( - x="rank_deviance", - y="gene", - data=top_deviance_ranking, - palette="coolwarm", - orient="h", - ax=axes[1, 0], - ) - axes[1, 0].set_title("Transcript Deviance from Wild Type") - axes[1, 0].set_xlabel("Rank of Deviance from Wild Type") - axes[1, 0].set_ylabel("Gene") - axes[1, 0].axvline(x=0, color="grey", linestyle="--") - - sns.scatterplot( - x="rank_value", - y="rank_abs_diff", - hue="gene", - data=top_deviance_ranking, - palette="deep", - s=100, - ax=axes[1, 1], - ) - axes[1, 1].set_title("Rank Value vs Rank Absolute Difference") - axes[1, 1].set_xlabel("Rank Value") - axes[1, 1].set_ylabel("Rank Absolute Difference") - - plt.tight_layout() - plt.savefig(os.path.join(output_dir, "combined_visualization.png"), dpi=300) - plt.close() - - -def save_top_genes(top_combined_ranking, output_dir, num_genes): - top_combined_ranking.head(num_genes)[["gene"]].to_csv( - os.path.join(output_dir, f"top_{num_genes}_genes.txt"), - index=False, - header=False, - sep="\t", - ) - return os.path.join(output_dir, f"top_{num_genes}_genes.txt") - - -def rank_and_visualize_genes( - input_data, output_dir, num_genes=100, known_genes_path=None -): - genes = parse_data(input_data) - metrics_df = calculate_metrics(genes) - top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( - rank_genes(metrics_df, known_genes_path) - ) - merged_df = merged_df.drop_duplicates(subset="gene", keep="first") - top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) - top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(num_genes) - - visualize_ranking(top_combined_ranking, top_deviance_ranking, merged_df, output_dir) - path = save_top_genes(top_combined_ranking, output_dir, num_genes) - - print(f"\nTop {num_genes} Genes by Combined Ranking:") - print(top_combined_ranking[["gene", "combined_rank"]]) - print(f"\nDetailed Metrics for Top {num_genes} Genes by Combined Ranking:") - print(top_combined_ranking) - - merged_df.to_csv(os.path.join(output_dir, "gene_metrics.csv"), index=False) - - return path diff --git a/src/gtf2db.py b/src/gtf2db.py index ef6fc4fa..ee82a43d 100755 --- a/src/gtf2db.py +++ b/src/gtf2db.py @@ -59,6 +59,8 @@ def get_color(transcript_kind): gene_name = record["gene_name"][0] elif "gene_id" in record.attributes: gene_name = record["gene_id"][0] + elif "Parent" in record.attributes: + gene_name = record["Parent"][0] else: gene_name = "unknown_gene" transcript_name = record.id + "|" + transcript_type + "|" + gene_name @@ -97,12 +99,13 @@ def check_input_gtf(gtf, db, complete_db): gtf_is_correct, corrected_gtf, out_fname, has_meta_features = check_gtf_duplicates(gtf) if not gtf_is_correct: outdir = os.path.dirname(db) - new_gtf_path = os.path.join(outdir, out_fname) - with open(new_gtf_path, "w") as out_gtf: - out_gtf.write(corrected_gtf) logger.error("Input GTF seems to be corrupted (see warnings above).") - logger.error("An attempt to correct this GTF was made, the result is written to %s" % new_gtf_path) - logger.error("NB! some transcript / gene ids in the corrected annotation are modified.") + if out_fname and corrected_gtf: + new_gtf_path = os.path.join(outdir, out_fname) + with open(new_gtf_path, "w") as out_gtf: + out_gtf.write(corrected_gtf) + logger.error("An attempt to correct this GTF was made, the result is written to %s" % new_gtf_path) + logger.error("NB! some transcript / gene ids in the corrected annotation are modified.") logger.error("Provide a correct GTF by fixing the original input GTF or checking the corrected one.") exit(-3) else: @@ -167,6 +170,10 @@ def check_gtf_duplicates(gtf): handle = open(gtf, "rt") inner_ext = outer_ext + if inner_ext.lower() == 'gff3': + return check_gff3_duplicates(handle) + + gff3_checked = False for l in handle.readlines(): line_count += 1 if l.startswith("#"): @@ -177,8 +184,14 @@ def check_gtf_duplicates(gtf): corrected_gtf += l continue - feature_type = v[2] - attrs = v[8].split(" ") + attribute_column = v[8] + if not gff3_checked: + gff3_checked = True + if attribute_column.find("ID=") != -1: + handle.seek(0) + return check_gff3_duplicates(handle) + + attrs = attribute_column.split(" ") gene_id_pos = -1 for i in range(len(attrs)): @@ -190,6 +203,7 @@ def check_gtf_duplicates(gtf): gtf_correct = False continue + feature_type = v[2] gene_str = attrs[gene_id_pos + 1] start_pos = gene_str.find('"') end_pos = gene_str.rfind('"') @@ -259,6 +273,53 @@ def check_gtf_duplicates(gtf): return gtf_correct, corrected_gtf, gtf_name + ".corrected" + inner_ext.lower(), complete_genedb +def check_gff3_duplicates(handle): + gtf_correct = True + gene_count = 0 + transcript_count = 0 + line_count = 0 + feature_ids = {} + + for l in handle.readlines(): + line_count += 1 + if l.startswith("#"): + continue + v = l.strip().split("\t") + if len(v) < 9: + continue + + feature_type = v[2] + if feature_type == 'gene': + gene_count += 1 + elif feature_type in ["transcript", "mRNA"]: + transcript_count += 1 + + attrs = v[8].split(";") + id_pos = -1 + for i in range(len(attrs)): + if attrs[i].startswith('ID'): + id_pos = i + if id_pos == -1: + if feature_type in ["gene", "transcript", "mRNA"]: + logger.warning("Malformed GTF line %d (ID attribute value cannot be found)" % line_count) + logger.warning(l.strip()) + gtf_correct = False + continue + + id_str = attrs[id_pos] + id_value = id_str.split("=")[1] + if id_value in feature_ids: + logger.warning("Duplicated ID %s on line %d" % (id_value, line_count)) + gtf_correct = False + feature_ids[id_value] += 1 + + complete_genedb = 1 + if transcript_count == 0 or gene_count == 0: + complete_genedb = -1 + + return gtf_correct, None, None, complete_genedb + + def find_converted_db(converted_gtfs, gtf_filename, complete_genedb): gtf_mtime = converted_gtfs.get(gtf_filename, {}).get('gtf_mtime') db_mtime = converted_gtfs.get(gtf_filename, {}).get('db_mtime') diff --git a/src/intron_graph.py b/src/intron_graph.py index 7a459b82..64d12789 100644 --- a/src/intron_graph.py +++ b/src/intron_graph.py @@ -209,6 +209,8 @@ def signleton_dead_start(self, v): def get_outgoing(self, intron, v_type=None): res = [] + if intron not in self.outgoing_edges: + return res if v_type is None: for v in self.outgoing_edges[intron]: if v[0] >= 0: @@ -221,6 +223,8 @@ def get_outgoing(self, intron, v_type=None): def get_incoming(self, intron, v_type=None): res = [] + if intron not in self.incoming_edges: + return res if v_type is None: for v in self.incoming_edges[intron]: if v[0] >= 0: diff --git a/src/plot_output.py b/src/plot_output.py deleted file mode 100644 index 9dd754e0..00000000 --- a/src/plot_output.py +++ /dev/null @@ -1,256 +0,0 @@ -import os -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import numpy as np -import pprint - - -class PlotOutput: - def __init__( - self, - updated_gene_dict, - gene_names, - output_directory, - create_visualization_subdir=False, - reads_and_class=None, - filter_transcripts=None, - conditions=False, - use_counts=False, - ): - self.updated_gene_dict = updated_gene_dict - self.gene_names = gene_names - self.output_directory = output_directory - self.reads_and_class = reads_and_class - self.filter_transcripts = filter_transcripts - self.conditions = conditions - self.use_counts = use_counts - - # Create visualization subdirectory if specified - if create_visualization_subdir: - self.visualization_dir = os.path.join( - self.output_directory, "visualization" - ) - os.makedirs(self.visualization_dir, exist_ok=True) - else: - self.visualization_dir = self.output_directory - - def plot_transcript_map(self): - # Get the first condition's gene dictionary - first_condition = next(iter(self.updated_gene_dict)) - gene_dict = self.updated_gene_dict[first_condition] - - for gene_name in self.gene_names: - if gene_name in gene_dict: - gene_data = gene_dict[gene_name] - num_transcripts = len(gene_data["transcripts"]) - plot_height = max( - 3, num_transcripts * 0.3 - ) # Adjust the height dynamically - - fig, ax = plt.subplots( - figsize=(12, plot_height) - ) # Adjust height dynamically - - if self.filter_transcripts is not None: - ax.set_title( - f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']} with value over {self.filter_transcripts}" - ) - else: - ax.set_title( - f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']}" - ) - - ax.set_xlabel("Chromosomal position") - ax.set_ylabel("Transcripts") - ax.set_yticks(range(num_transcripts)) - ax.set_yticklabels( - [ - f"{transcript_id}" - for transcript_id in gene_data["transcripts"].keys() - ] - ) - - ax.xaxis.set_major_locator( - ticker.MaxNLocator(integer=True) - ) # Ensure genomic positions are integers - ax.xaxis.set_major_formatter( - ticker.FuncFormatter(lambda x, pos: f"{int(x)}") - ) # Format x-axis ticks as integers - - # Plot each transcript - for i, (transcript_id, transcript_info) in enumerate( - gene_data["transcripts"].items() - ): - # Determine the direction based on the gene's strand information - direction_marker = ">" if gene_data["strand"] == "+" else "<" - marker_pos = ( - transcript_info["end"] + 100 - if gene_data["strand"] == "+" - else transcript_info["start"] - 100 - ) - ax.plot( - marker_pos, - i, - marker=direction_marker, - markersize=5, - color="blue", - ) - - # Draw the line for the whole transcript - ax.plot( - [transcript_info["start"], transcript_info["end"]], - [i, i], - color="grey", - linewidth=2, - ) - - # Exon blocks - for exon in transcript_info["exons"]: - exon_length = exon["end"] - exon["start"] - ax.add_patch( - plt.Rectangle( - (exon["start"], i - 0.4), - exon_length, - 0.8, - color="skyblue", - ) - ) - - ax.set_xlim(gene_data["start"], gene_data["end"]) - ax.invert_yaxis() # First transcript at the top - - plt.tight_layout() - plot_path = os.path.join( - self.visualization_dir, f"{gene_name}_splicing.png" - ) - plt.savefig(plot_path) # Saving plot by gene name - plt.close(fig) - - def plot_transcript_usage(self): - """ - Visualize transcript usage for each gene in gene_names across different conditions. - """ - - for gene_name in self.gene_names: - gene_data = {} - for condition, genes in self.updated_gene_dict.items(): - if gene_name in genes: - gene_data[condition] = genes[gene_name]["transcripts"] - - if not gene_data: - print(f"Gene {gene_name} not found in the data.") - continue - - conditions = list(gene_data.keys()) - n_bars = len(conditions) - - fig, ax = plt.subplots(figsize=(12, 8)) - index = np.arange(n_bars) - bar_width = 0.35 - opacity = 0.8 - - # for sample_type, transcripts in gene_data.items(): - # print(f"Sample Type: {sample_type}") - # for transcript_id, transcript_info in transcripts.items(): - # print( - # f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}" - # ) - # Adjusting the colors for better within-bar comparison - max_transcripts = max(len(gene_data[condition]) for condition in conditions) - colors = plt.cm.plasma( - np.linspace(0, 1, num=max_transcripts) - ) # Using plasma for better color gradation - - bottom_val = np.zeros(n_bars) - for i, condition in enumerate(conditions): - transcripts = gene_data[condition] - for j, (transcript_id, transcript_info) in enumerate( - transcripts.items() - ): - color = colors[j % len(colors)] - value = transcript_info["value"] - plt.bar( - i, - float(value), - bar_width, - bottom=bottom_val[i], - alpha=opacity, - color=color, - label=transcript_id if i == 0 else "", - ) - bottom_val[i] += float(value) - - plt.xlabel("Sample Type") - plt.ylabel("Transcript Usage (TPM)") - plt.title(f"Transcript Usage for {gene_name} by Sample Type") - plt.xticks(index, conditions) - plt.legend( - title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc="upper left" - ) - - plt.tight_layout() - plot_path = os.path.join( - self.visualization_dir, - f"{gene_name}_transcript_usage_by_sample_type.png", - ) - plt.savefig(plot_path) - plt.close(fig) - - def make_pie_charts(self): - """ - Create pie charts for transcript alignment classifications and read assignment consistency. - Handles both combined and separate sample data structures. - """ - print("self.reads_and_class structure:") - pprint.pprint(self.reads_and_class) - - titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - - for title, data in zip(titles, self.reads_and_class): - if isinstance(data, dict): - if any(isinstance(v, dict) for v in data.values()): - # Separate 'Mutants' and 'WildType' case - for sample_name, sample_data in data.items(): - self._create_pie_chart(f"{title} - {sample_name}", sample_data) - else: - # Combined data case - self._create_pie_chart(title, data) - else: - print(f"Skipping unexpected data type for {title}: {type(data)}") - - def _create_pie_chart(self, title, data): - """ - Helper method to create a single pie chart. - """ - labels = list(data.keys()) - sizes = list(data.values()) - total = sum(sizes) - - # Generate a file-friendly title - file_title = title.lower().replace(" ", "_").replace("-", "_") - - plt.figure(figsize=(12, 8)) - wedges, texts, autotexts = plt.pie( - sizes, - labels=labels, - autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", - startangle=140, - textprops=dict(color="w"), - ) - plt.setp(autotexts, size=8, weight="bold") - plt.setp(texts, size=7) - - plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. - plt.title(f"{title}\nTotal: {total}") - - plt.legend( - wedges, - labels, - title="Categories", - loc="center left", - bbox_to_anchor=(1, 0, 0.5, 1), - fontsize=8, - ) - plot_path = os.path.join(self.visualization_dir, f"{file_title}_pie_chart.png") - plt.savefig(plot_path, bbox_inches="tight", dpi=300) - plt.close() diff --git a/src/post_process.py b/src/post_process.py deleted file mode 100644 index 94a1102f..00000000 --- a/src/post_process.py +++ /dev/null @@ -1,609 +0,0 @@ -import csv -import os -import pickle -import gzip -import shutil -import copy -import json -from argparse import Namespace -import tempfile -import gffutils -import yaml - - -class OutputConfig: - """Class to build dictionaries from the output files of the pipeline.""" - - def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): - self.output_directory = output_directory - self.log_details = {} - self.extended_annotation = None - self.read_assignments = None - self.input_gtf = gtf # Initialize with the provided gtf flag - self.genedb_filename = None - self.yaml_input = True - self.yaml_input_path = None - self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. - self.conditions = False - self.gene_grouped_counts = None - self.transcript_grouped_counts = None - self.transcript_grouped_tpm = None - self.gene_grouped_tpm = None - self.gene_counts = None - self.transcript_counts = None - self.gene_tpm = None - self.transcript_tpm = None - self.transcript_model_counts = None - self.transcript_model_tpm = None - self.transcript_model_grouped_tpm = None - self.transcript_model_grouped_counts = None - self.use_counts = use_counts - self.ref_only = ref_only - - self._load_params_file() - self._find_files() - self._conditional_unzip() - - # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log - if self.ref_only and not self.input_gtf: - raise ValueError( - "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." - ) - - def _load_params_file(self): - """Load the .params file for necessary configuration and commands.""" - params_path = os.path.join(self.output_directory, ".params") - assert os.path.exists(params_path), f"Params file not found: {params_path}" - try: - with open(params_path, "rb") as file: - params = pickle.load(file) - if isinstance(params, Namespace): - self._process_params(vars(params)) - else: - print("Unexpected params format.") - except Exception as e: - raise ValueError(f"An error occurred while loading params: {e}") - - def _process_params(self, params): - """Process parameters loaded from the .params file.""" - self.log_details["gene_db"] = params.get("genedb") - self.log_details["fastq_used"] = bool(params.get("fastq")) - self.input_gtf = self.input_gtf or params.get("genedb") - self.genedb_filename = params.get("genedb_filename") - - if params.get("yaml"): - # YAML input case - self.yaml_input = True - self.yaml_input_path = params.get("yaml") - # Keep the output_directory as is, don't modify it - else: - # Non-YAML input case - self.yaml_input = False - processing_sample = params.get("prefix") - if processing_sample: - self.output_directory = os.path.join( - self.output_directory, processing_sample - ) - else: - raise ValueError( - "Processing sample directory not found in params for non-YAML input." - ) - - def _conditional_unzip(self): - """Check if unzip is needed and perform it conditionally based on the model use.""" - if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): - self.input_gtf = self._unzip_file(self.input_gtf) - if not self.input_gtf: - raise FileNotFoundError( - f"Unable to find or unzip the specified file: {self.input_gtf}" - ) - - def _unzip_file(self, file_path): - """Unzip a gzipped file and return the path to the uncompressed file.""" - new_path = file_path[:-3] # Remove .gz extension - - if os.path.exists(new_path): - # print(f"File {new_path} already exists, using this file.") - return new_path - - if not os.path.exists(file_path): - self.gtf_flag_needed = True - return None - - with gzip.open(file_path, "rb") as f_in: - with open(new_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - print(f"File {file_path} was decompressed to {new_path}.") - - return new_path - - def _find_files(self): - """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" - if self.yaml_input: - self.conditions = True - self.ref_only = True - self._find_files_from_yaml() - return # Exit the method after processing YAML input - - if not os.path.exists(self.output_directory): - print(f"Directory not found: {self.output_directory}") # Debugging output - raise FileNotFoundError( - f"Specified sample subdirectory does not exist: {self.output_directory}" - ) - - for file_name in os.listdir(self.output_directory): - if file_name.endswith(".extended_annotation.gtf"): - self.extended_annotation = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".read_assignments.tsv"): - self.read_assignments = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".read_assignments.tsv.gz"): - self.read_assignments = self._unzip_file( - os.path.join(self.output_directory, file_name) - ) - elif file_name.endswith(".gene_grouped_counts.tsv"): - self.conditions = True - self.gene_grouped_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_grouped_counts.tsv"): - self.transcript_grouped_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_grouped_tpm.tsv"): - self.transcript_grouped_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".gene_grouped_tpm.tsv"): - self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".gene_counts.tsv"): - self.gene_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_counts.tsv"): - self.transcript_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".gene_tpm.tsv"): - self.gene_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_tpm.tsv"): - self.transcript_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_model_counts.tsv"): - self.transcript_model_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_tpm.tsv"): - self.transcript_model_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_grouped_tpm.tsv"): - self.transcript_model_grouped_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_grouped_counts.tsv"): - self.transcript_model_grouped_counts = os.path.join( - self.output_directory, file_name - ) - - # Determine if GTF flag is needed - if ( - not self.input_gtf - or not os.path.exists(self.input_gtf) - and not os.path.exists(self.input_gtf + ".gz") - and self.ref_only - ): - self.gtf_flag_needed = True - - # Set ref_only default based on the availability of extended_annotation - if self.ref_only is None: - self.ref_only = not self.extended_annotation - - def _find_files_from_yaml(self): - """Locate the necessary files in the directory, set specific grouped count and TPM files, and process read assignments.""" - if not os.path.exists(self.yaml_input_path): - print(f"YAML file not found: {self.yaml_input_path}") - raise FileNotFoundError( - f"Specified YAML file does not exist: {self.yaml_input_path}" - ) - - # Set the four specific attributes - self.gene_grouped_counts = os.path.join( - self.output_directory, "combined_gene_counts.tsv" - ) - self.transcript_grouped_counts = os.path.join( - self.output_directory, "combined_transcript_counts.tsv" - ) - self.transcript_grouped_tpm = os.path.join( - self.output_directory, "combined_transcript_tpm.tsv" - ) - self.gene_grouped_tpm = os.path.join( - self.output_directory, "combined_gene_tpm.tsv" - ) - - # Check if the files exist - for attr in [ - "gene_grouped_counts", - "transcript_grouped_counts", - "transcript_grouped_tpm", - "gene_grouped_tpm", - ]: - file_path = getattr(self, attr) - if not os.path.exists(file_path): - print(f"Warning: {attr} file not found at {file_path}") - setattr(self, attr, None) - - # Initialize read_assignments list - self.read_assignments = [] - - # Read and process the YAML file - with open(self.yaml_input_path, "r") as yaml_file: - yaml_data = yaml.safe_load(yaml_file) - - # Check if yaml_data is a list - if isinstance(yaml_data, list): - samples = yaml_data - else: - # If it's not a list, assume it's a dictionary with a 'samples' key - samples = yaml_data.get("samples", []) - - for sample in samples: - name = sample.get("name") - if name: - sample_dir = os.path.join(self.output_directory, name) - - # Check for .read_assignments.tsv.gz - gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") - if os.path.exists(gz_file): - unzipped_file = self._unzip_file(gz_file) - if unzipped_file: - self.read_assignments.append((name, unzipped_file)) - else: - print(f"Warning: Failed to unzip {gz_file}") - else: - # Check for .read_assignments.tsv - non_gz_file = os.path.join( - sample_dir, f"{name}.read_assignments.tsv" - ) - if os.path.exists(non_gz_file): - self.read_assignments.append((name, non_gz_file)) - else: - print(f"Warning: No read assignments file found for {name}") - - if not self.read_assignments: - print("Warning: No read assignment files found for any samples") - - -class DictionaryBuilder: - """Class to build dictionaries from the output files of the pipeline.""" - - def __init__(self, config): - self.config = config - - def build_gene_transcript_exon_dictionaries(self): - """Builds dictionaries of genes, transcripts, and exons from the GTF file.""" - if self.config.extended_annotation and not self.config.ref_only: - return self.parse_extended_annotation() - else: - return self.parse_input_gtf() - - def build_read_assignment_and_classification_dictionaries(self): - """Indexes classifications and assignment types from read_assignments.tsv file(s).""" - if not self.config.read_assignments: - raise FileNotFoundError("No read assignments file(s) found.") - - if isinstance(self.config.read_assignments, list): - # YAML input case (multiple files) - classification_counts_dict = {} - assignment_type_counts_dict = {} - for sample_name, read_assignment_file in self.config.read_assignments: - classification_counts, assignment_type_counts = ( - self._process_read_assignment_file(read_assignment_file) - ) - classification_counts_dict[sample_name] = classification_counts - assignment_type_counts_dict[sample_name] = assignment_type_counts - return classification_counts_dict, assignment_type_counts_dict - else: - # Non-YAML input case (single file) - return self._process_read_assignment_file(self.config.read_assignments) - - def _process_read_assignment_file(self, file_path): - classification_counts = {} - assignment_type_counts = {} - - with open(file_path, "r") as file: - # Skip header lines - for _ in range(3): - next(file, None) - - for line in file: - parts = line.strip().split("\t") - if len(parts) < 6: - continue - - additional_info = parts[-1] - classification = ( - additional_info.split("Classification=")[-1].split(";")[0].strip() - ) - assignment_type = parts[5] - - classification_counts[classification] = ( - classification_counts.get(classification, 0) + 1 - ) - assignment_type_counts[assignment_type] = ( - assignment_type_counts.get(assignment_type, 0) + 1 - ) - - return classification_counts, assignment_type_counts - - def parse_input_gtf(self): - """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" - gene_dict = {} - if not self.config.genedb_filename: - # convert GTF to DB if we use previous IsoQuant runs - # remove this functionality later - tmp_file = tempfile.NamedTemporaryFile(suffix=".db") - self.config.genedb_filename = tmp_file.name - input_gtf_path = self.config.input_gtf - gffutils.create_db( - input_gtf_path, - dbfn=self.config.genedb_filename, - force=True, - keep_order=True, - merge_strategy="merge", - sort_attribute_values=True, - disable_infer_genes=True, - disable_infer_transcripts=True, - ) - - try: - # Create a database without using a context manager - db = gffutils.FeatureDB(self.config.genedb_filename) - - for gene in db.features_of_type("gene"): - gene_id = gene.id - gene_dict[gene_id] = { - "chromosome": gene.seqid, - "start": gene.start, - "end": gene.end, - "strand": gene.strand, - "name": gene.attributes.get("gene_name", [""])[0], - "biotype": gene.attributes.get("gene_biotype", [""])[0], - "transcripts": {}, - } - - for transcript in db.children(gene, featuretype="transcript"): - transcript_id = transcript.id - gene_dict[gene_id]["transcripts"][transcript_id] = { - "start": transcript.start, - "end": transcript.end, - "name": transcript.attributes.get("transcript_name", [""])[0], - "biotype": transcript.attributes.get( - "transcript_biotype", [""] - )[0], - "exons": [], - "tags": transcript.attributes.get("tag", [""])[0].split(","), - } - - for exon in db.children(transcript, featuretype="exon"): - exon_info = { - "exon_id": exon.id, - "start": exon.start, - "end": exon.end, - "number": exon.attributes.get("exon_number", [""])[0], - } - gene_dict[gene_id]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) - - except Exception as e: - raise Exception(f"Error parsing GTF file: {str(e)}") - - return gene_dict - - def parse_extended_annotation(self): - """Parses the GTF file to build a detailed dictionary of genes, transcripts, and exons.""" - gene_dict = {} - if not self.config.extended_annotation: - raise FileNotFoundError("Extended annotation GTF file is missing.") - - with open(self.config.extended_annotation, "r") as file: - for line in file: - if line.startswith("#") or not line.strip(): - continue - fields = line.strip().split("\t") - if len(fields) < 9: - print( - f"Skipping malformed line due to insufficient fields: {line.strip()}" - ) - continue - - info_fields = fields[8].strip(";").split(";") - details = { - field.strip().split(" ")[0]: field.strip().split(" ")[1].strip('"') - for field in info_fields - if " " in field - } - - try: - if fields[2] == "gene": - gene_id = details["gene_id"] - gene_dict[gene_id] = { - "chromosome": fields[0], - "start": int(fields[3]), - "end": int(fields[4]), - "strand": fields[6], - "name": details.get("gene_name", ""), - "biotype": details.get("gene_biotype", ""), - "transcripts": {}, - } - elif fields[2] == "transcript": - transcript_id = details["transcript_id"] - gene_dict[details["gene_id"]]["transcripts"][transcript_id] = { - "start": int(fields[3]), - "end": int(fields[4]), - "exons": [], - } - elif fields[2] == "exon": - transcript_id = details["transcript_id"] - exon_info = { - "exon_id": details["exon_id"], - "start": int(fields[3]), - "end": int(fields[4]), - } - gene_dict[details["gene_id"]]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) - except KeyError as e: - print(f"Key error in line: {line.strip()} | Missing key: {e}") - return gene_dict - - def update_gene_dict(self, gene_dict, value_df): - new_dict = {} - gene_values = {} - - # Read gene counts from value_df - with open(value_df, "r") as file: - reader = csv.reader(file, delimiter="\t") - header = next(reader) - conditions = header[1:] # Assumes the first column is gene ID - - # Initialize gene_values dictionary - for row in reader: - gene_id = row[0] - gene_values[gene_id] = {} - for i, condition in enumerate(conditions): - if len(row) > i + 1: - value = float(row[i + 1]) - else: - value = 0.0 # Default to 0 if no value - gene_values[gene_id][condition] = value - - # Build the new dictionary structure by conditions - for condition in conditions: - new_dict[condition] = {} # Create a new sub-dictionary for each condition - - # Deep copy the gene_dict and update with values from value_df - for gene_id, gene_info in gene_dict.items(): - new_dict[condition][gene_id] = copy.deepcopy(gene_info) - if gene_id in gene_values and condition in gene_values[gene_id]: - new_dict[condition][gene_id]["value"] = gene_values[gene_id][ - condition - ] - else: - new_dict[condition][gene_id][ - "value" - ] = 0 # Default to 0 if the gene_id has no corresponding value - - return new_dict - - def update_transcript_values(self, gene_dict, value_df): - new_dict = copy.deepcopy(gene_dict) # Preserve the original structure - transcript_values = {} - - # Load transcript counts from value_df - with open(value_df, "r") as file: - reader = csv.reader(file, delimiter="\t") - header = next(reader) - conditions = header[1:] # Assumes the first column is transcript ID - - for row in reader: - transcript_id = row[0] - for i, condition in enumerate(conditions): - if len(row) > i + 1: - value = float(row[i + 1]) - else: - value = 0.0 # Default to 0 if no value - if transcript_id not in transcript_values: - transcript_values[transcript_id] = {} - transcript_values[transcript_id][condition] = value - - # Update each condition without restructuring the original dictionary - for condition in conditions: - if condition not in new_dict: - new_dict[condition] = copy.deepcopy( - gene_dict - ) # Make sure all genes are present - - for gene_id, gene_info in new_dict[condition].items(): - if "transcripts" in gene_info: - for transcript_id, transcript_info in gene_info[ - "transcripts" - ].items(): - if ( - transcript_id in transcript_values - and condition in transcript_values[transcript_id] - ): - transcript_info["value"] = transcript_values[transcript_id][ - condition - ] - else: - transcript_info["value"] = ( - 0 # Set default if no value for this transcript - ) - return new_dict - - def update_gene_names(self, gene_dict): - updated_dict = {} - for condition, genes in gene_dict.items(): - updated_genes = {} - for gene_id, gene_info in genes.items(): - if gene_info["name"]: - gene_name_upper = gene_info["name"].upper() - updated_genes[gene_name_upper] = gene_info - else: - # If name is empty, use the original gene_id - updated_genes[gene_id] = gene_info - updated_dict[condition] = updated_genes - return updated_dict - - def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): - # Dictionary to hold genes and transcripts that meet the criteria - transcript_passes_threshold = {} - - # First pass: Determine which transcripts meet the minimum value requirement in any condition - for condition, genes in gene_dict.items(): - for gene_id, gene_info in genes.items(): - for transcript_id, transcript_info in gene_info["transcripts"].items(): - if ( - "value" in transcript_info - and transcript_info["value"] != "NA" - and float(transcript_info["value"]) >= min_value - ): - if gene_id not in transcript_passes_threshold: - transcript_passes_threshold[gene_id] = {} - transcript_passes_threshold[gene_id][transcript_id] = True - - # Second pass: Build the filtered dictionary including only transcripts that have eligible values in any condition - filtered_dict = {} - for condition, genes in gene_dict.items(): - filtered_genes = {} - for gene_id, gene_info in genes.items(): - if gene_id in transcript_passes_threshold: - eligible_transcripts = { - transcript_id: transcript_info - for transcript_id, transcript_info in gene_info[ - "transcripts" - ].items() - if transcript_id in transcript_passes_threshold[gene_id] - } - if ( - eligible_transcripts - ): # Only add genes with non-empty transcript sets - filtered_gene_info = copy.deepcopy(gene_info) - filtered_gene_info["transcripts"] = eligible_transcripts - filtered_genes[gene_id] = filtered_gene_info - if filtered_genes: # Only add conditions with non-empty gene sets - filtered_dict[condition] = filtered_genes - - return filtered_dict - - def read_gene_list(self, gene_list_path): - with open(gene_list_path, "r") as file: - gene_list = [ - line.strip().upper() for line in file - ] # Convert each gene to uppercase - return gene_list - - def save_gene_dict_to_json(self, gene_dict, output_path): - """Saves the gene dictionary to a JSON file.""" - # name the gene_dict file - output_path = os.path.join(output_path, "gene_dict.json") - with open(output_path, "w") as file: - json.dump(gene_dict, file, indent=4) diff --git a/src/process_dict.py b/src/process_dict.py deleted file mode 100644 index bb3ca001..00000000 --- a/src/process_dict.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import sys -import os - - -def simplify_and_sum_transcripts(data): - gene_totals_across_conditions = {} - simplified_data = {} - - # Sum transcript values and collect them across all conditions - for sample_id, genes in data.items(): - simplified_data[sample_id] = {} - for gene_id, gene_data in genes.items(): - transcripts = gene_data.get("transcripts", {}) - total_value = 0.0 - simplified_transcripts = {} - for transcript_id, transcript_details in transcripts.items(): - transcript_value = ( - transcript_details.get("value", 0.0) - if isinstance(transcript_details, dict) - else 0.0 - ) - simplified_transcripts[transcript_id] = transcript_value - total_value += transcript_value - - gene_data_copy = ( - gene_data.copy() - ) # Make a copy to avoid modifying the original - gene_data_copy["transcripts"] = simplified_transcripts - gene_data_copy["value"] = ( - total_value # Replace the gene-level value with the sum of transcript values - ) - simplified_data[sample_id][gene_id] = gene_data_copy - - if gene_id not in gene_totals_across_conditions: - gene_totals_across_conditions[gene_id] = [] - gene_totals_across_conditions[gene_id].append(total_value) - - # Determine which genes to remove - genes_to_remove = [ - gene_id - for gene_id, totals in gene_totals_across_conditions.items() - if all(total < 5 for total in totals) - ] - - # Remove genes from the simplified data structure - for sample_id, genes in simplified_data.items(): - for gene_id in genes_to_remove: - if gene_id in genes: - del genes[gene_id] - - return simplified_data - - -def read_json(file_path): - with open(file_path, "r") as file: - return json.load(file) - - -def write_json(data, file_path): - with open(file_path, "w") as file: - json.dump(data, file, indent=4) - - -def main(): - if len(sys.argv) != 2: - print("Usage: python script.py ") - sys.exit(1) - - input_file_path = sys.argv[1] - base, ext = os.path.splitext(input_file_path) - output_file_path = f"{base}_simplified{ext}" - - try: - # Load the gene data from the specified input JSON file - gene_dict = read_json(input_file_path) - - # Simplify the transcripts, sum their values, and remove genes under a threshold across all conditions - modified_gene_dict = simplify_and_sum_transcripts(gene_dict) - - # Save the modified gene data to the newly named output JSON file - write_json(modified_gene_dict, output_file_path) - - print(f"Modified gene data has been saved to {output_file_path}") - - except Exception as e: - print(f"Error: {str(e)}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/src/transcript_printer.py b/src/transcript_printer.py index e84b8a52..3e00aa2e 100644 --- a/src/transcript_printer.py +++ b/src/transcript_printer.py @@ -99,8 +99,8 @@ def dump(self, gene_info, transcript_model_storage): for gene_id, coords in gene_order: if gene_id not in self.printed_gene_ids: gene_additiional_info = "" - if gene_info and gene_id in gene_info.gene_attributes: - gene_additiional_info = gene_info.gene_attributes[gene_id] + if gene_info and gene_id in gene_info.feature_attributes: + gene_additiional_info = gene_info.feature_attributes[gene_id] source = "IsoQuant" if gene_info and gene_id in gene_info.sources: source = gene_info.sources[gene_id] @@ -117,8 +117,8 @@ def dump(self, gene_info, transcript_model_storage): if not model.check_additional("exons"): model.add_additional_attribute("exons", str(len(model.exon_blocks))) transcript_additiional_info = "" - if gene_info and model.transcript_id in gene_info.gene_attributes: - transcript_additiional_info = " " + gene_info.gene_attributes[model.transcript_id] + if gene_info and model.transcript_id in gene_info.feature_attributes: + transcript_additiional_info = " " + gene_info.feature_attributes[model.transcript_id] transcript_line = '%s\t%s\ttranscript\t%d\t%d\t.\t%s\t.\tgene_id "%s"; transcript_id "%s"; %s\n' \ % (model.chr_id, model.source, model.exon_blocks[0][0], model.exon_blocks[-1][1], @@ -137,9 +137,14 @@ def dump(self, gene_info, transcript_model_storage): exons_to_print = sorted(exons_to_print, reverse=True) if model.strand == '-' else sorted(exons_to_print) for i, e in enumerate(exons_to_print): exon_str_id = self.exon_id_storage.get_id(model.chr_id, e, model.strand) + + exon_id = model.transcript_id + "_%d_%d_%s" % (e[0], e[1], model.strand) + exon_additiional_info = "" + if gene_info and exon_id in gene_info.feature_attributes: + exon_additiional_info = " " + gene_info.feature_attributes[model.transcript_id] feature_type = e[2] self.out_gff.write(prefix_columns + "%s\t%d\t%d\t" % (feature_type, e[0], e[1]) + suffix_columns + - ' exon "%d"; exon_id "%s";\n' % ((i + 1), exon_str_id)) + ' exon_number "%d"; exon_id "%s"; %s\n' % ((i + 1), exon_str_id, exon_additiional_info)) self.out_gff.flush() def dump_read_assignments(self, transcript_model_constructor): diff --git a/src/visualization_cache_utils.py b/src/visualization_cache_utils.py new file mode 100644 index 00000000..742bd8e2 --- /dev/null +++ b/src/visualization_cache_utils.py @@ -0,0 +1,262 @@ +import pickle +import logging +import time +from pathlib import Path +from typing import Dict, Any, Optional, Union +import random +import re +import hashlib + + +def build_gene_dict_cache_file( + extended_annotation: Optional[str], input_gtf: str, ref_only: bool, cache_dir: Path +) -> Path: + """ + Generate a gene dictionary cache filename based on: + - Which annotation file we're using (extended vs. reference GTF). + - The modification time of that file. + - The ref_only setting. + """ + if extended_annotation and not ref_only: + source_file = Path(extended_annotation) + source_type = "extended" + else: + source_file = Path(input_gtf) + source_type = "reference" + mtime = source_file.stat().st_mtime + cache_name = f"gene_dict_cache_{source_type}_{source_file.name}_{mtime}_ref_only_{ref_only}.pkl" + return cache_dir / cache_name + + +def build_read_assignment_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path +) -> Path: + """ + Generate a read-assignment cache filename based on: + - The read assignment file(s). + - Possibly their modification times. + - The ref_only setting. + """ + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"read_assignment_cache_{source_file.name}_{mtime}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + # Build a composite name from the multiple input files + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append( + f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}" + ) + composite_name = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"read_assignment_cache_multi_{composite_name}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "read_assignment_cache_default.pkl" + + +def _hash_list(values: list) -> str: + try: + s = ",".join(map(str, values)) + m = hashlib.md5() + m.update(s.encode('utf-8')) + return m.hexdigest()[:12] + except Exception: + # Fallback to length-based signature + return f"len{len(values)}" + + +def build_length_effects_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path, bin_labels: list +) -> Path: + """ + Cache name for read-length effects aggregates. Includes input files, mtimes, ref_only, and bin label signature. + """ + bins_sig = _hash_list(bin_labels) + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"length_effects_cache_{source_file.name}_{mtime}_bins_{bins_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append(f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}") + composite = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"length_effects_cache_multi_{composite}_bins_{bins_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "length_effects_cache_default.pkl" + + +def build_length_hist_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path, bin_edges: list +) -> Path: + """ + Cache name for read-length histogram. Includes input files, mtimes, ref_only, and bin edges signature. + """ + edges_sig = _hash_list(bin_edges) + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"length_hist_cache_{source_file.name}_{mtime}_edges_{edges_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append(f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}") + composite = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"length_hist_cache_multi_{composite}_edges_{edges_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "length_hist_cache_default.pkl" + + +def save_cache(cache_file: Path, data_to_cache: Any) -> None: + """Save data to a cache file using pickle.""" + try: + with open(cache_file, 'wb') as f: + pickle.dump(data_to_cache, f, protocol=pickle.HIGHEST_PROTOCOL) # Save the entire tuple + logging.debug(f"Successfully saved cache to {cache_file}") + except Exception as e: + logging.error(f"Error saving cache to {cache_file}: {e}") + + +def load_cache(cache_file: Path) -> Any: + """Load data from a cache file.""" + try: + with open(cache_file, 'rb') as f: + cached_data = pickle.load(f) + if isinstance(cached_data, tuple) and len(cached_data) == 3: # Check if it's the new tuple format + gene_dict, novel_gene_ids, novel_transcript_ids = cached_data # Unpack tuple + return gene_dict, novel_gene_ids, novel_transcript_ids # Return the tuple + else: # Handle old cache format (just gene_dict) + return cached_data # Return just the gene_dict for backward compatibility + except FileNotFoundError: + logging.debug(f"Cache file not found: {cache_file}") + return None # Indicate cache miss + except Exception as e: + logging.error(f"Error loading cache from {cache_file}: {e}") + return None + + +def validate_gene_dict(gene_dict: Dict, ref_only: bool = False) -> bool: + """Enhanced validation with novel gene check.""" + if not gene_dict: + return False + + # Always check for novel genes regardless of ref_only mode + novel_genes = sum(1 for condition in gene_dict.values() + for gene_id in condition.keys() + if re.match(r"novel_gene", gene_id)) + if novel_genes > 0: + logging.warning(f"Found {novel_genes} novel genes in cached dictionary. Rebuilding required.") + return False + + # Existing structure validation + try: + for condition in gene_dict.values(): + for gene_info in condition.values(): + if not all(k in gene_info for k in ["chromosome", "start", "end", "strand", "transcripts"]): + return False + return True + except (KeyError, AttributeError): + return False + + +def validate_read_assignment_data( + data: Any, read_assignments: Union[str, list] +) -> bool: + """ + Validate the structure of the cached read-assignment data. + """ + try: + if isinstance(read_assignments, list): + # Expecting something like: + # { + # "classification_counts": { "sampleA": {...}, "sampleB": {...} }, + # "assignment_type_counts": { "sampleA": {...}, "sampleB": {...} } + # } + if not isinstance(data, dict): + return False + if ( + "classification_counts" not in data + or "assignment_type_counts" not in data + ): + return False + return True + else: + # Single file scenario: We expect a 2-tuple (classification_counts, assignment_type_counts) + if not isinstance(data, (tuple, list)) or len(data) != 2: + return False + return True + except Exception as e: + logging.error(f"Read-assignment validation error: {e}") + return False + + +def validate_length_effects_data(data: Any, expected_bins: Optional[list] = None) -> bool: + try: + required = [ + 'bins', 'by_bin_assignment', 'by_bin_classification', + 'assignment_keys', 'classification_keys', 'totals' + ] + if not isinstance(data, dict): + return False + if any(k not in data for k in required): + return False + if expected_bins and data.get('bins') != expected_bins: + return False + # Basic shape checks + if not isinstance(data['by_bin_assignment'], dict): return False + if not isinstance(data['by_bin_classification'], dict): return False + if not isinstance(data['totals'], dict): return False + return True + except Exception as e: + logging.error(f"Length-effects validation error: {e}") + return False + + +def validate_length_hist_data(data: Any, expected_edges: Optional[list] = None) -> bool: + try: + if not isinstance(data, dict): + return False + if any(k not in data for k in ['edges', 'counts', 'total']): + return False + if expected_edges and list(map(int, data.get('edges', []))) != list(map(int, expected_edges)): + return False + return True + except Exception as e: + logging.error(f"Length-hist validation error: {e}") + return False + + +def cleanup_cache(cache_dir: Path, max_age_days: int = 7) -> None: + """ + Remove cache files older than specified days. + """ + current_time = time.time() + for cache_file in cache_dir.glob("*.pkl"): + file_age_days = (current_time - cache_file.stat().st_mtime) / (24 * 3600) + if file_age_days > max_age_days: + try: + cache_file.unlink() + logging.info(f"Removed old cache file: {cache_file}") + except Exception as e: + logging.warning(f"Failed to remove cache file {cache_file}: {e}") diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py new file mode 100644 index 00000000..70332e57 --- /dev/null +++ b/src/visualization_dictionary_builder.py @@ -0,0 +1,909 @@ +import copy +import gffutils +import pandas as pd +import re +import logging +from pathlib import Path +from typing import Dict, Any, List, Union, Tuple +import numpy as np + +from src.visualization_cache_utils import ( + build_gene_dict_cache_file, + save_cache, + load_cache, + validate_gene_dict, + cleanup_cache, +) +from src.visualization_read_assignment_io import ( + get_read_assignment_counts, + get_read_length_effects, + get_read_length_histogram, +) + + +class DictionaryBuilder: + def __init__(self, config): + self.config = config + self.cache_dir = Path(config.output_directory) / ".cache" + self.cache_dir.mkdir(exist_ok=True) + + # Set up logger for DictionaryBuilder + self.logger = logging.getLogger('IsoQuant.visualization.dictionary_builder') + self.logger.setLevel(logging.INFO) + + # Initialize sets to store novel gene and transcript IDs + self.novel_gene_ids = set() + self.novel_transcript_ids = set() + + # Clean up old cache files on init + cleanup_cache(self.cache_dir, max_age_days=7) + + def build_gene_dict_with_expression_and_filter( + self, + min_value: float = 1.0, + reference_conditions: List[str] = None, + target_conditions: List[str] = None, + ) -> Dict[str, Any]: + """ + Optimized build process with filtering based on selected conditions. + Filters transcripts based on min_value occurring in at least one of the + selected reference_conditions or target_conditions. + Caches the resulting dictionary based on the specific conditions used. + """ + self.logger.debug("=== DICTIONARY BUILD PROCESS DEBUG ===") + self.logger.debug(f"Starting dictionary build:") + self.logger.debug(f" min_value: {min_value}") + self.logger.debug(f" reference_conditions: {reference_conditions}") + self.logger.debug(f" target_conditions: {target_conditions}") + self.logger.debug(f" config.ref_only: {self.config.ref_only}") + self.logger.debug(f" config.extended_annotation: {getattr(self.config, 'extended_annotation', 'NOT_SET')}") + + # 1. Load full TPM matrix to determine available conditions first + tpm_file = self._get_tpm_file() + self.logger.debug(f"Loading full TPM matrix from {tpm_file}") + try: + tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) + tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] # Clean headers + tpm_df = tpm_df.set_index('feature_id') # Use cleaned column name + except KeyError as e: + self.logger.error(f"Missing required column ('feature_id' or condition name) in {tpm_file}: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Failed to load TPM expression matrix: {str(e)}") + raise + + available_conditions = sorted(tpm_df.columns.tolist()) # Sort for consistent cache key + self.logger.debug(f"Available conditions in TPM file: {available_conditions}") + + # 2. Determine the actual conditions to process and create a cache key + requested_conditions = set(reference_conditions or []) | set(target_conditions or []) + if requested_conditions: + conditions_to_process = sorted(list(requested_conditions.intersection(available_conditions))) + missing_conditions = requested_conditions.difference(available_conditions) + if missing_conditions: + self.logger.warning(f"Requested conditions not found in TPM file and will be ignored: {missing_conditions}") + if not conditions_to_process: + self.logger.error("None of the requested conditions were found in the TPM file. Cannot proceed.") + return {} + self.logger.debug(f"Processing conditions: {conditions_to_process}") + else: + self.logger.debug("No specific conditions requested, processing all available conditions.") + conditions_to_process = available_conditions # Already sorted + + # Create a deterministic cache key based on conditions + condition_key_part = "_".join(c.replace(" ", "_") for c in conditions_to_process) + if len(condition_key_part) > 50: # Avoid excessively long filenames + condition_key_part = f"hash_{hash(condition_key_part)}" + + # 3. Check cache specific to these conditions and min_value + base_cache_file = build_gene_dict_cache_file( # Keep base name generation consistent + self.config.extended_annotation, + tpm_file, # Use original TPM file path for base name consistency + self.config.ref_only, + self.cache_dir, + ) + # Append condition and min_value specifics + condition_specific_cache_file = base_cache_file.parent / ( + f"{base_cache_file.stem}_conditions_{condition_key_part}_minval_{min_value}.pkl" + ) + self.logger.debug(f"Looking for cache file: {condition_specific_cache_file}") + + if condition_specific_cache_file.exists(): + self.logger.info(f"Loading data from cache: {condition_specific_cache_file}") + cached_data = load_cache(condition_specific_cache_file) + # Expecting (dict, novel_genes_set, novel_transcripts_set) + if cached_data and isinstance(cached_data, tuple) and len(cached_data) == 3: + cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data + # Basic validation - check if it's a dict and has expected top-level keys (conditions) + if isinstance(cached_gene_dict, dict) and all(c in cached_gene_dict for c in conditions_to_process): + # Deeper validation might be needed if structure is complex + if validate_gene_dict(cached_gene_dict): # Reuse existing validation if suitable + self.novel_gene_ids = cached_novel_gene_ids + self.novel_transcript_ids = cached_novel_transcript_ids + self.logger.debug("Successfully loaded dictionary from cache.") + return cached_gene_dict + else: + self.logger.warning("Cached dictionary failed validation. Rebuilding.") + else: + self.logger.warning("Cached data format mismatch or missing conditions. Rebuilding.") + else: + self.logger.warning("Cached data is invalid or in old format. Rebuilding.") + + # 4. Cache miss or invalid: Build dictionary from scratch for the specified conditions + self.logger.info("Building dictionary from scratch for selected conditions.") + + # Parse GTF and filter novel genes (only needs to be done once) + self.logger.info("Parsing GTF and filtering novel genes") + parsed_data = self.parse_gtf() + self._validate_gene_structure(parsed_data) # Validate base structure + base_gene_dict = self._filter_novel_genes(parsed_data) # Also populates self.novel_gene_ids etc. + + # Subset TPM matrix to *only* the conditions being processed for filtering + tpm_df_subset = tpm_df[conditions_to_process] + + # Identify valid transcripts based on max value within the SUBSET conditions + transcript_max_values_subset = tpm_df_subset.max(axis=1) + valid_transcripts = set( + transcript_max_values_subset[transcript_max_values_subset >= min_value].index + ) + + # Debug: Analyze what transcripts passed the expression filter + total_transcripts_in_tpm = len(transcript_max_values_subset) + novel_transcripts_in_tpm = sum(1 for tx_id in transcript_max_values_subset.index if tx_id.startswith("transcript")) + ensembl_transcripts_in_tpm = sum(1 for tx_id in transcript_max_values_subset.index if tx_id.startswith("ENSMUST")) + + novel_transcripts_passed = sum(1 for tx_id in valid_transcripts if tx_id.startswith("transcript")) + ensembl_transcripts_passed = sum(1 for tx_id in valid_transcripts if tx_id.startswith("ENSMUST")) + + # Show sample transcripts that passed/failed + sample_novel_passed = [tx_id for tx_id in valid_transcripts if tx_id.startswith("transcript")][:5] + sample_novel_failed = [tx_id for tx_id in transcript_max_values_subset.index + if tx_id.startswith("transcript") and tx_id not in valid_transcripts][:5] + + self.logger.debug("=== EXPRESSION FILTERING DEBUG ===") + self.logger.debug(f"Total transcripts before expression filtering: {total_transcripts_in_tpm}") + self.logger.debug(f"Novel transcripts in TPM file: {novel_transcripts_in_tpm}") + self.logger.debug(f"Ensembl transcripts in TPM file: {ensembl_transcripts_in_tpm}") + self.logger.debug( + f"Identified {len(valid_transcripts)} transcripts with TPM >= {min_value} " + f"in at least one of the conditions: {conditions_to_process}" + ) + self.logger.debug(f"Novel transcripts passed: {novel_transcripts_passed} / {novel_transcripts_in_tpm}") + self.logger.debug(f"Ensembl transcripts passed: {ensembl_transcripts_passed} / {ensembl_transcripts_in_tpm}") + + if sample_novel_passed: + self.logger.debug(f"Sample novel transcripts that PASSED: {sample_novel_passed}") + if sample_novel_failed: + self.logger.debug(f"Sample novel transcripts that FAILED: {sample_novel_failed}") + # Show TPM values for failed novel transcripts + for tx_id in sample_novel_failed[:3]: + max_tpm = transcript_max_values_subset.get(tx_id, 0) + self.logger.debug(f" {tx_id}: max TPM = {max_tpm:.2f}") + + if novel_transcripts_passed == 0 and novel_transcripts_in_tpm > 0: + self.logger.warning(f"NO NOVEL TRANSCRIPTS PASSED expression filter! Consider lowering min_value from {min_value}") + # Show the highest TPM values for novel transcripts + novel_tpm_values = [(tx_id, transcript_max_values_subset.get(tx_id, 0)) + for tx_id in transcript_max_values_subset.index if tx_id.startswith("transcript")] + novel_tpm_values.sort(key=lambda x: x[1], reverse=True) + self.logger.debug("Top 5 novel transcript TPM values:") + for tx_id, tpm in novel_tpm_values[:5]: + self.logger.debug(f" {tx_id}: {tpm:.2f} TPM") + + # Build the final dictionary, iterating only through conditions_to_process + final_dict = {} + for condition in conditions_to_process: + final_dict[condition] = {} + condition_tpm_values = tpm_df[condition] # Get expression from the original full df + + for gene_id, gene_info in base_gene_dict.items(): + # Filter transcripts based on valid_transcripts set AND add expression value + new_transcripts = { + tid: {**tinfo, 'value': condition_tpm_values.get(tid, 0)} + for tid, tinfo in gene_info['transcripts'].items() + if tid in valid_transcripts # Apply the filter here + } + + # Only add gene if it has at least one valid transcript remaining + if new_transcripts: + final_dict[condition][gene_id] = { + **gene_info, # Copy base gene info + 'transcripts': new_transcripts, + 'exons': {} # Initialize exons, will be aggregated next + } + + # Validate structure for this condition's dictionary + self._validate_gene_structure(final_dict[condition]) + + # Aggregate exon values based on the filtered transcripts in the final_dict + self.logger.debug("Aggregating exon values based on filtered transcript expression.") + for condition in conditions_to_process: + for gene_id, gene_info in final_dict[condition].items(): + aggregated_exons = {} + for transcript_id, transcript_info in gene_info["transcripts"].items(): + transcript_value = transcript_info.get("value", 0) # TPM value from the filtered transcript + for exon in transcript_info.get("_original_exons", transcript_info.get("exons", [])): # Use original exon structure if available + exon_id = exon.get("exon_id") + if not exon_id: continue + if exon_id not in aggregated_exons: + aggregated_exons[exon_id] = { + "exon_id": exon_id, + "start": exon["start"], + "end": exon["end"], + "number": exon.get("number", "NA"), + "value": 0.0, # Initialize aggregate value + } + aggregated_exons[exon_id]["value"] += transcript_value # Sum transcript TPM + gene_info["exons"] = aggregated_exons # Assign aggregated exons + + # 5. Debug final results before saving + self.logger.debug("=== FINAL DICTIONARY RESULTS ===") + total_final_genes = sum(len(genes) for genes in final_dict.values()) + total_final_transcripts = 0 + final_novel_transcripts = 0 + final_ensembl_transcripts = 0 + + for condition, genes in final_dict.items(): + condition_transcripts = 0 + condition_novel_transcripts = 0 + + for gene_id, gene_info in genes.items(): + transcripts = gene_info.get("transcripts", {}) + condition_transcripts += len(transcripts) + + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript"): + condition_novel_transcripts += 1 + final_novel_transcripts += 1 + elif tx_id.startswith("ENSMUST"): + final_ensembl_transcripts += 1 + + total_final_transcripts += condition_transcripts + self.logger.debug(f"Condition '{condition}': {len(genes)} genes, {condition_transcripts} transcripts ({condition_novel_transcripts} novel)") + + self.logger.info(f"Totals across conditions: genes={total_final_genes}, transcripts={total_final_transcripts}, novel={final_novel_transcripts}, ensembl={final_ensembl_transcripts}") + + if final_novel_transcripts == 0: + self.logger.warning("FINAL RESULT: NO NOVEL TRANSCRIPTS in final dictionary!") + else: + self.logger.info(f"Novel transcripts passing filters: {final_novel_transcripts}") + + # 6. Save the newly built dictionary to the condition-specific cache + self.logger.debug(f"Saving filtered dictionary to cache: {condition_specific_cache_file}") + save_cache( + condition_specific_cache_file, + (final_dict, self.novel_gene_ids, self.novel_transcript_ids) + ) + + return final_dict + + def _get_tpm_file(self) -> str: + """Get the appropriate TPM file path from config.""" + self.logger.debug("=== TPM FILE SELECTION DEBUG ===") + self.logger.debug(f"config.conditions: {self.config.conditions}") + self.logger.debug(f"config.ref_only: {self.config.ref_only}") + self.logger.debug(f"config.transcript_grouped_tpm: {getattr(self.config, 'transcript_grouped_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_model_grouped_tpm: {getattr(self.config, 'transcript_model_grouped_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_tpm_ref: {getattr(self.config, 'transcript_tpm_ref', 'NOT_SET')}") + self.logger.debug(f"config.transcript_tpm: {getattr(self.config, 'transcript_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_model_tpm: {getattr(self.config, 'transcript_model_tpm', 'NOT_SET')}") + + if self.config.conditions: # Check if we have multiple conditions + if self.config.ref_only: + # Reference-only mode: use regular transcript files + merged_tpm = self.config.transcript_grouped_tpm + if merged_tpm and "_merged.tsv" in merged_tpm: + self.logger.debug("REF-ONLY: Using merged TPM file with transcript deduplication already applied") + tpm_file = merged_tpm + else: + tpm_file = self.config.transcript_grouped_tpm + self.logger.debug("REF-ONLY mode: Using transcript_grouped_tpm (reference transcripts only)") + else: + # Extended annotation mode: use transcript_model files that include novel transcripts + merged_tpm = getattr(self.config, 'transcript_model_grouped_tpm', None) + if merged_tpm and "_merged.tsv" in merged_tpm: + self.logger.debug("EXTENDED: Using merged transcript_model TPM file with deduplication") + tpm_file = merged_tpm + elif merged_tpm: + tpm_file = merged_tpm + self.logger.debug("EXTENDED: Using transcript_model_grouped_tpm (includes novel transcripts)") + else: + # Fallback to regular transcript file if transcript_model file not found + self.logger.warning("transcript_model_grouped_tpm not found, falling back to transcript_grouped_tpm") + tpm_file = self.config.transcript_grouped_tpm + else: + if self.config.ref_only: + tpm_file = self.config.transcript_tpm_ref + else: + # For single condition, use transcript_model files + transcript_model_tpm = getattr(self.config, 'transcript_model_tpm', None) + if transcript_model_tpm: + base_file = transcript_model_tpm.replace('.tsv', '') + tpm_file = f"{base_file}_merged.tsv" + self.logger.debug("EXTENDED: Using transcript_model TPM for single condition") + else: + base_file = self.config.transcript_tpm.replace('.tsv', '') + tpm_file = f"{base_file}_merged.tsv" + self.logger.warning("transcript_model_tpm not found, falling back to transcript_tpm") + + self.logger.info(f"Selected TPM file: {tpm_file}") + if not tpm_file or not Path(tpm_file).exists(): + self.logger.error(f"TPM file does not exist: {tpm_file}") + raise FileNotFoundError(f"TPM file {tpm_file} not found") + + # Check file size and sample content + tpm_path = Path(tpm_file) + self.logger.debug(f"TPM file size: {tpm_path.stat().st_size / (1024*1024):.2f} MB") + + # Sample a few lines from the TPM file to see what transcript IDs are present + with open(tpm_file, 'r') as f: + lines = f.readlines() + self.logger.debug(f"TPM file has {len(lines)} total lines") + if len(lines) > 1: + header = lines[0].strip() + self.logger.debug(f"TPM header: {header}") + + # Show sample transcript IDs + novel_count = 0 + ensembl_count = 0 + sample_novel = [] + sample_ensembl = [] + + for i in range(1, min(21, len(lines))): # Check first 20 data lines + transcript_id = lines[i].split('\t')[0] + if transcript_id.startswith('transcript'): + novel_count += 1 + if len(sample_novel) < 5: + sample_novel.append(transcript_id) + elif transcript_id.startswith('ENSMUST'): + ensembl_count += 1 + if len(sample_ensembl) < 5: + sample_ensembl.append(transcript_id) + + self.logger.debug(f"TPM file sample (first 20 lines): {novel_count} novel, {ensembl_count} Ensembl") + if sample_novel: + self.logger.debug(f"Sample novel transcript IDs: {sample_novel}") + if sample_ensembl: + self.logger.debug(f"Sample Ensembl transcript IDs: {sample_ensembl}") + + return tpm_file + + # ------------------ READ ASSIGNMENT CACHING ------------------ + + def build_read_assignment_and_classification_dictionaries(self): + """Delegate to read-assignment I/O module with caching.""" + return get_read_assignment_counts(self.config, self.cache_dir) + + def _post_process_cached_data(self, cached_data): + # Backwards-compat wrapper no longer used; kept for compatibility + if isinstance(self.config.read_assignments, list): + return ( + cached_data.get("classification_counts", {}), + cached_data.get("assignment_type_counts", {}), + ) + return cached_data + + def _process_read_assignment_file(self, file_path): + """Deprecated; maintained for compatibility. Use get_read_assignment_counts instead.""" + return {}, {} + + # ------------------ READ LENGTH VS ASSIGNMENT ------------------ + def build_length_vs_assignment(self): + """ + Stream read_assignment TSV file(s) and aggregate counts by read-length bins + versus (a) assignment_type (unique/ambiguous/inconsistent_*) and + (b) classification (full_splice_match/incomplete_splice_match/NIC/NNIC/etc.). + + Returns a dictionary: + { + 'bins': [bin_labels...], + 'assignment': { (bin, assignment_type) -> count }, + 'classification': { (bin, classification) -> count } + } + """ + if not self.config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + # Define length bins + bin_defs = [ + (0, 1000, '<1kb'), + (1000, 2000, '1-2kb'), + (2000, 5000, '2-5kb'), + (5000, 8000, '5-8kb'), + (8000, 12000, '8-12kb'), + (12000, 20000, '12-20kb'), + (20000, 50000, '20-50kb'), + (50000, float('inf'), '>50kb'), + ] + + def bin_length(length_bp: int) -> str: + for lo, hi, name in bin_defs: + if lo <= length_bp < hi: + return name + return 'unknown' + + def calc_length(exons_str: str) -> int: + if not exons_str: + return 0 + total = 0 + for part in exons_str.split(','): + if '-' not in part: + continue + try: + s, e = part.split('-') + total += int(e) - int(s) + 1 + except Exception: + continue + return total + + assign_counts = {} + class_counts = {} + + # Helper to process a single file (plain or gz) + import gzip + def process_file(fp: str): + def smart_open(path_str): + try: + with open(path_str, 'rb') as bf: + if bf.read(2) == b'\x1f\x8b': + return gzip.open(path_str, 'rt') + except Exception: + pass + return open(path_str, 'rt') + with smart_open(fp) as file: + # Skip header lines starting with '#' + # Read line by line to avoid loading entire file + for line in file: + if not line or line.startswith('#'): + continue + parts = line.rstrip('\n').split('\t') + if len(parts) < 9: + continue + assignment_type = parts[5] + exons = parts[7] + additional = parts[8] + # Classification=VALUE; in additional_info + classification = additional.split('Classification=')[-1].split(';')[0].strip() if 'Classification=' in additional else 'Unknown' + + length_bp = calc_length(exons) + b = bin_length(length_bp) + + # Update assignment_type bin counts + key_a = (b, assignment_type) + assign_counts[key_a] = assign_counts.get(key_a, 0) + 1 + + # Update classification bin counts + key_c = (b, classification) + class_counts[key_c] = class_counts.get(key_c, 0) + 1 + + # Process single or multiple files + if isinstance(self.config.read_assignments, list): + for _sample, path in self.config.read_assignments: + process_file(path) + else: + process_file(self.config.read_assignments) + + return { + 'bins': [name for _, _, name in bin_defs], + 'assignment': assign_counts, + 'classification': class_counts, + } + + # ------------------ READ LENGTH EFFECTS ------------------ + def build_read_length_effects(self): + """Delegate to read-assignment I/O module with caching.""" + return get_read_length_effects(self.config, self.cache_dir) + + def build_read_length_histogram(self, bin_edges: List[int] = None): + """Delegate to read-assignment I/O module with caching.""" + return get_read_length_histogram(self.config, self.cache_dir, bin_edges) + + # -------------------- GTF PARSING -------------------- + + def parse_gtf(self) -> Dict[str, Any]: + """ + Parse GTF file into a dictionary with genes, transcripts, and exons. + Handles both reference GTF (with gffutils) and extended annotation GTF. + """ + self.logger.info("=== GTF PARSING DEBUG ===") + self.logger.info(f"config.ref_only: {self.config.ref_only}") + self.logger.info(f"config.extended_annotation: {getattr(self.config, 'extended_annotation', 'NOT_SET')}") + self.logger.info(f"config.input_gtf: {getattr(self.config, 'input_gtf', 'NOT_SET')}") + self.logger.info(f"config.genedb_filename: {getattr(self.config, 'genedb_filename', 'NOT_SET')}") + + if self.config.ref_only: + # Use gffutils for reference GTF (more robust but slower) + self.logger.info("Parsing reference GTF using gffutils") + return self._parse_reference_gtf() + else: + # Use faster custom parser for extended annotation + self.logger.info("Parsing extended annotation GTF with custom parser") + return self._parse_extended_gtf() + + def _parse_reference_gtf(self) -> Dict[str, Any]: + """Parse reference GTF using gffutils""" + # Check if genedb_filename exists, if not create one + if not self.config.genedb_filename or not Path(self.config.genedb_filename).exists(): + if self.config.genedb_filename: + self.logger.warning(f"Configured genedb file does not exist: {self.config.genedb_filename}") + + db_path = self.cache_dir / "gtf.db" + if not db_path.exists(): + self.logger.info(f"Creating GTF database at {db_path}") + if not self.config.input_gtf or not Path(self.config.input_gtf).exists(): + raise FileNotFoundError(f"Input GTF file required for database creation but not found: {self.config.input_gtf}") + + gffutils.create_db( + self.config.input_gtf, + dbfn=str(db_path), + force=True, + merge_strategy="create_unique", + disable_infer_genes=True, + disable_infer_transcripts=True, + verbose=False, + ) + self.config.genedb_filename = str(db_path) + self.logger.info(f"Using fallback GTF database: {self.config.genedb_filename}") + + self.logger.info(f"Opening GTF database: {self.config.genedb_filename}") + db = gffutils.FeatureDB(self.config.genedb_filename) + + # Pre-fetch all features + self.logger.info("Pre-fetching features from database") + features = {feature.id: feature for feature in db.all_features()} + + # Build gene -> transcripts -> exons structure + gene_dict = {} + self.logger.info("Processing gene features") + for feature in features.values(): + if feature.featuretype != "gene": + continue + + gene_id = feature.id + gene_name = feature.attributes.get("gene_name", [gene_id])[0] # Default to gene_id if name missing + gene_biotype = feature.attributes.get("gene_biotype", ["unknown"])[0] # Default to "unknown" + + gene_dict[gene_id] = { + "chromosome": feature.seqid, + "start": feature.start, + "end": feature.end, + "strand": feature.strand, + "name": gene_name, # Use updated gene_name + "biotype": gene_biotype, # Use updated gene_biotype + "transcripts": {}, + } + + self.logger.info("Processing transcript and exon features") + for feature in features.values(): + if feature.featuretype == "transcript": + gene_id = feature.attributes.get("gene_id", [""])[0] + if gene_id not in gene_dict: + continue + + transcript_id = feature.id + transcript_name = feature.attributes.get("transcript_name", [transcript_id])[0] # Default to transcript_id + transcript_biotype = feature.attributes.get("transcript_biotype", ["unknown"])[0] # Default to "unknown" + transcript_tags = feature.attributes.get("tag", [""])[0].split(",") # Get tags + + gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": feature.start, + "end": feature.end, + "name": transcript_name, # Use updated transcript_name + "biotype": transcript_biotype, # Use updated transcript_biotype + "exons": [], + "tags": transcript_tags, # Use updated transcript_tags + } + elif feature.featuretype == "exon": + gene_id = feature.attributes.get("gene_id", [""])[0] + transcript_id = feature.attributes.get("transcript_id", [""])[0] + if ( + gene_id in gene_dict + and transcript_id in gene_dict[gene_id]["transcripts"] + ): + exon_number = feature.attributes.get("exon_number", ["1"])[0] # Default to "1" + exon_id = feature.attributes.get("exon_id", [""])[0] # Get exon_id + + gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append( + { + "exon_id": exon_id, # Use retrieved exon_id + "start": feature.start, + "end": feature.end, + "number": exon_number, # Use updated exon_number + } + ) + + self.logger.info(f"Processed {len(gene_dict)} genes from reference GTF") + return gene_dict + + def _parse_extended_gtf(self) -> Dict[str, Any]: + """Parse extended annotation GTF with custom parser""" + base_gene_dict = {} + gtf_file = self.config.extended_annotation + self.logger.info(f"=== EXTENDED GTF PARSING DEBUG ===") + self.logger.info(f"Parsing extended annotation GTF: {gtf_file}") + + # Check file existence and size + gtf_path = Path(gtf_file) + if not gtf_path.exists(): + self.logger.error(f"Extended annotation GTF file does not exist: {gtf_file}") + raise FileNotFoundError(f"Extended annotation GTF file not found: {gtf_file}") + + file_size_mb = gtf_path.stat().st_size / (1024*1024) + self.logger.info(f"Extended GTF file size: {file_size_mb:.2f} MB") + + try: + with open(gtf_file, "r") as file: + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + + # First pass: genes and transcripts + for line in file: + if line.startswith("#") or not line.strip(): + continue + + fields = line.strip().split("\t") + if len(fields) < 9: + continue + + feature_type = fields[2] + attrs = dict(attr_pattern.findall(fields[8])) + gene_id = attrs.get("gene_id") + transcript_id = attrs.get("transcript_id") + + if feature_type == "gene" and gene_id: + if gene_id not in base_gene_dict: + base_gene_dict[gene_id] = { + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": attrs.get("gene_name", gene_id), + "biotype": attrs.get("gene_biotype", "unknown"), + "transcripts": {} + } + + elif feature_type == "transcript" and gene_id and transcript_id: + if gene_id not in base_gene_dict: + base_gene_dict[gene_id] = { + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": attrs.get("gene_name", gene_id), + "biotype": attrs.get("gene_biotype", "unknown"), + "transcripts": {} + } + + base_gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": int(fields[3]), + "end": int(fields[4]), + "exons": [], + "tags": attrs.get("tags", "").split(","), + "name": attrs.get("transcript_name", transcript_id), + "biotype": attrs.get("transcript_biotype", "unknown"), + } + + elif feature_type == "exon" and transcript_id and gene_id: + if gene_id in base_gene_dict and transcript_id in base_gene_dict[gene_id]["transcripts"]: + exon_info = { + "exon_id": attrs.get("exon_id", ""), + "start": int(fields[3]), + "end": int(fields[4]), + "number": attrs.get("exon_number", "1"), + "value": 0.0 + } + base_gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append(exon_info) + + # Debug: Analyze what we found + total_genes = len(base_gene_dict) + novel_genes = sum(1 for gene_id in base_gene_dict.keys() if "novel_gene" in gene_id) + ensembl_genes = sum(1 for gene_id in base_gene_dict.keys() if gene_id.startswith("ENSMUSG")) + + total_transcripts = 0 + novel_transcripts = 0 + ensembl_transcripts = 0 + sample_novel_transcripts = [] + sample_ensembl_transcripts = [] + + for gene_id, gene_info in base_gene_dict.items(): + transcripts = gene_info.get("transcripts", {}) + total_transcripts += len(transcripts) + + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript"): + novel_transcripts += 1 + if len(sample_novel_transcripts) < 5: + sample_novel_transcripts.append(f"{gene_id}:{tx_id}") + elif tx_id.startswith("ENSMUST"): + ensembl_transcripts += 1 + if len(sample_ensembl_transcripts) < 5: + sample_ensembl_transcripts.append(f"{gene_id}:{tx_id}") + + self.logger.info(f"=== EXTENDED GTF PARSING RESULTS ===") + self.logger.info(f"Total genes parsed: {total_genes}") + self.logger.info(f"Novel genes: {novel_genes}, Ensembl genes: {ensembl_genes}") + self.logger.info(f"Total transcripts: {total_transcripts}") + self.logger.info(f"Novel transcripts: {novel_transcripts}, Ensembl transcripts: {ensembl_transcripts}") + + if sample_novel_transcripts: + self.logger.info(f"Sample novel transcripts: {sample_novel_transcripts}") + if sample_ensembl_transcripts: + self.logger.info(f"Sample Ensembl transcripts: {sample_ensembl_transcripts}") + + return base_gene_dict + except Exception as e: + self.logger.error(f"GTF parsing failed: {str(e)}") + raise + + # Keep the original functions for backward compatibility, but have them use the new implementation + def parse_input_gtf(self) -> Dict[str, Any]: + """ + Parse the reference GTF file using gffutils. + This is now a wrapper around _parse_reference_gtf for backward compatibility. + """ + return self._parse_reference_gtf() + + def parse_extended_annotation(self) -> Dict[str, Any]: + """ + Parse extended annotation GTF. + This is now a wrapper around _parse_extended_gtf for backward compatibility. + """ + return self._parse_extended_gtf() + + # -------------------- UPDATES & UTILITIES -------------------- + + def update_gene_names(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Update gene and transcript identifiers to their names, if available, + while preserving all nested structure. + """ + try: + updated_dict = {} + total_transcripts = 0 + + for condition, genes in gene_dict.items(): + updated_genes = {} + condition_transcripts = 0 + + for gene_id, gene_info in genes.items(): + new_gene_info = copy.deepcopy(gene_info) + + # Update gene name + if "name" in gene_info and gene_info["name"]: + gene_name_upper = gene_info["name"].upper() + updated_genes[gene_name_upper] = new_gene_info + else: + updated_genes[gene_id] = new_gene_info + + # Count transcripts + transcripts = new_gene_info.get("transcripts", {}) + condition_transcripts += len(transcripts) + + # Debug sample of transcript structure + if gene_id == list(genes.keys())[0]: + self.logger.debug(f"Sample gene {gene_id} transcript structure:") + for tid in list(transcripts.keys())[:3]: + self.logger.debug(f"Transcript {tid}: {transcripts[tid]}") + + total_transcripts += condition_transcripts + updated_dict[condition] = updated_genes + self.logger.debug(f"Condition {condition}: {condition_transcripts} transcripts") + + self.logger.info(f"Updated gene names for {len(gene_dict)} conditions") + self.logger.info(f"Total transcripts in dictionary: {total_transcripts}") + + return updated_dict + + except Exception as e: + self.logger.error(f"Error updating gene/transcript names: {e}") + self.logger.error(f"Dictionary structure before update: {str(type(gene_dict))}") + raise + + def read_gene_list(self, gene_list_path: Union[str, Path]) -> List[str]: + """ + Read and parse a plain-text file containing one gene identifier per line. + Return a list of uppercase gene IDs/names. + """ + try: + with open(gene_list_path, "r") as file: + gene_list = [line.strip().upper() for line in file if line.strip()] + self.logger.debug(f"Read {len(gene_list)} genes from {gene_list_path}") + return gene_list + except Exception as e: + self.logger.error(f"Error reading gene list from {gene_list_path}: {e}") + raise + + def _filter_novel_genes(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: + """Filter out novel genes based on gene ID pattern.""" + self.logger.info("=== NOVEL GENE FILTERING DEBUG ===") + self.logger.info(f"Starting novel gene filtering on {len(gene_dict)} genes") + + filtered_dict = {} + total_removed_genes = 0 + total_removed_transcripts = 0 + checked_gene_count = 0 + sample_removed = [] # For debug logging + sample_kept_novel_transcripts = [] # For novel transcripts in kept genes + + novel_gene_pattern = r"novel_gene" # Make sure this pattern is correct for your novel gene IDs + self.logger.info(f"Using novel gene pattern: '{novel_gene_pattern}'") + + for gene_id, gene_info in gene_dict.items(): + checked_gene_count += 1 + is_novel = bool(re.match(novel_gene_pattern, gene_id)) + + if is_novel: + total_removed_genes += 1 + removed_transcript_count = len(gene_info.get("transcripts", {})) + total_removed_transcripts += removed_transcript_count + + # Add novel gene ID to the set + self.novel_gene_ids.add(gene_id) + # Add novel transcript IDs to the set + transcripts = gene_info.get("transcripts", {}) + self.novel_transcript_ids.update(transcripts.keys()) + + if len(sample_removed) < 5: # Sample log of removed genes + sample_transcripts = list(transcripts.keys())[:3] # Show first 3 transcripts + sample_removed.append({ + 'gene_id': gene_id, + 'transcript_count': removed_transcript_count, + 'sample_transcripts': sample_transcripts + }) + continue # Skip adding novel genes to filtered_dict + else: + # Check if this kept gene has any novel transcripts + transcripts = gene_info.get("transcripts", {}) + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript") and len(sample_kept_novel_transcripts) < 10: + sample_kept_novel_transcripts.append(f"{gene_id}:{tx_id}") + + filtered_dict[gene_id] = gene_info # Keep known genes + + self.logger.info(f"=== NOVEL GENE FILTERING RESULTS ===") + self.logger.info(f"Checked {checked_gene_count} total genes") + self.logger.info( + f"Removed {total_removed_genes} novel genes " + f"({total_removed_genes/checked_gene_count:.2%} of total) " + f"and {total_removed_transcripts} associated transcripts" + ) + self.logger.info(f"Kept {len(filtered_dict)} genes after novel gene filtering") + + if sample_removed: + self.logger.info("Sample removed novel genes:") + for g in sample_removed: + self.logger.info(f"- {g['gene_id']}: {g['transcript_count']} transcripts {g['sample_transcripts']}") + else: + self.logger.warning("No novel genes detected with current filtering pattern") + + if sample_kept_novel_transcripts: + self.logger.info(f"Sample novel transcripts in KEPT genes: {sample_kept_novel_transcripts}") + else: + self.logger.warning("No novel transcripts found in kept genes!") + + return filtered_dict + + def get_novel_feature_ids(self) -> Tuple[set, set]: + """Return the sets of novel gene and transcript IDs.""" + return self.novel_gene_ids, self.novel_transcript_ids + + def _validate_gene_structure(self, gene_dict: Dict[str, Any]) -> None: + """Ensure proper gene-centric structure before condition processing.""" + required_gene_keys = ['chromosome', 'start', 'end', 'strand', 'name', 'biotype', 'transcripts'] + + for gene_id, gene_info in gene_dict.items(): + # Check gene ID format + if not isinstance(gene_id, str) or len(gene_id) < 4: + self.logger.error(f"Invalid gene ID format: {gene_id}") + raise ValueError("Malformed gene ID structure") + + # Check required keys + missing = [k for k in required_gene_keys if k not in gene_info] + if missing: + self.logger.error(f"Gene {gene_id} missing keys: {missing}") + raise ValueError("Incomplete gene information") + + # Check transcripts structure + transcripts = gene_info.get('transcripts', {}) + if not isinstance(transcripts, dict): + self.logger.error(f"Invalid transcripts in gene {gene_id} - expected dict") + raise ValueError("Malformed transcript structure") diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py new file mode 100644 index 00000000..0f34a7e1 --- /dev/null +++ b/src/visualization_differential_exp.py @@ -0,0 +1,1296 @@ +from __future__ import annotations +import logging +import pandas as pd +from typing import Dict, List, Tuple, Optional, Union, Any +from pathlib import Path +from rpy2 import robjects +from rpy2.robjects import r, Formula +from rpy2.robjects.packages import importr +from rpy2.robjects import pandas2ri +from rpy2.robjects.conversion import localconverter +from src.visualization_plotter import ExpressionVisualizer +from src.visualization_mapping import GeneMapper +import numpy as np +from sklearn.decomposition import PCA +from rpy2.rinterface_lib import callbacks + +class DifferentialAnalysis: + def __init__( + self, + output_dir: Path, + viz_output: Path, + ref_conditions: List[str], + target_conditions: List[str], + updated_gene_dict: Dict[str, Dict], + ref_only: bool = False, + dictionary_builder: Optional[Any] = None, + filter_min_count: int = 10, + pca_n_components: int = 10, + top_transcripts_base_mean: int = 500, + top_n_genes: int = 100, + log_level: int = logging.INFO, # Allow configuring log level + tech_rep_dict: Dict[str, str] = None, + # New options + use_shrunk_lfc_for_visuals: bool = True, + transcript_filter_mode: str = "per_group_min", # or "half_samples" + transcript_min_per_group: int = 2, + transcript_min_total_fraction: float = 0.5, + covariate_df: Optional[pd.DataFrame] = None, # index: base sample_id (without condition prefix) + size_factor_type: str = "poscounts", # DESeq2 sfType, recommended for zero-heavy data + ): + """Initialize differential expression analysis.""" + def quiet_cb(x): + pass + + # Silence R stdout/stderr + callbacks.logger.setLevel(logging.WARNING) # Affects R's logging only + callbacks.consolewrite_print = quiet_cb + callbacks.consolewrite_warnerror = quiet_cb + + self.output_dir = Path(output_dir) + self.deseq_dir = Path(viz_output) / "differential_expression" + self.deseq_dir.mkdir(parents=True, exist_ok=True) + self.ref_conditions = ref_conditions + self.target_conditions = target_conditions + self.ref_only = ref_only + self.updated_gene_dict = updated_gene_dict + self.dictionary_builder = dictionary_builder + + # Configurable parameters + self.filter_min_count = filter_min_count + self.pca_n_components = pca_n_components + self.top_transcripts_base_mean = top_transcripts_base_mean + self.top_n_genes = top_n_genes # Used for both gene and transcript top list size + + # Create a single logger for this class + self.logger = logging.getLogger('IsoQuant.visualization.differential_exp') + self.logger.setLevel(log_level) # Set logger level + + # Get transcript mapping if available + self.transcript_map = {} + if hasattr(self.dictionary_builder, 'config') and hasattr(self.dictionary_builder.config, 'transcript_map'): + self.transcript_map = self.dictionary_builder.config.transcript_map + if self.transcript_map: + self.logger.info(f"Using transcript mapping from dictionary_builder with {len(self.transcript_map)} entries for DESeq2 analysis") + else: + # Try to load transcript mapping directly from file + self.logger.info("Transcript mapping from dictionary_builder is empty, trying to load it directly from file") + self._load_transcript_mapping_from_file() + else: + # Try to load transcript mapping directly from file + self.logger.info("No transcript mapping available from dictionary_builder, trying to load it directly from file") + self._load_transcript_mapping_from_file() + + self.transcript_to_gene = self._create_transcript_to_gene_map() + self.visualizer = ExpressionVisualizer(self.deseq_dir) + self.gene_mapper = GeneMapper() + self.tech_rep_dict = tech_rep_dict + self.use_shrunk_lfc_for_visuals = use_shrunk_lfc_for_visuals + self.transcript_filter_mode = transcript_filter_mode + self.transcript_min_per_group = transcript_min_per_group + self.transcript_min_total_fraction = transcript_min_total_fraction + self.covariate_df = covariate_df + self.size_factor_type = size_factor_type + + # ------------------------- + # Small helpers to reduce duplication + # ------------------------- + def _get_labels(self) -> Tuple[str, str]: + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + return target_label, reference_label + + def _annotate_results(self, level: str, results_df: pd.DataFrame) -> pd.DataFrame: + if results_df is None or results_df.empty: + return results_df + results_df = results_df.copy() + results_df.index.name = "feature_id" + results_df.reset_index(inplace=True) + mapping = self._map_gene_symbols(results_df["feature_id"].unique(), level) + results_df["transcript_symbol"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("transcript_symbol", x) + ) + results_df["gene_name"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("gene_name", x.split('.')[0] if '.' in x else x) + ) + if level == "gene": + results_df = results_df.drop(columns=["transcript_symbol"], errors='ignore') + return results_df + + def _save_results(self, level: str, results_df: pd.DataFrame, results_shrunk_df: Optional[pd.DataFrame]) -> Tuple[Path, Optional[Path]]: + target_label, reference_label = self._get_labels() + outfile = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}.csv" + results_df.to_csv(outfile, index=False) + shrunk_path = None + if results_shrunk_df is not None and not results_shrunk_df.empty: + # Annotate shrunk for convenience + annotated_shrunk = self._annotate_results(level, results_shrunk_df) + shrunk_path = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}_shrunk_annotated.csv" + annotated_shrunk.to_csv(shrunk_path, index=False) + return outfile, shrunk_path + + def _lfc_for_visuals(self, base_df: pd.DataFrame, shrunk_df: Optional[pd.DataFrame]) -> pd.DataFrame: + df = base_df.copy() + if not self.use_shrunk_lfc_for_visuals or shrunk_df is None or shrunk_df.empty: + return df + try: + merged = pd.merge( + df[["feature_id", "log2FoldChange"]], + shrunk_df[["feature_id", "log2FoldChange"]], + on="feature_id", + how="left", + suffixes=("", "_shrunk"), + ) + lfc_map = merged.set_index("feature_id")["log2FoldChange_shrunk"] + replacement = df["feature_id"].map(lfc_map) + df["log2FoldChange"] = replacement.fillna(df["log2FoldChange"]).values + # Optionally retain a column for reference + df = pd.merge(df, merged[["feature_id", "log2FoldChange_shrunk"]], on="feature_id", how="left") + except Exception as e: + self.logger.warning(f"Could not merge shrunk LFCs for visuals: {e}") + return df + + def _load_prefixed_counts(self, pattern: str) -> pd.DataFrame: + """Load count tsvs for all conditions, prefix columns with condition, and concat.""" + all_sample_dfs: List[pd.DataFrame] = [] + for condition in self.ref_conditions + self.target_conditions: + condition_dir = Path(self.output_dir) / condition + count_files = list(condition_dir.glob(f"*{pattern}")) + if not count_files: + self.logger.error(f"No count files found for condition: {condition}") + raise FileNotFoundError(f"No count files matching {pattern} found in {condition_dir}") + for file_path in count_files: + self.logger.debug(f"Reading count data from: {file_path}") + df = pd.read_csv(file_path, sep="\t") + if "#feature_id" not in df.columns and df.columns[0].startswith("#"): + df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) + df.set_index("#feature_id", inplace=True) + # Prefix columns + df.rename(columns={col: f"{condition}_{col}" for col in df.columns}, inplace=True) + all_sample_dfs.append(df) + if not all_sample_dfs: + raise ValueError("No sample data found") + return pd.concat(all_sample_dfs, axis=1) + + def _load_transcript_mapping_from_file(self): + """Load transcript mapping directly from the transcript_mapping.tsv file.""" + mapping_file = self.output_dir / "transcript_mapping.tsv" + + if not mapping_file.exists(): + self.logger.warning(f"Transcript mapping file not found at {mapping_file}") + return + + try: + # Load the transcript mapping file + self.logger.debug(f"Loading transcript mapping from {mapping_file}") + self.transcript_map = {} + + # Skip header and read the mapping + with open(mapping_file, 'r') as f: + header = f.readline() # Skip header + for line in f: + parts = line.strip().split('\t') + if len(parts) == 2: + transcript_id, canonical_id = parts + self.transcript_map[transcript_id] = canonical_id + + self.logger.info(f"Successfully loaded {len(self.transcript_map)} transcript mappings from file") + + # Log some examples for debugging + sample_items = list(self.transcript_map.items())[:5] + for orig, canon in sample_items: + self.logger.debug(f"Mapping sample: {orig} → {canon}") + except Exception as e: + self.logger.error(f"Failed to load transcript mapping: {str(e)}") + + def _create_transcript_to_gene_map(self) -> Dict[str, str]: + """ + Create a mapping from transcript IDs to gene names. + + Returns: + Dict[str, str]: Mapping from transcript ID to gene name. + """ + transcript_map = {} + for gene_category, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_name = gene_info.get("name", gene_id) + transcripts = gene_info.get("transcripts", {}) + for transcript_id, transcript_info in transcripts.items(): + transcript_name = transcript_info.get("name", gene_name) + transcript_map[transcript_id] = transcript_name + return transcript_map + + def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame]: + """ + Run differential expression analysis for both genes and transcripts. + Orchestrates loading, filtering, DESeq2 execution, and visualization. + + Returns: + Tuple containing: + - Path to gene results file + - Path to transcript results file + - DataFrame of transcript counts (filtered but not normalized) + - DataFrame of DESeq2 gene-level results (unfiltered by significance) + """ + self.logger.info("Starting differential expression analysis workflow.") + + # --- 1. Load and Filter Data --- + gene_counts_filtered, transcript_counts_filtered = self._load_and_filter_data() + + # Store filtered transcript counts (as required by original return signature) + self.transcript_count_data = transcript_counts_filtered + + # --- 2. Run DESeq2 Analysis (Gene Level) --- + (deseq2_results_gene_file, + deseq2_results_df_gene, + deseq2_results_df_gene_shrunk, + gene_normalized_counts, + gene_vst_counts) = self._perform_level_analysis("gene", gene_counts_filtered) + + # --- 3. Run DESeq2 Analysis (Transcript Level) --- + (deseq2_results_transcript_file, + deseq2_results_df_transcript, + deseq2_results_df_transcript_shrunk, + transcript_normalized_counts, + transcript_vst_counts) = self._perform_level_analysis("transcript", transcript_counts_filtered) + + # --- 4. Generate Visualizations --- + self._generate_visualizations( + gene_counts_filtered=gene_counts_filtered, # Pass filtered counts for coldata generation + transcript_counts_filtered=transcript_counts_filtered, # Pass filtered counts for coldata generation + gene_normalized_counts=gene_normalized_counts, + transcript_normalized_counts=transcript_normalized_counts, + gene_vst_counts=gene_vst_counts, + transcript_vst_counts=transcript_vst_counts, + deseq2_results_df_gene=deseq2_results_df_gene, + deseq2_results_df_transcript=deseq2_results_df_transcript, + deseq2_results_df_gene_shrunk=deseq2_results_df_gene_shrunk, + deseq2_results_df_transcript_shrunk=deseq2_results_df_transcript_shrunk + ) + + self.logger.info("Differential expression analysis workflow complete.") + # Return signature matches original: results files, filtered transcript counts, gene results df + return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered, deseq2_results_df_gene + + def _load_and_filter_data(self) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Loads, filters (novelty, validity, counts), and returns gene and transcript count data.""" + self.logger.info("Loading and filtering count data...") + + # --- Load Count Data --- + gene_counts = self._get_condition_data("gene_grouped_counts.tsv") + transcript_counts = self._get_condition_data("transcript_grouped_counts.tsv") + self.logger.debug(f"Raw transcript counts shape: {transcript_counts.shape}") + self.logger.debug(f"Raw gene counts shape: {gene_counts.shape}") + + # --- Apply Transcript-Specific Filters --- + transcript_counts_filtered = self._apply_transcript_filters(transcript_counts) + + # --- Apply Count-based Filtering (Gene Level) --- + gene_counts_filtered = self._filter_counts(gene_counts, level="gene") + + if gene_counts_filtered.empty: + self.logger.error("No genes remaining after count filtering.") + raise ValueError("No genes remaining after count filtering.") + if transcript_counts_filtered.empty: + self.logger.error("No transcripts remaining after count filtering.") + raise ValueError("No transcripts remaining after count filtering.") + + self.logger.info("Data loading and filtering complete.") + self.logger.info(f"Final gene counts shape: {gene_counts_filtered.shape}") + self.logger.info(f"Final transcript counts shape: {transcript_counts_filtered.shape}") + + return gene_counts_filtered, transcript_counts_filtered + + def _apply_transcript_filters(self, transcript_counts: pd.DataFrame) -> pd.DataFrame: + """Applies novel, valid, and count-based filters specifically to transcript data.""" + self.logger.debug(f"Applying filters to transcript data (initial shape: {transcript_counts.shape})") + + # --- Valid Transcript Set --- + # Determine the set of transcripts considered valid based on the updated_gene_dict + valid_transcripts = set() + for condition_genes in self.updated_gene_dict.values(): + for gene_info in condition_genes.values(): + valid_transcripts.update(gene_info.get("transcripts", {}).keys()) + if not valid_transcripts: + self.logger.warning("No valid transcripts found in updated_gene_dict. Skipping validity filter.") + self.logger.debug(f"Found {len(valid_transcripts)} valid transcript IDs in updated_gene_dict.") + # Apply validity filter if available + if valid_transcripts: + before_valid = transcript_counts.shape[0] + transcript_counts = transcript_counts[transcript_counts.index.isin(valid_transcripts)] + self.logger.info(f"Validity filtering: Retained {transcript_counts.shape[0]} / {before_valid} transcripts present in updated_gene_dict") + + # --- Novel Transcript Filtering --- + if self.dictionary_builder: + novel_transcript_ids = self.dictionary_builder.get_novel_feature_ids()[1] # Assuming index 1 is transcripts + self.logger.debug(f"Number of novel transcripts identified: {len(novel_transcript_ids)}") + original_count = transcript_counts.shape[0] + transcript_counts = transcript_counts[~transcript_counts.index.isin(novel_transcript_ids)] + removed_count = original_count - transcript_counts.shape[0] + perc_removed = (removed_count / original_count * 100) if original_count > 0 else 0 + self.logger.info(f"Novel Gene filtering: Removed {removed_count} transcripts ({perc_removed:.1f}%)") + self.logger.debug(f"Shape after novel filtering: {transcript_counts.shape}") + else: + self.logger.info("Novel transcript filtering: Skipped (no dictionary builder).") + + + + if transcript_counts.empty: + self.logger.warning("No transcripts remaining after novel gene filtering. Count filtering will be skipped.") + return transcript_counts # Return empty dataframe + + # --- Count-based Filtering (Transcript Level) --- + transcript_counts_filtered = self._filter_counts(transcript_counts, level="transcript") + + self.logger.debug(f"Final transcript counts shape after all filters: {transcript_counts_filtered.shape}") + return transcript_counts_filtered + + def _perform_level_analysis( + self, level: str, count_data: pd.DataFrame + ) -> Tuple[Path, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Runs DESeq2 analysis for a specific level (gene or transcript). + + Args: + level: Analysis level ("gene" or "transcript"). + count_data: PRE-FILTERED count data DataFrame for the level. + + Returns: + Tuple containing: + - Path to the saved DESeq2 results CSV file. + - DataFrame of the DESeq2 results. + - DataFrame of the DESeq2 normalized counts. + """ + self.logger.info(f"Performing DESeq2 analysis for level: {level}") + + if count_data.empty: + self.logger.error(f"Input count data is empty for level: {level}") + raise ValueError(f"Input count data is empty for level: {level}") + + # Create design matrix + coldata = self._build_design_matrix(count_data) + + # Run DESeq2 - returns results, shrunk results, normalized counts, and VST counts + results_df, results_shrunk_df, normalized_counts_df, vst_counts_df = self._run_deseq2(count_data, coldata, level) + + # --- Process and annotate DESeq2 Results --- + results_df = self._annotate_results(level, results_df) + + # --- Save Results --- + # Save both standard and annotated shrunk results + outfile, _ = self._save_results(level, results_df, results_shrunk_df) + self.logger.info(f"Saved DESeq2 results ({results_df.shape[0]} features) to {outfile}") + + # --- Write Top Genes/Transcripts --- + self._write_top_genes(results_df, level) + + self.logger.info(f"DESeq2 analysis complete for level: {level}") + return outfile, results_df, results_shrunk_df, normalized_counts_df, vst_counts_df + + def _generate_visualizations( + self, + gene_counts_filtered: pd.DataFrame, + transcript_counts_filtered: pd.DataFrame, + gene_normalized_counts: pd.DataFrame, + transcript_normalized_counts: pd.DataFrame, + gene_vst_counts: pd.DataFrame, + transcript_vst_counts: pd.DataFrame, + deseq2_results_df_gene: pd.DataFrame, + deseq2_results_df_transcript: pd.DataFrame, + deseq2_results_df_gene_shrunk: Optional[pd.DataFrame] = None, + deseq2_results_df_transcript_shrunk: Optional[pd.DataFrame] = None, + ): + """Generates PCA plots and other visualizations based on DESeq2 results and normalized counts.""" + self.logger.info("Generating visualizations...") + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + + # --- Visualize Gene-Level DE Results --- + gene_results_for_plot = self._lfc_for_visuals(deseq2_results_df_gene, deseq2_results_df_gene_shrunk) + self.visualizer.visualize_results( + results=gene_results_for_plot, + target_label=target_label, + reference_label=reference_label, + min_count=self.filter_min_count, # Use configured value + feature_type="genes", + ) + self.logger.info(f"Gene-level DE summary visualizations saved to {self.deseq_dir}") + + # --- Run PCA (Gene Level) --- + gene_counts_for_pca = gene_vst_counts if gene_vst_counts is not None and not gene_vst_counts.empty else gene_normalized_counts + if not gene_counts_for_pca.empty: + gene_coldata = self._build_design_matrix(gene_counts_filtered) # Need coldata matching the counts used + self._run_pca( + normalized_counts=gene_counts_for_pca, + level="gene", + coldata=gene_coldata, + target_label=target_label, + reference_label=reference_label, + is_vst=(gene_vst_counts is not None and not gene_vst_counts.empty) + ) + else: + self.logger.warning("Skipping gene-level PCA: Normalized counts are empty.") + + # --- Visualize Transcript-Level DE Results --- + tx_results_for_plot = self._lfc_for_visuals(deseq2_results_df_transcript, deseq2_results_df_transcript_shrunk) + self.visualizer.visualize_results( + results=tx_results_for_plot, + target_label=target_label, + reference_label=reference_label, + min_count=self.filter_min_count, # Use configured value + feature_type="transcripts", + ) + self.logger.info(f"Transcript-level DE summary visualizations saved to {self.deseq_dir}") + + # --- Run PCA (Transcript Level) --- + tx_counts_for_pca = transcript_vst_counts if transcript_vst_counts is not None and not transcript_vst_counts.empty else transcript_normalized_counts + if not tx_counts_for_pca.empty: + transcript_coldata = self._build_design_matrix(transcript_counts_filtered) # Need coldata matching the counts used + self._run_pca( + normalized_counts=tx_counts_for_pca, + level="transcript", + coldata=transcript_coldata, + target_label=target_label, + reference_label=reference_label, + is_vst=(transcript_vst_counts is not None and not transcript_vst_counts.empty) + ) + else: + self.logger.warning("Skipping transcript-level PCA: Normalized counts are empty.") + + self.logger.info("Visualizations generated.") + + def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: + """ + Get transcript count data and apply transcript mapping to create a merged grouped dataframe. + This preserves the individual sample columns needed for DESeq2, but merges identical transcripts. + """ + self.logger.debug(f"Creating merged transcript count matrix with pattern: {pattern}") + + # Adjust pattern if needed + adjusted_pattern = pattern + if not self.ref_only and pattern == "transcript_grouped_counts.tsv": + adjusted_pattern = "transcript_model_grouped_counts.tsv" + + self.logger.debug(f"Using file pattern: {adjusted_pattern}") + + # Load and prefix columns consistently + combined_df = self._load_prefixed_counts(adjusted_pattern) + self.logger.info(f"Combined count data shape before mapping: {combined_df.shape}") + + # Apply technical replicate merging before transcript mapping + combined_df = self._merge_technical_replicates(combined_df) + + # Apply transcript mapping if available + if not hasattr(self, 'transcript_map') or not self.transcript_map: + self.logger.info("No transcript mapping available, using raw counts") + return combined_df + + # Log transcript mapping info + self.logger.info(f"Applying transcript mapping with {len(self.transcript_map)} mappings") + + # Get unique transcript IDs and create mapping dictionary + unique_transcripts = combined_df.index.unique() + transcript_groups = {} + + # Group transcripts by their canonical ID + for transcript_id in unique_transcripts: + canonical_id = self.transcript_map.get(transcript_id, transcript_id) + if canonical_id not in transcript_groups: + transcript_groups[canonical_id] = [] + transcript_groups[canonical_id].append(transcript_id) + + # Create the merged dataframe + merged_df = pd.DataFrame(index=list(transcript_groups.keys()), columns=combined_df.columns) + + # Track merge statistics + total_transcripts = len(unique_transcripts) + merged_groups = 0 + merged_transcripts = 0 + + # For each canonical transcript ID, sum the counts from all transcripts that map to it + for canonical_id, transcript_ids in transcript_groups.items(): + if len(transcript_ids) == 1: + # Just one transcript, copy the row directly + merged_df.loc[canonical_id] = combined_df.loc[transcript_ids[0]] + else: + # Multiple transcripts map to this canonical ID, sum their counts + merged_df.loc[canonical_id] = combined_df.loc[transcript_ids].sum() + merged_groups += 1 + merged_transcripts += len(transcript_ids) - 1 # Count transcripts beyond the first one + + # Log details of significant merges (more than 2 transcripts or interesting transcripts) + if len(transcript_ids) > 2 or any("ENST" in t for t in transcript_ids): + self.logger.debug(f"Merged transcript group for {canonical_id}: {transcript_ids}") + + # Log merge statistics + self.logger.info(f"Transcript merging complete: {merged_groups} canonical IDs had multiple transcripts") + self.logger.info(f"Merged {merged_transcripts} transcripts into canonical IDs ({merged_transcripts/total_transcripts:.1%} of total)") + self.logger.info(f"Final merged count matrix shape: {merged_df.shape}") + + return merged_df + + def _get_condition_data(self, pattern: str) -> pd.DataFrame: + """Get count data for differential expression analysis.""" + if pattern == "transcript_grouped_counts.tsv": + # For transcript data, use our merged function + return self._get_merged_transcript_counts(pattern) + elif pattern == "gene_grouped_counts.tsv": + # For gene data, use a simpler approach (no merging needed) + self.logger.info(f"Loading gene count data with pattern: {pattern}") + combined_df = self._load_prefixed_counts(pattern) + self.logger.info(f"Combined gene count data shape: {combined_df.shape}") + + # Apply technical replicate merging + combined_df = self._merge_technical_replicates(combined_df) + + return combined_df + else: + self.logger.error(f"Unsupported count pattern: {pattern}") + raise ValueError(f"Unsupported count pattern: {pattern}") + + def _filter_counts(self, count_data: pd.DataFrame, level: str = "gene") -> pd.DataFrame: + """ + Filter features based on counts using the configured threshold. + + For genes: Keep if mean count >= configured min_count in either condition group. + For transcripts: Behavior is configurable. + - per_group_min (default): require counts >= threshold in at least K samples per group + - half_samples: require counts >= threshold in >= fraction of all samples + """ + if count_data.empty: + self.logger.warning(f"Input count data for filtering ({level}) is empty. Returning empty DataFrame.") + return count_data + + # Use the configured minimum count threshold + min_count_threshold = self.filter_min_count + + if level == "transcript": + # Determine columns by condition name prefix + ref_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.ref_conditions) + ] + tgt_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.target_conditions) + ] + + if self.transcript_filter_mode == "half_samples": + total_cols = len(count_data.columns) + required = int(np.ceil(total_cols * float(self.transcript_min_total_fraction))) + passing_total = (count_data >= min_count_threshold).sum(axis=1) + keep_features = passing_total >= required + self.logger.info( + "Transcript filtering (half_samples): Keeping transcripts present in "+ + ">= %d/%d samples with counts >= %d", + required, total_cols, min_count_threshold + ) + else: + # Default: per-group minimum requirement + min_ref_required = ( + min(self.transcript_min_per_group, len(ref_cols)) if len(ref_cols) >= 1 else 0 + ) + min_tgt_required = ( + min(self.transcript_min_per_group, len(tgt_cols)) if len(tgt_cols) >= 1 else 0 + ) + + passing_ref = (count_data[ref_cols] >= min_count_threshold).sum(axis=1) if ref_cols else 0 + passing_tgt = (count_data[tgt_cols] >= min_count_threshold).sum(axis=1) if tgt_cols else 0 + + if isinstance(passing_ref, int): + self.logger.warning("No reference columns found for transcript filtering; no transcripts will pass.") + keep_features = count_data.index == "__none__" + elif isinstance(passing_tgt, int): + self.logger.warning("No target columns found for transcript filtering; no transcripts will pass.") + keep_features = count_data.index == "__none__" + else: + keep_features = (passing_ref >= min_ref_required) & (passing_tgt >= min_tgt_required) + + self.logger.info( + "Transcript filtering (per_group_min): Keeping transcripts with counts >= %d in at least %d/%d ref and %d/%d target samples", + min_count_threshold, min_ref_required, len(ref_cols), min_tgt_required, len(tgt_cols) + ) + else: # gene level + ref_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.ref_conditions) + ] + tgt_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.target_conditions) + ] + + # Handle cases where one condition might have no samples after potential upstream filtering + if not ref_cols: + self.logger.warning("No reference columns found for gene count filtering.") + ref_means = pd.Series(0, index=count_data.index) # Assign 0 mean if no ref samples + else: + ref_means = count_data[ref_cols].mean(axis=1) + + if not tgt_cols: + self.logger.warning("No target columns found for gene count filtering.") + tgt_means = pd.Series(0, index=count_data.index) # Assign 0 mean if no target samples + else: + tgt_means = count_data[tgt_cols].mean(axis=1) + + keep_features = (ref_means >= min_count_threshold) | (tgt_means >= min_count_threshold) + + self.logger.info( + f"Gene filtering: Keeping genes with mean count >= {min_count_threshold} " + f"in either reference or target condition group" + ) + + filtered_data = count_data[keep_features] + removed_count = count_data.shape[0] - filtered_data.shape[0] + self.logger.info( + f"After count filtering ({level}): Retained {filtered_data.shape[0]} / {count_data.shape[0]} features " + f"(Removed {removed_count})" + ) + + return filtered_data + + def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: + """Create experimental design matrix for DESeq2. + + Each column in the count data (sample) needs to be assigned to either + the reference or target group for differential expression analysis. + """ + groups = [] + condition_assignments = [] + sample_ids = [] + # Optional covariates + covariate_values: Dict[str, List[Optional[Union[str, float]]]] = {} + covariate_columns: List[str] = list(self.covariate_df.columns) if isinstance(self.covariate_df, pd.DataFrame) else [] + for cov_col in covariate_columns: + covariate_values[cov_col] = [] + + self.logger.debug("Building experimental design matrix") + + for sample in count_data.columns: + # Extract the condition from the sample name + # Matches pattern: conditionname_sampleid + # The column name should start with the condition name followed by an underscore + condition = None + for cond in self.ref_conditions + self.target_conditions: + if sample.startswith(f"{cond}_"): + condition = cond + # Extract the sample ID (everything after the condition name and underscore) + sample_id = sample[len(condition)+1:] + break + + if condition is None: + self.logger.error(f"Could not determine condition for sample: {sample}") + raise ValueError(f"Sample column '{sample}' does not match any specified condition") + + # Assign to reference or target group + if condition in self.ref_conditions: + groups.append("Reference") + else: + groups.append("Target") + + # Store the condition and sample ID for additional information + condition_assignments.append(condition) + sample_ids.append(sample) + + # Attach covariate values if provided + if covariate_columns: + for cov_col in covariate_columns: + try: + value = self.covariate_df.loc[sample_id, cov_col] + except Exception: + value = np.nan + covariate_values[cov_col].append(value) + + # Create the design matrix DataFrame + design_matrix = pd.DataFrame({ + "group": groups, + "condition": condition_assignments, + "sample_id": sample_ids + }, index=count_data.columns) + # Append covariates into design matrix + if covariate_columns: + for cov_col in covariate_columns: + design_matrix[cov_col] = covariate_values[cov_col] + + # Log the design matrix for debugging + self.logger.debug(f"Design matrix:\n{design_matrix}") + + return design_matrix + + def _run_deseq2( + self, count_data: pd.DataFrame, coldata: pd.DataFrame, level: str + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame], pd.DataFrame, pd.DataFrame]: + """ + Run DESeq2 analysis and return results and normalized counts. + + Args: + count_data: Raw count data (filtered). + coldata: Design matrix. + level: Analysis level (gene/transcript). + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: DESeq2 results, DESeq2 normalized counts. + """ + self.logger.info(f"Running DESeq2 for {level} level...") + deseq2 = importr("DESeq2") + # Ensure counts are integers for DESeq2 (fail fast; do not silently coerce) + count_data = count_data.fillna(0) + if (count_data < 0).any().any(): + raise ValueError(f"Negative counts detected for {level}.") + if not np.all(np.equal(count_data.values, np.floor(count_data.values))): + raise ValueError(f"Non-integer counts detected for {level}. Please supply raw integer counts.") + count_data = count_data.astype(int) + + if count_data.empty: + self.logger.error(f"Count data is empty before running DESeq2 for {level}.") + # Return empty dataframes if counts are empty + return pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns) + + try: + with localconverter(robjects.default_converter + pandas2ri.converter): + # Convert count_data and coldata to R DataFrames + count_data_r = robjects.conversion.py2rpy(count_data) + coldata_r = robjects.conversion.py2rpy(coldata) + + # Create DESeqDataSet + self.logger.debug("Creating DESeqDataSet...") + # Build design formula dynamically: ~ covariates + group + covariate_cols = [c for c in coldata.columns if c not in ["group", "condition", "sample_id"]] + formula_terms = (covariate_cols + ["group"]) if covariate_cols else ["group"] + design_formula = "~ " + " + ".join(formula_terms) + self.logger.info(f"Using design formula for {level}: {design_formula}") + # Ensure group and categorical covariates are factors with correct baseline + r('library(methods)') + r.assign("coldata_tmp", coldata_r) + r('coldata_tmp$group <- relevel(factor(coldata_tmp$group), "Reference")') + # Coerce non-numeric covariates to factors + for cov in covariate_cols: + if not pd.api.types.is_numeric_dtype(coldata[cov]): + r(f'coldata_tmp${cov} <- factor(coldata_tmp${cov})') + coldata_r = r('coldata_tmp') + + dds = deseq2.DESeqDataSetFromMatrix( + countData=count_data_r, colData=coldata_r, design=Formula(design_formula) + ) + + # Run DESeq analysis + self.logger.debug("Running DESeq()...") + # Use sfType configured; 'poscounts' is recommended for zero-heavy counts + dds = deseq2.DESeq(dds, sfType=self.size_factor_type) + + # Get results + self.logger.debug("Extracting results()...") + res = deseq2.results( + dds, contrast=robjects.StrVector(["group", "Target", "Reference"]) + ) + res_df = robjects.conversion.rpy2py(r("as.data.frame")(res)) # Convert to R dataframe first for stability + res_df.index = count_data.index # Assign original feature IDs as index + + # Extract dispersion estimates + self.logger.debug("Extracting dispersion estimates...") + dispersions_r = r['dispersions'](dds) + dispersions_py = robjects.conversion.rpy2py(dispersions_r) + + # Add dispersion estimates to results DataFrame + res_df['dispersion'] = dispersions_py + + # Extract size factors + self.logger.debug("Extracting size factors...") + size_factors_r = r['sizeFactors'](dds) + size_factors_py = robjects.conversion.rpy2py(size_factors_r) + + # Create size factors DataFrame with sample names + size_factors_df = pd.DataFrame({ + 'sample': count_data.columns, + 'size_factor': size_factors_py + }) + + # Save size factors to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + size_factors_file = self.deseq_dir / f"size_factors_{level}_{target_label}_vs_{reference_label}.csv" + size_factors_df.to_csv(size_factors_file, index=False) + self.logger.info(f"Size factors saved to {size_factors_file}") + + # Correct way to call the R 'counts' function on the dds object + # Ensure 'r' is imported: from rpy2.robjects import r + normalized_counts_r = r['counts'](dds, normalized=True) + + # Convert R matrix to pandas DataFrame + normalized_counts_py = robjects.conversion.rpy2py(normalized_counts_r) + # Ensure DataFrame structure matches original count_data (features x samples) + normalized_counts_df = pd.DataFrame(normalized_counts_py, index=count_data.index, columns=count_data.columns) + + # VST-transformed counts for PCA visualization stability + try: + vst_obj = deseq2.vst(dds, blind=True) + vst_mat_r = r['assay'](vst_obj) + vst_counts_py = robjects.conversion.rpy2py(vst_mat_r) + vst_counts_df = pd.DataFrame(vst_counts_py, index=count_data.index, columns=count_data.columns) + except Exception as e: + self.logger.warning(f"VST transformation failed or unavailable: {e}") + vst_counts_df = pd.DataFrame() + + # Generate dispersion and count summaries + self._generate_dispersion_summary(res_df, level) + + # LFC shrinkage (apeglm) for interpretability; keep Wald stats for GSEA + res_shrunk_df: Optional[pd.DataFrame] = None + try: + # Ensure apeglm is available + importr("apeglm") + # Find appropriate coefficient name + coef_names = robjects.conversion.rpy2py(r['resultsNames'](dds)) + # Prefer the standard group coefficient; fallback to first matching 'group' + coef_name = None + for name in coef_names: + if isinstance(name, str) and "group_Target_vs_Reference" in name: + coef_name = name + break + if coef_name is None: + for name in coef_names: + if isinstance(name, str) and name.startswith("group_"): + coef_name = name + break + if coef_name is None and len(coef_names) > 0: + coef_name = coef_names[0] + + self.logger.info(f"Applying LFC shrinkage with apeglm (coef={coef_name}) for {level} level...") + res_shrunk = deseq2.lfcShrink(dds, coef=coef_name, type="apeglm") + res_shrunk_df = robjects.conversion.rpy2py(r("as.data.frame")(res_shrunk)) + res_shrunk_df.index = count_data.index + + # Save shrunk results to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + outfile_shrunk = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}_shrunk.csv" + res_shrunk_df.to_csv(outfile_shrunk) + self.logger.info(f"Saved shrunk DESeq2 results to {outfile_shrunk}") + except Exception as e: + self.logger.warning(f"LFC shrinkage (apeglm) failed or unavailable: {e}") + + self.logger.info(f"DESeq2 run completed for {level}. Results shape: {res_df.shape}, Normalized counts shape: {normalized_counts_df.shape}, VST shape: {vst_counts_df.shape}") + return res_df, res_shrunk_df, normalized_counts_df, vst_counts_df + + except Exception as e: + self.logger.error(f"Error running DESeq2 for {level}: {str(e)}") + # Return empty DataFrames on error to avoid downstream issues + return pd.DataFrame(), pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns), pd.DataFrame(index=count_data.index, columns=count_data.columns) + + def _generate_dispersion_summary(self, results_df: pd.DataFrame, level: str) -> None: + """ + Generate summary statistics for average read counts and dispersion estimates. + Saves summary to a file and logs key statistics. + + Args: + results_df: DESeq2 results DataFrame with baseMean and dispersion columns + level: Analysis level (gene/transcript) + """ + if results_df.empty: + self.logger.warning(f"Cannot generate dispersion summary for {level}: Results DataFrame is empty.") + return + + self.logger.info(f"Generating dispersion and count summary for {level} level...") + + # Check if required columns exist + required_cols = ['baseMean', 'dispersion'] + missing_cols = [col for col in required_cols if col not in results_df.columns] + if missing_cols: + self.logger.warning(f"Cannot generate complete summary for {level}: Missing columns {missing_cols}") + return + + # Remove NaN values for summary statistics + clean_data = results_df[['baseMean', 'dispersion']].dropna() + + if clean_data.empty: + self.logger.warning(f"No valid data for dispersion summary for {level} after removing NaN values.") + return + + # Calculate summary statistics + summary_stats = { + 'level': level, + 'total_features': len(results_df), + 'features_with_valid_data': len(clean_data), + + # Average read count (baseMean) statistics + 'baseMean_mean': clean_data['baseMean'].mean(), + 'baseMean_median': clean_data['baseMean'].median(), + 'baseMean_std': clean_data['baseMean'].std(), + 'baseMean_min': clean_data['baseMean'].min(), + 'baseMean_max': clean_data['baseMean'].max(), + 'baseMean_q25': clean_data['baseMean'].quantile(0.25), + 'baseMean_q75': clean_data['baseMean'].quantile(0.75), + + # Dispersion statistics + 'dispersion_mean': clean_data['dispersion'].mean(), + 'dispersion_median': clean_data['dispersion'].median(), + 'dispersion_std': clean_data['dispersion'].std(), + 'dispersion_min': clean_data['dispersion'].min(), + 'dispersion_max': clean_data['dispersion'].max(), + 'dispersion_q25': clean_data['dispersion'].quantile(0.25), + 'dispersion_q75': clean_data['dispersion'].quantile(0.75), + } + + # Add size factor statistics if available + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + size_factors_file = self.deseq_dir / f"size_factors_{level}_{target_label}_vs_{reference_label}.csv" + + if size_factors_file.exists(): + try: + size_factors_df = pd.read_csv(size_factors_file) + if 'size_factor' in size_factors_df.columns: + sf_data = size_factors_df['size_factor'].dropna() + summary_stats.update({ + 'size_factor_mean': sf_data.mean(), + 'size_factor_median': sf_data.median(), + 'size_factor_std': sf_data.std(), + 'size_factor_min': sf_data.min(), + 'size_factor_max': sf_data.max(), + 'size_factor_q25': sf_data.quantile(0.25), + 'size_factor_q75': sf_data.quantile(0.75), + }) + self.logger.info(f" Size factors: mean={summary_stats['size_factor_mean']:.4f}, median={summary_stats['size_factor_median']:.4f}, range={summary_stats['size_factor_min']:.4f}-{summary_stats['size_factor_max']:.4f}") + except Exception as e: + self.logger.warning(f"Could not read size factors file: {e}") + + # Log key statistics + self.logger.info(f"{level.capitalize()} level summary:") + self.logger.info(f" Total features: {summary_stats['total_features']}") + self.logger.info(f" Features with valid data: {summary_stats['features_with_valid_data']}") + self.logger.info(f" Average read count (baseMean): mean={summary_stats['baseMean_mean']:.2f}, median={summary_stats['baseMean_median']:.2f}") + self.logger.info(f" Dispersion: mean={summary_stats['dispersion_mean']:.4f}, median={summary_stats['dispersion_median']:.4f}") + + # Create a more detailed summary for significant DE genes/transcripts + if 'padj' in results_df.columns: + significant_features = results_df[results_df['padj'] < 0.05].dropna(subset=['baseMean', 'dispersion']) + if not significant_features.empty: + summary_stats.update({ + 'significant_features_count': len(significant_features), + 'significant_baseMean_mean': significant_features['baseMean'].mean(), + 'significant_baseMean_median': significant_features['baseMean'].median(), + 'significant_dispersion_mean': significant_features['dispersion'].mean(), + 'significant_dispersion_median': significant_features['dispersion'].median(), + }) + + self.logger.info(f" Significant DE features (padj < 0.05): {summary_stats['significant_features_count']}") + self.logger.info(f" Significant features - Average read count: mean={summary_stats['significant_baseMean_mean']:.2f}, median={summary_stats['significant_baseMean_median']:.2f}") + self.logger.info(f" Significant features - Dispersion: mean={summary_stats['significant_dispersion_mean']:.4f}, median={summary_stats['significant_dispersion_median']:.4f}") + + # Save summary to file + summary_file = self.deseq_dir / f"dispersion_count_summary_{level}_{target_label}_vs_{reference_label}.txt" + + with open(summary_file, 'w') as f: + f.write(f"Dispersion and Count Summary for {level.capitalize()} Level Analysis\n") + f.write(f"Comparison: {target_label} vs {reference_label}\n") + f.write("=" * 60 + "\n\n") + + for key, value in summary_stats.items(): + if isinstance(value, float): + f.write(f"{key}: {value:.6f}\n") + else: + f.write(f"{key}: {value}\n") + + self.logger.info(f"Dispersion and count summary saved to {summary_file}") + + # Also save detailed data for further analysis + detailed_file = self.deseq_dir / f"detailed_dispersion_data_{level}_{target_label}_vs_{reference_label}.csv" + + # Include feature mapping information if available + detailed_data = results_df[['baseMean', 'dispersion', 'log2FoldChange', 'pvalue', 'padj']].copy() + if 'gene_name' in results_df.columns: + detailed_data['gene_name'] = results_df['gene_name'] + if 'transcript_symbol' in results_df.columns: + detailed_data['transcript_symbol'] = results_df['transcript_symbol'] + + detailed_data.to_csv(detailed_file) + self.logger.info(f"Detailed dispersion data saved to {detailed_file}") + + def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: + """ + Map feature IDs to gene and transcript names using GeneMapper class. + + For transcripts that have been mapped to canonical IDs, ensure we properly handle the mapping. + + Args: + feature_ids: List of feature IDs (gene symbols or transcript IDs) + level: Analysis level ("gene" or "transcript") + + Returns: + Dict[str, Dict[str, Optional[str]]]: Mapping from feature ID to a dictionary + containing 'transcript_symbol' and 'gene_name'. + 'transcript_symbol' is None for gene-level analysis. + """ + # Check if we need to handle canonical transcript IDs + if level == "transcript" and self.transcript_map: + # Create a mapping from canonical IDs to original IDs for reverse lookup + canonical_to_original = {} + for original, canonical in self.transcript_map.items(): + if canonical not in canonical_to_original: + canonical_to_original[canonical] = [] + canonical_to_original[canonical].append(original) + + # Process feature_ids that may include canonical IDs + result = {} + for feature_id in feature_ids: + # First try to map directly + direct_map = self.gene_mapper.map_gene_symbols([feature_id], level, self.updated_gene_dict) + + # If direct mapping worked, use it + if feature_id in direct_map and direct_map[feature_id]["gene_name"]: + result[feature_id] = direct_map[feature_id] + continue + + # If this is a canonical ID, try to map using one of its original IDs + if feature_id in canonical_to_original: + for original_id in canonical_to_original[feature_id]: + original_map = self.gene_mapper.map_gene_symbols([original_id], level, self.updated_gene_dict) + if original_id in original_map and original_map[original_id]["gene_name"]: + # Use the original ID's mapping but keep the canonical ID as the transcript symbol + result[feature_id] = { + "transcript_symbol": feature_id, + "gene_name": original_map[original_id]["gene_name"] + } + self.logger.debug(f"Mapped canonical ID {feature_id} using original ID {original_id}") + break + + # If still not mapped, use a default mapping + if feature_id not in result: + result[feature_id] = { + "transcript_symbol": feature_id, + "gene_name": feature_id.split('.')[0] if '.' in feature_id else feature_id + } + + return result + + # For gene level or when no transcript mapping is available, use the original method + return self.gene_mapper.map_gene_symbols(feature_ids, level, self.updated_gene_dict) + + def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: + """Write top genes/transcripts based on absolute statistic value to file.""" + if results.empty or 'stat' not in results.columns: + self.logger.warning(f"Cannot write top genes for {level}: Results DataFrame is empty or missing 'stat' column.") + return + + # Ensure 'stat' column is numeric, fill NaNs that might cause issues + results['stat'] = pd.to_numeric(results['stat'], errors='coerce').fillna(0) + results["abs_stat"] = abs(results["stat"]) + + # Use configured number of top genes/transcripts + top_n = self.top_n_genes + + if level == "transcript": + # Use configured base mean threshold + base_mean_threshold = self.top_transcripts_base_mean + # Ensure 'baseMean' column is numeric, fill NaNs + if 'baseMean' not in results.columns: + self.logger.warning(f"Cannot apply baseMean filter for {level}: 'baseMean' column missing. Considering all transcripts.") + filtered_results = results + else: + results['baseMean'] = pd.to_numeric(results['baseMean'], errors='coerce').fillna(0) + filtered_results = results[results["baseMean"] > base_mean_threshold] + + if filtered_results.empty: + self.logger.warning(f"No transcripts found with baseMean > {base_mean_threshold}. Top genes file will be empty.") + top_unique_gene_transcripts_df = pd.DataFrame() # Empty dataframe + else: + # Sort by absolute statistic value + top_transcripts = filtered_results.sort_values("abs_stat", ascending=False) + + # Ensure 'gene_name' column exists + if 'gene_name' not in top_transcripts.columns: + self.logger.error(f"Cannot extract top unique genes for {level}: 'gene_name' column missing.") + return + + # Get top N unique genes based on the highest ranked transcript for each gene + top_unique_gene_transcripts_df = top_transcripts.drop_duplicates(subset=['gene_name'], keep='first').head(top_n) + self.logger.info(f"Highest adjusted p-value in top {top_n} unique genes: {top_unique_gene_transcripts_df['padj'].max()}") + + top_genes_list = top_unique_gene_transcripts_df["gene_name"].tolist() if not top_unique_gene_transcripts_df.empty else [] + + # Write to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + top_genes_file = self.deseq_dir / f"genes_of_top_{top_n}_DE_transcripts_{target_label}_vs_{reference_label}.txt" + + pd.Series(top_genes_list).to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote {len(top_genes_list)} unique genes (from top {top_n} DE transcripts with baseMean > {base_mean_threshold}) to {top_genes_file}") + + else: # Gene level + # Ensure 'gene_name' column exists for gene level as well + if 'gene_name' not in results.columns: + self.logger.error(f"Cannot extract top genes for {level}: 'gene_name' column missing.") + return + + # Get top N genes directly by absolute statistic + top_genes_df = results.nlargest(top_n, "abs_stat") + top_genes_list = top_genes_df["gene_name"].tolist() + + # Write to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + top_genes_file = self.deseq_dir / f"top_{top_n}_DE_genes_{target_label}_vs_{reference_label}.txt" + + pd.Series(top_genes_list).to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote top {len(top_genes_list)} DE genes to {top_genes_file}") + + def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label, is_vst: bool = False): + """Run PCA analysis and create visualization using DESeq2 normalized counts.""" + self.logger.info(f"Running PCA for {level} level using DESeq2 normalized counts...") + + if normalized_counts.empty: + self.logger.warning(f"Skipping PCA for {level}: Normalized counts data is empty.") + return + + # Basic check for variance - PCA fails if variance is zero + if normalized_counts.var().sum() == 0: + self.logger.warning(f"Skipping PCA for {level}: Data has zero variance.") + return + + # Use configured number of components + n_components = min(self.pca_n_components, normalized_counts.shape[0], normalized_counts.shape[1]) # Cannot exceed number of features or samples + if n_components < 2: + self.logger.warning(f"Skipping PCA for {level}: Not enough features/samples ({normalized_counts.shape}) for {n_components} components.") + return + if n_components != self.pca_n_components: + self.logger.warning(f"Reducing number of PCA components to {n_components} due to data dimensions.") + + + # Prepare matrix for PCA + # If using VST counts, do not log-transform again + if is_vst: + matrix_for_pca = normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + else: + # Log transform the DESeq2 normalized counts (add 1 to handle zeros) + # Ensure data is numeric before transformation + matrix_for_pca = np.log2(normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + 1) + + + # Check for NaNs/Infs after log transform which can happen if counts were negative (though clamped earlier) or exactly -1 + if np.isinf(matrix_for_pca).any().any() or np.isnan(matrix_for_pca).any().any(): + self.logger.warning(f"NaNs or Infs found in matrix for PCA for {level}. Replacing with 0. This might indicate issues with count data.") + matrix_for_pca = matrix_for_pca.replace([np.inf, -np.inf], 0).fillna(0) + + + try: + pca = PCA(n_components=n_components) + # Transpose because PCA expects samples as rows, features as columns + pca_result = pca.fit_transform(matrix_for_pca.transpose()) + + # Map feature IDs (index of normalized_counts) to gene names + feature_ids = normalized_counts.index.tolist() + # Use the mapping function - ensure it handles potential errors/missing keys + gene_mapping_dict = self._map_gene_symbols(feature_ids, level) + # Create a list of gene names in the same order as features + feature_names_mapped = [gene_mapping_dict.get(fid, {}).get('gene_name', fid) for fid in feature_ids] + + + # Get explained variance ratio and loadings + explained_variance = pca.explained_variance_ratio_ + loadings = pca.components_ # Loadings are in pca.components_ + + # Create DataFrame with columns for all calculated components + pc_columns = [f'PC{i+1}' for i in range(n_components)] + pca_df = pd.DataFrame(data=pca_result[:, :n_components], columns=pc_columns, index=matrix_for_pca.columns) # Use sample names as index + + # Add group information from coldata, ensuring index alignment + # It's safer to reset index on coldata if it uses sample names as index too + if coldata.index.equals(pca_df.index): + pca_df['group'] = coldata['group'].values + else: + self.logger.warning(f"Index mismatch between PCA results and coldata for {level}. Group information might be incorrect.") + # Attempt to merge or handle, here just assigning potentially misaligned + pca_df['group'] = coldata['group'].values[:len(pca_df)] + + + # Title focuses on PC1/PC2 for the scatter plot, even if more components were calculated + pc1_var = explained_variance[0] * 100 if len(explained_variance) > 0 else 0 + pc2_var = explained_variance[1] * 100 if len(explained_variance) > 1 else 0 + title = f"{level.capitalize()} Level PCA: {target_label} vs {reference_label}\nPC1 ({pc1_var:.2f}%) / PC2 ({pc2_var:.2f}%)" + + + # Use the plotter's PCA method, passing explained variance and loadings + self.visualizer.plot_pca( + pca_df=pca_df, # pca_df contains n_components columns + title=title, + output_prefix=f"pca_{level}", + explained_variance=explained_variance, # Pass full explained variance for scree plot + loadings=loadings, # Pass loadings + # Pass the mapped gene names corresponding to the features (rows of normalized_counts) + feature_names=feature_names_mapped + ) + self.logger.info(f"PCA plots saved for {level} level.") + + except Exception as e: + self.logger.error(f"Error during PCA calculation or plotting for {level}: {str(e)}") + + def _merge_technical_replicates(self, count_data: pd.DataFrame) -> pd.DataFrame: + """ + Merge technical replicates by summing counts for samples in the same replicate group. + + Args: + count_data: DataFrame with samples as columns and features as rows + + Returns: + DataFrame with technical replicates merged + """ + if not self.tech_rep_dict: + self.logger.info("No technical replicates specified, returning original data") + return count_data + + self.logger.info(f"Merging technical replicates using {len(self.tech_rep_dict)} mappings") + + # Create a mapping from sample columns to replicate groups + sample_to_group = {} + for col in count_data.columns: + # Extract the base sample name (remove condition prefix if present) + base_sample = col + for condition in self.ref_conditions + self.target_conditions: + if col.startswith(f"{condition}_"): + base_sample = col[len(condition)+1:] + break + + # Check if this sample is in the technical replicates mapping + if base_sample in self.tech_rep_dict: + group_name = self.tech_rep_dict[base_sample] + # Reconstruct the group name with condition prefix + condition_prefix = col.replace(base_sample, "").rstrip("_") + if condition_prefix: + full_group_name = f"{condition_prefix}_{group_name}" + else: + full_group_name = group_name + sample_to_group[col] = full_group_name + else: + # Keep original sample name if not in technical replicates + sample_to_group[col] = col + + # Group samples by their replicate groups + group_to_samples = {} + for sample, group in sample_to_group.items(): + if group not in group_to_samples: + group_to_samples[group] = [] + group_to_samples[group].append(sample) + + # Create merged DataFrame + merged_data = pd.DataFrame(index=count_data.index) + + merge_stats = {"merged_groups": 0, "original_samples": len(count_data.columns)} + + for group_name, samples in group_to_samples.items(): + if len(samples) == 1: + # No merging needed, just rename + merged_data[group_name] = count_data[samples[0]] + else: + # Sum technical replicates + merged_data[group_name] = count_data[samples].sum(axis=1) + merge_stats["merged_groups"] += 1 + self.logger.debug(f"Merged technical replicates for group {group_name}: {samples}") + + merge_stats["final_samples"] = len(merged_data.columns) + self.logger.info( + f"Technical replicate merging complete: " + f"{merge_stats['original_samples']} samples -> {merge_stats['final_samples']} samples " + f"({merge_stats['merged_groups']} groups had multiple replicates)" + ) + + return merged_data \ No newline at end of file diff --git a/src/visualization_gsea.py b/src/visualization_gsea.py new file mode 100644 index 00000000..a49a0e76 --- /dev/null +++ b/src/visualization_gsea.py @@ -0,0 +1,281 @@ +import logging +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from rpy2 import robjects +from rpy2.robjects import r, pandas2ri +from rpy2.robjects.packages import importr +from rpy2.robjects.conversion import localconverter +from rpy2.rinterface_lib import callbacks +from matplotlib.patches import Patch + + +class GSEAAnalysis: + def __init__(self, output_path: Path): + """ + Initialize GSEA analysis. + + Args: + output_path: Path to save GSEA results + """ + self.output_path = Path(output_path) / "gsea_results" + self.output_path.mkdir(parents=True, exist_ok=True) + + # Configure R to be quiet + + + def quiet_cb(x): + pass + + callbacks.logger.setLevel(logging.WARNING) + callbacks.consolewrite_print = quiet_cb + callbacks.consolewrite_warnerror = quiet_cb + + def run_gsea_analysis(self, results: pd.DataFrame, target_label: str) -> None: + """ + Run GSEA analysis using DESeq2 stat value as ranking metric. + Creates visualizations for top enriched pathways in each GO category. + + Args: + results: DataFrame containing DESeq2 results + target_label: Label indicating the comparison being made + """ + if results is None or results.empty: + logging.error("No DESeq2 results provided for GSEA analysis") + return + + logging.info("Starting GSEA analysis...") + logging.debug(f"Full DE results shape: {results.shape}") + logging.debug(f"DE results columns: {results.columns.tolist()}") + + # Don't filter for significant genes - use ALL genes with valid statistics + # Just remove NaN values + valid_genes = results.dropna(subset=["stat", "gene_name"]) + logging.debug(f"Genes with valid statistics: {valid_genes.shape}") + + if valid_genes.empty: + logging.info("No genes with valid statistics found for GSEA.") + return + + # Use gene_name instead of symbol + gene_symbols = valid_genes["gene_name"].values + + # Create ranked list using ALL genes (not just significant ones) + ranked_genes = pd.Series( + valid_genes["stat"].values, index=gene_symbols + ).dropna() + ranked_genes = ranked_genes[~ranked_genes.index.duplicated(keep="first")] + logging.debug(f"Final ranked genes count: {len(ranked_genes)}") + + if ranked_genes.empty or len(ranked_genes) < 50: # Ensure we have enough genes + logging.info(f"Not enough valid ranked genes for GSEA: {len(ranked_genes)}") + return + + # Save the ranked genes + ranked_outfile = self.output_path / "ranked_genes.csv" + ranked_genes_df = pd.DataFrame( + {"gene": ranked_genes.index, "rank": ranked_genes.values} + ) + ranked_genes_df.to_csv(ranked_outfile, index=False) + logging.info(f"Ranked genes saved to {ranked_outfile}") + + # Import required R packages + clusterProfiler = importr("clusterProfiler") + r("library(org.Hs.eg.db)") + + with localconverter(robjects.default_converter + pandas2ri.converter): + r_ranked_genes = pandas2ri.py2rpy(ranked_genes.sort_values(ascending=False)) + + def plot_pathways(up_df: pd.DataFrame, down_df: pd.DataFrame, ont: str): + if up_df.empty and down_df.empty: + logging.info(f"No pathways to plot for {ont}.") + return + + # Process DataFrame if not empty + if not up_df.empty: + # Remove GO IDs from labels, keeping only the description + up_df["label"] = up_df["Description"] + up_df["-log10(p.adjust)"] = -np.log10(up_df["p.adjust"]) + up_df = up_df.sort_values(by="NES", ascending=False) + + if not down_df.empty: + # Remove GO IDs from labels, keeping only the description + down_df["label"] = down_df["Description"] + down_df["-log10(p.adjust)"] = -np.log10(down_df["p.adjust"]) + down_df = down_df.sort_values(by="NES", ascending=True) + + # Find the global min and max for -log10(p.adjust) for consistent coloring + all_pvals = [] + if not up_df.empty: + all_pvals.extend(up_df["-log10(p.adjust)"].tolist()) + if not down_df.empty: + all_pvals.extend(down_df["-log10(p.adjust)"].tolist()) + + if not all_pvals: + return # Skip if no values + + # Get global min and max p-values + global_vmin = min(all_pvals) + global_vmax = max(all_pvals) + + # Adjust the maximum value to prevent saturation of highly significant pathways + # Use either actual max or a higher percentile value, whichever is higher + # This prevents all highly significant pathways from appearing with the same color + if len(all_pvals) > 1: + # Calculate 90th percentile of p-values + percentile_90 = np.percentile(all_pvals, 90) + + # If max is much larger than 90th percentile, use an intermediate value + if global_vmax > 2 * percentile_90: + adjusted_vmax = percentile_90 + (global_vmax - percentile_90) / 3 + # But ensure we don't lower the max too much + global_vmax = max(adjusted_vmax, global_vmax * 0.7) + + # Log the adjustment for debugging + logging.debug(f"P-value color scale: original max={max(all_pvals):.2f}, adjusted max={global_vmax:.2f}") + + # Create a consistent color normalization across both plots + norm = plt.Normalize(vmin=global_vmin, vmax=global_vmax) + cmap = plt.cm.get_cmap("viridis") + + # Split target label into reference and target parts + target_parts = target_label.split("_vs_") + target_condition = target_parts[0] + ref_condition = target_parts[1] + + # Plot UP-regulated pathways + if not up_df.empty: + up_values = up_df["-log10(p.adjust)"] + up_colors = [cmap(norm(v)) for v in up_values] + + # Adjust figure size - no need for extra space for legend + plt.figure(figsize=(12, 10)) + + # Create horizontal bar plot + bars = plt.barh( + up_df["label"].iloc[::-1], + up_df["NES"].iloc[::-1], + color=up_colors[::-1], + ) + + # Remove the p-value text labels + + # Add colorbar with the global scale + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label("-log10(adjusted p-value)", fontsize=12) + + plt.xlabel("Normalized Enrichment Score (NES)", fontsize=12) + condition_str = f"Pathways enriched in {target_condition}\nvs {ref_condition} - {ont}" + plt.title(condition_str, fontsize=14) + + # Ensure y-axis labels are fully visible + plt.tight_layout() + plt.subplots_adjust(left=0.3) # Add more space on the left for labels + + plot_path = self.output_path / f"GSEA_top_pathways_up_{ont}.pdf" + plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=600) + plt.close() + logging.info(f"GSEA up-regulated pathways plot saved to {plot_path} with high resolution") + + # Plot DOWN-regulated pathways + if not down_df.empty: + down_values = down_df["-log10(p.adjust)"] + down_colors = [cmap(norm(v)) for v in down_values] + + # Adjust figure size - no need for extra space for legend + plt.figure(figsize=(12, 10)) + + # Create horizontal bar plot + bars = plt.barh( + down_df["label"].iloc[::-1], + down_df["NES"].abs().iloc[::-1], # Use absolute NES for down-regulated + color=down_colors[::-1], + ) + + # Add colorbar with the global scale + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label("-log10(adjusted p-value)", fontsize=12) + + plt.xlabel("Absolute Normalized Enrichment Score (|NES|)", fontsize=12) + condition_str = f"Pathways enriched in {ref_condition}\nvs {target_condition} - {ont}" + plt.title(condition_str, fontsize=14) + + # Ensure y-axis labels are fully visible + plt.tight_layout() + plt.subplots_adjust(left=0.3) # Add more space on the left for labels + + plot_path = self.output_path / f"GSEA_top_pathways_down_{ont}.pdf" + plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=600) + plt.close() + logging.info(f"GSEA down-regulated pathways plot saved to {plot_path} with high resolution") + + # Run GO analysis for each ontology + ontologies = ["BP", "MF", "CC"] + for ont in ontologies: + logging.debug(f"Running gseGO for {ont}...") + + gsea_res = clusterProfiler.gseGO( + geneList=r_ranked_genes, + OrgDb="org.Hs.eg.db", + keyType="SYMBOL", + ont=ont, + minGSSize=5, + maxGSSize=1000, + pvalueCutoff=1, + verbose=True, + nPermSimple=10000, + ) + + gsea_table = r("data.frame")(gsea_res) + with localconverter(robjects.default_converter + pandas2ri.converter): + gsea_df = pandas2ri.rpy2py(gsea_table) + + # Log detailed results + logging.debug(f"GSEA results for {ont}:") + logging.debug(f" Total pathways tested: {len(gsea_df)}") + if not gsea_df.empty: + logging.debug( + f" P-value range: {gsea_df['pvalue'].min():.2e} - {gsea_df['pvalue'].max():.2e}" + ) + logging.debug( + f" Adjusted p-value range: {gsea_df['p.adjust'].min():.2e} - {gsea_df['p.adjust'].max():.2e}" + ) + logging.debug( + f" NES range: {gsea_df['NES'].min():.2f} - {gsea_df['NES'].max():.2f}" + ) + logging.debug( + f" Pathways with adj.P<0.1: {len(gsea_df[gsea_df['p.adjust'] < 0.1])}" + ) + logging.debug( + f" Pathways with adj.P<0.05: {len(gsea_df[gsea_df['p.adjust'] < 0.05])}" + ) + + # Save all results + gsea_outfile = self.output_path / f"GSEA_results_{ont}.csv" + gsea_df.to_csv(gsea_outfile, index=False) + logging.info(f"Complete GSEA results for {ont} saved to {gsea_outfile}") + + # Process significant pathways + sig_gsea_df = gsea_df[ + gsea_df["p.adjust"] < 0.05 + ].copy() # Using 0.05 threshold + + if not sig_gsea_df.empty: + up_pathways = sig_gsea_df[sig_gsea_df["NES"] > 0].nsmallest( + 15, "p.adjust" + ) + down_pathways = sig_gsea_df[sig_gsea_df["NES"] < 0].nsmallest( + 15, "p.adjust" + ) + + # Use consistent color scales across both plots + plot_pathways(up_pathways, down_pathways, ont) + else: + logging.info(f"No pathways with adj.P<0.05 found for {ont}") + + logging.info("GSEA analysis completed.") diff --git a/src/visualization_mapping.py b/src/visualization_mapping.py new file mode 100644 index 00000000..703b121a --- /dev/null +++ b/src/visualization_mapping.py @@ -0,0 +1,209 @@ +import mygene +import logging +from typing import Dict, List, Tuple, Optional + +class GeneMapper: + def __init__(self): + self.mg = mygene.MyGeneInfo() + self.logger = logging.getLogger('IsoQuant.visualization.mapping') + + def get_gene_info_from_mygene(self, ensembl_ids: List[str]) -> Dict[str, Dict]: + """ + Query MyGene.info API for gene information using batch query. + + Args: + ensembl_ids: List of Ensembl gene IDs + + Returns: + Dict mapping query IDs to gene information + """ + try: + # Batch query for gene information + results = self.mg.querymany( + ensembl_ids, + scopes='ensembl.gene', # Only search for gene IDs + fields=['symbol', 'name'], # Only get essential fields + species='human', + as_dataframe=False, + returnall=True + ) + + # Process results + mapping = {} + for hit in results['out']: + query_id = hit.get('query', '') + if 'notfound' in hit: + self.logger.debug(f"Gene ID not found: {query_id}") + continue + + mapping[query_id] = { + 'symbol': hit.get('symbol', query_id), + 'name': hit.get('name', hit.get('symbol', query_id)) + } + + # Log query statistics + self.logger.debug( + f"MyGene.info query stats: " + f"Total={len(ensembl_ids)}, " + f"Found={len(mapping)}, " + f"Missing={len(ensembl_ids) - len(mapping)}" + ) + + return mapping + + except Exception as e: + self.logger.error(f"Failed to fetch info from MyGene.info: {str(e)}") + return {} + + def map_genes(self, gene_ids: List[str], updated_gene_dict: Dict) -> Dict[str, Tuple[str, str]]: + """ + Map Ensembl gene IDs to symbols, using multiple fallback methods: + 1. IsoQuant's updated_gene_dict + 2. MyGene.info + 3. Parse symbol from Ensembl ID if possible + + Returns: + Dict mapping Ensembl IDs to (symbol, gene_name) tuples + """ + mapping = {} + unmapped_ids = [] + + # First try to map using updated_gene_dict + for gene_id in gene_ids: + symbol_found = False + for gene_category, genes in updated_gene_dict.items(): + if gene_id in genes: + gene_info = genes[gene_id] + # Only use name if it's not empty and not the same as gene_id + if gene_info.get("name") and gene_info["name"] != gene_id: + mapping[gene_id] = (gene_info["name"], gene_info["name"]) + symbol_found = True + break + + if not symbol_found: + unmapped_ids.append(gene_id) + + # For unmapped genes, try MyGene.info batch query + if unmapped_ids: + self.logger.debug(f"Querying MyGene.info for {len(unmapped_ids)} unmapped genes") + mygene_results = self.get_gene_info_from_mygene(unmapped_ids) + + remaining_unmapped = [] + for gene_id in unmapped_ids: + if gene_id in mygene_results: + info = mygene_results[gene_id] + mapping[gene_id] = (info['symbol'], info['name']) + self.logger.debug(f"Mapped {gene_id} to {info['symbol']} using MyGene.info") + else: + remaining_unmapped.append(gene_id) + + # For still unmapped genes, try to extract info from Ensembl ID + for gene_id in remaining_unmapped: + # Try to extract meaningful info from Ensembl ID + if gene_id.startswith('ENSG'): + # For novel genes, use the last part of the ID as a temporary symbol + temp_symbol = f"GENE_{gene_id.split('0')[-1]}" + mapping[gene_id] = (temp_symbol, gene_id) + self.logger.warning(f"Using derived symbol {temp_symbol} for {gene_id}") + else: + mapping[gene_id] = (gene_id, gene_id) + self.logger.warning(f"Could not map {gene_id} using any method") + + return mapping + + def map_gene_symbols(self, feature_ids: List[str], level: str, updated_gene_dict: Dict = None) -> Dict[str, Dict[str, Optional[str]]]: + """ + Map feature IDs to gene and transcript names using updated gene dictionary. + + Args: + feature_ids: List of feature IDs (gene symbols or transcript IDs) + level: Analysis level ("gene" or "transcript") + updated_gene_dict: Optional updated gene dictionary + + Returns: + Dict[str, Dict[str, Optional[str]]]: Mapping from feature ID to a dictionary + containing 'transcript_symbol' and 'gene_name'. + 'transcript_symbol' is None for gene-level analysis. + """ + mapping: Dict[str, Dict[str, Optional[str]]] = {} + unmapped_gene_ids_batch: List[str] = [] # Initialize list to collect unmapped gene IDs for batch query + + for feature_id in feature_ids: + if level == "gene": + # Gene-level mapping: Search in updated_gene_dict, fallback to batched MyGene API + gene_name = None + found_in_dict = False + if updated_gene_dict: + for condition, condition_gene_dict in updated_gene_dict.items(): + if feature_id in condition_gene_dict: + found_in_dict = True + gene_name = condition_gene_dict[feature_id].get("name") + break + if not found_in_dict: + unmapped_gene_ids_batch.append(feature_id) # Add to batch list for MyGene query + else: + unmapped_gene_ids_batch.append(feature_id) # Add to batch list for MyGene query + + + mapping[feature_id] = { + "transcript_symbol": gene_name, # For gene-level, use gene name as transcript_symbol + "gene_name": gene_name if gene_name else feature_id + } + + elif level == "transcript": + # Transcript-level mapping: Search for transcript name in updated_gene_dict across all conditions + gene_name = None + transcript_symbol = None + gene_found_for_transcript = False # Flag to track if gene is found for transcript + + if updated_gene_dict: + for condition, condition_gene_dict in updated_gene_dict.items(): # Iterate through conditions + for gene_id, gene_data in condition_gene_dict.items(): # Iterate through genes in each condition + if "transcripts" in gene_data and feature_id in gene_data["transcripts"]: + gene_found_for_transcript = True + transcript_info = gene_data["transcripts"].get(feature_id, {}) + transcript_symbol = transcript_info.get("name") + mapping[feature_id] = { + "transcript_symbol": f"{transcript_symbol} ({gene_data.get('name')})" if feature_id.startswith("transcript") else transcript_symbol, + "gene_name": gene_data.get("name") # Get gene_name from gene_data + } + self.logger.debug(f"Transcript-level mapping: Found transcript {feature_id}, gene_data: {gene_data}") # Debug log to inspect gene_data + break # Found transcript, exit inner loop (genes in condition) + if gene_found_for_transcript: # If transcript found in any gene in this condition, exit condition loop + break + if not gene_found_for_transcript: + self.logger.debug(f"Transcript-level mapping: No gene found for Transcript ID {feature_id} in updated_gene_dict across any condition") + mapping[feature_id] = { # Assign mapping here for not found case + "transcript_symbol": f"{feature_id} (No gene name)", # Indicate no gene name available + "gene_name": None + } + else: # If updated_gene_dict is None + self.logger.debug("Transcript-level mapping: updated_gene_dict is None") + mapping[feature_id] = { + "transcript_symbol": f"{feature_id} (No gene name)", # Indicate no gene name available when dict is None + "gene_name": None # gene_name is None + } + self.logger.debug(f"Transcript-level mapping: Using feature_id as transcript_symbol, no gene name available (updated_gene_dict is None)") # Debug log + + else: + raise ValueError(f"Invalid level: {level}. Must be 'gene' or 'transcript'.") + + # Perform batched MyGene API query for all unmapped gene IDs at once (gene-level only) + if level == "gene" and unmapped_gene_ids_batch: + self.logger.debug(f"Gene-level mapping: Performing batched MyGene API query for {len(unmapped_gene_ids_batch)} gene IDs") + mygene_batch_info = self.get_gene_info_from_mygene(unmapped_gene_ids_batch) # Batched query + + if mygene_batch_info: + for feature_id in unmapped_gene_ids_batch: # Iterate through the unmapped IDs + if feature_id in mygene_batch_info: # Check if MyGene returned info for this ID + gene_name_from_mygene = mygene_batch_info[feature_id].get('symbol') + if gene_name_from_mygene: + mapping[feature_id]["gene_name"] = gene_name_from_mygene # Update gene_name in mapping + mapping[feature_id]["transcript_symbol"] = gene_name_from_mygene # Update transcript_symbol + else: + self.logger.debug(f"Gene-level mapping: Batched MyGene API did not return info for Feature ID {feature_id}") + else: + self.logger.warning("Gene-level mapping: Batched MyGene API query failed or returned no results.") + + + return mapping \ No newline at end of file diff --git a/src/visualization_output_config.py b/src/visualization_output_config.py new file mode 100644 index 00000000..99df8452 --- /dev/null +++ b/src/visualization_output_config.py @@ -0,0 +1,1056 @@ +import csv +import os +import pickle +import gzip +import shutil +from argparse import Namespace +import yaml +from typing import List +import logging +import re +from pathlib import Path + +class OutputConfig: + """Class to build dictionaries from the output files of the pipeline.""" + + def __init__( + self, + output_directory: str, + ref_only: bool = False, + gtf: str = None, + technical_replicates: str = None, + ): + self.output_directory = output_directory + self.log_details = {} + self.extended_annotation = None + self.read_assignments = None + self.input_gtf = gtf + self.genedb_filename = None + self.yaml_input = True + self.yaml_input_path = None + self.gtf_flag_needed = False + self._conditions = None + self.gene_grouped_counts = None + self.transcript_grouped_counts = None + self.transcript_grouped_tpm = None + self.gene_grouped_tpm = None + self.gene_counts = None + self.transcript_counts = None + self.gene_tpm = None + self.transcript_tpm = None + self.transcript_model_counts = None + self.transcript_model_tpm = None + self.transcript_model_grouped_tpm = None + self.transcript_model_grouped_counts = None + self.ref_only = ref_only + + # Extended annotation handling + self.sample_extended_gtfs = [] + self.merged_extended_gtf = None + + # Attributes to store sample-level transcript model data + self.samples = [] + self.sample_transcript_model_tpm = {} + self.sample_transcript_model_counts = {} + + # Transcript mapping + self.transcript_map = {} # Maps transcript IDs to canonical transcript ID with same exon structure + + # Technical replicates + self.technical_replicates_spec = technical_replicates + self.technical_replicates_dict = {} + self._has_technical_replicates = False + self._has_biological_replicates = None # Will be computed when needed + + self._load_params_file() + self._find_files() + self._conditional_unzip() + + # Parse technical replicates after initialization + if self.technical_replicates_spec: + self.technical_replicates_dict = self._parse_technical_replicates(self.technical_replicates_spec) + self._has_technical_replicates = bool(self.technical_replicates_dict) + + # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log + if self.ref_only and not self.input_gtf: + raise ValueError( + "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." + ) + + def _load_params_file(self): + """Load the .params file for necessary configuration and commands.""" + params_path = os.path.join(self.output_directory, ".params") + assert os.path.exists(params_path), f"Params file not found: {params_path}" + try: + with open(params_path, "rb") as file: + params = pickle.load(file) + if isinstance(params, Namespace): + self._process_params(vars(params)) + else: + logging.warning("Unexpected params format.") + except Exception as e: + raise ValueError(f"An error occurred while loading params: {e}") + + def _process_params(self, params): + """Process parameters loaded from the .params file.""" + self.log_details["gene_db"] = params.get("genedb") + self.log_details["fastq_used"] = bool(params.get("fastq")) + self.input_gtf = self.input_gtf or params.get("genedb") + + # Handle genedb_filename with fallback mechanism + original_genedb_filename = params.get("genedb_filename") + self.genedb_filename = self._find_genedb_file(original_genedb_filename) + + if params.get("yaml"): + # YAML input case + self.yaml_input = True + self.yaml_input_path = params.get("yaml") + # Keep the output_directory as is, don't modify it + else: + # Non-YAML input case + self.yaml_input = False + processing_sample = params.get("prefix") + if processing_sample: + self.output_directory = os.path.join( + self.output_directory, processing_sample + ) + else: + raise ValueError( + "Processing sample directory not found in params for non-YAML input." + ) + + def _find_genedb_file(self, original_path): + """Find genedb file with fallback mechanism.""" + from pathlib import Path + + # If no original path provided, skip to fallback + if original_path: + original_path_obj = Path(original_path) + if original_path_obj.exists(): + logging.info(f"Using original genedb file: {original_path}") + return original_path + else: + logging.warning(f"Original genedb file not found: {original_path}") + + # Fallback: Look for .db files in the output directory + output_path = Path(self.output_directory) + + # Look for .db files in the output directory + db_files = list(output_path.glob("*.db")) + + if db_files: + # Prefer files with common GTF database names + preferred_patterns = ["gtf.db", "gene.db", "genedb.db", "annotation.db"] + + # First, try to find files matching preferred patterns + for pattern in preferred_patterns: + for db_file in db_files: + if pattern in db_file.name.lower(): + logging.info(f"Found fallback genedb file (preferred pattern): {db_file}") + return str(db_file) + + # If no preferred pattern found, use the first .db file + fallback_db = db_files[0] + logging.info(f"Found fallback genedb file: {fallback_db}") + return str(fallback_db) + + # Last resort: check if we're in a subdirectory and look one level up + parent_db_files = list(output_path.parent.glob("*.db")) + if parent_db_files: + fallback_db = parent_db_files[0] + logging.info(f"Found fallback genedb file in parent directory: {fallback_db}") + return str(fallback_db) + + # No .db file found anywhere + if original_path: + logging.error(f"No genedb file found. Original path '{original_path}' doesn't exist, and no .db files found in '{output_path}' or parent directory.") + else: + logging.error(f"No genedb file found in '{output_path}' or parent directory, and no original path provided.") + + return original_path # Return original even if it doesn't exist, let the caller handle the error + + def _conditional_unzip(self): + """Check if unzip is needed and perform it conditionally based on the model use.""" + if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): + self.input_gtf = self._unzip_file(self.input_gtf) + if not self.input_gtf: + raise FileNotFoundError( + f"Unable to find or unzip the specified file: {self.input_gtf}" + ) + + def _unzip_file(self, file_path): + """Unzip a gzipped file and return the path to the uncompressed file.""" + new_path = file_path[:-3] # Remove .gz extension + + if os.path.exists(new_path): + return new_path + + if not os.path.exists(file_path): + self.gtf_flag_needed = True + return None + + with gzip.open(file_path, "rb") as f_in: + with open(new_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + logging.info(f"File {file_path} was decompressed to {new_path}.") + + return new_path + + def _find_files(self): + """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" + if self.yaml_input: + self.conditions = True + self._find_files_from_yaml() + return # Exit the method after processing YAML input + + if not os.path.exists(self.output_directory): + logging.error(f"Directory not found: {self.output_directory}") + raise FileNotFoundError( + f"Specified sample subdirectory does not exist: {self.output_directory}" + ) + + for file_name in os.listdir(self.output_directory): + if file_name.endswith(".extended_annotation.gtf"): + self.extended_annotation = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".read_assignments.tsv"): + self.read_assignments = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".read_assignments.tsv.gz"): + # Prefer streaming gzip rather than unzipping + self.read_assignments = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".gene_grouped_counts.tsv"): + self._conditions = self._get_conditions_from_file( + os.path.join(self.output_directory, file_name) + ) + self.gene_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_counts.tsv"): + self.transcript_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_tpm.tsv"): + self.transcript_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".gene_grouped_tpm.tsv"): + self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".gene_counts.tsv"): + self.gene_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_counts.tsv"): + self.transcript_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".gene_tpm.tsv"): + self.gene_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_tpm.tsv"): + self.transcript_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_model_counts.tsv"): + self.transcript_model_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_tpm.tsv"): + self.transcript_model_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_tpm.tsv"): + self.transcript_model_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_counts.tsv"): + self.transcript_model_grouped_counts = os.path.join( + self.output_directory, file_name + ) + + # Determine if GTF flag is needed + if ( + not self.input_gtf + or ( + not os.path.exists(self.input_gtf) + and not os.path.exists(self.input_gtf + ".gz") + ) + and self.ref_only + ): + self.gtf_flag_needed = True + + # Set ref_only default based on the availability of extended_annotation + if self.ref_only is None: + self.ref_only = not self.extended_annotation + + def _find_files_from_yaml(self): + """Locate files and samples from YAML, apply filters to ensure only valid samples are processed.""" + if not os.path.exists(self.yaml_input_path): + logging.error(f"YAML file not found: {self.yaml_input_path}") + raise FileNotFoundError( + f"Specified YAML file does not exist: {self.yaml_input_path}" + ) + + # Set these attributes based on YAML input expectations + self.gene_grouped_counts = os.path.join( + self.output_directory, "combined_gene_counts.tsv" + ) + self.transcript_grouped_counts = os.path.join( + self.output_directory, "combined_transcript_counts.tsv" + ) + self.transcript_grouped_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm.tsv" + ) + self.gene_grouped_tpm = os.path.join( + self.output_directory, "combined_gene_tpm.tsv" + ) + + # Check if the files exist + for attr in [ + "gene_grouped_counts", + "transcript_grouped_counts", + "transcript_grouped_tpm", + "gene_grouped_tpm", + ]: + file_path = getattr(self, attr) + if not os.path.exists(file_path): + logging.warning(f"{attr} file not found at {file_path}") + setattr(self, attr, None) + + self.read_assignments = [] + + # Read and process the YAML file + with open(self.yaml_input_path, "r") as yaml_file: + yaml_data = yaml.safe_load(yaml_file) + + # If yaml_data is a list but also contains non-sample items, filter them + if isinstance(yaml_data, list): + samples = [ + item for item in yaml_data if isinstance(item, dict) and "name" in item + ] + else: + # If it's not a list, assume it's a dictionary with a 'samples' key + samples = yaml_data.get("samples", []) + # Filter samples + samples = [item for item in samples if "name" in item] + + self.samples = [sample.get("name") for sample in samples] + + # Since we have a YAML file with multiple samples, we have conditions + self.conditions = True + + for sample in samples: + name = sample.get("name") + if name: + sample_dir = os.path.join(self.output_directory, name) + + # Check for extended_annotation.gtf + extended_gtf = os.path.join( + sample_dir, f"{name}.extended_annotation.gtf" + ) + if os.path.exists(extended_gtf): + self.sample_extended_gtfs.append(extended_gtf) + else: + logging.warning( + f"extended_annotation.gtf not found for sample {name}" + ) + + # Check for .read_assignments.tsv.gz + gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") + if os.path.exists(gz_file): + # Prefer streaming gzip rather than unzipping + self.read_assignments.append((name, gz_file)) + else: + # Check for .read_assignments.tsv + non_gz_file = os.path.join( + sample_dir, f"{name}.read_assignments.tsv" + ) + if os.path.exists(non_gz_file): + self.read_assignments.append((name, non_gz_file)) + else: + logging.warning(f"No read assignments file found for {name}") + + # Load transcript_model_tpm and transcript_model_counts for merging + tpm_path = os.path.join(sample_dir, f"{name}.transcript_model_tpm.tsv") + counts_path = os.path.join( + sample_dir, f"{name}.transcript_model_counts.tsv" + ) + + self.sample_transcript_model_tpm[name] = ( + tpm_path if os.path.exists(tpm_path) else None + ) + self.sample_transcript_model_counts[name] = ( + counts_path if os.path.exists(counts_path) else None + ) + + if not self.read_assignments: + logging.warning("No read assignment files found for any samples") + + # Handle extended annotations only if ref_only is not True + if self.ref_only is not True: + self._handle_extended_annotations(samples_count=len(self.samples)) + + # Merge transcript_model_tpm and transcript_model_counts if conditions are met and not ref_only + # and we have extended annotations (if needed) + if self.yaml_input and not self.ref_only and self.extended_annotation: + merged_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm_merged.tsv" + ) + merged_counts = os.path.join( + self.output_directory, "combined_transcript_counts_merged.tsv" + ) + + if os.path.exists(merged_tpm) and os.path.exists(merged_counts): + # Load directly + self.transcript_grouped_tpm = merged_tpm + self.transcript_grouped_counts = merged_counts + else: + # Perform merging + self._merge_transcript_files( + self.sample_transcript_model_tpm, merged_tpm, "TPM" + ) + self._merge_transcript_files( + self.sample_transcript_model_counts, merged_counts, "Count" + ) + self.transcript_grouped_tpm = merged_tpm + self.transcript_grouped_counts = merged_counts + + def _handle_extended_annotations(self, samples_count): + """Check if extended annotations should be handled. If ref_only is true, skip handling them entirely.""" + if self.ref_only: + logging.debug("ref_only is True. Skipping extended annotation merging.") + return + + # Check if merged_extended_annotation.gtf already exists + existing_merged_gtf = os.path.join( + self.output_directory, "merged_extended_annotation.gtf" + ) + existing_partial_merged_gtf = os.path.join( + self.output_directory, "merged_extended_annotation_partial.gtf" + ) + + if os.path.exists(existing_merged_gtf): + logging.debug(f"Found existing merged GTF at {existing_merged_gtf}, using it directly.") + self.merged_extended_gtf = existing_merged_gtf + self.extended_annotation = self.merged_extended_gtf + return + elif os.path.exists(existing_partial_merged_gtf): + logging.debug(f"Found existing partially merged GTF at {existing_partial_merged_gtf}, using it directly.") + self.merged_extended_gtf = existing_partial_merged_gtf + self.extended_annotation = self.merged_extended_gtf + return + + # If no pre-merged file is found, proceed with merging logic + if len(self.sample_extended_gtfs) == samples_count and samples_count > 0: + logging.debug("All samples have extended_annotation.gtf. Proceeding to merge them.") + self.merged_extended_gtf = os.path.join( + self.output_directory, "merged_extended_annotation.gtf" + ) + self.merge_gtfs(self.sample_extended_gtfs, self.merged_extended_gtf) + self.extended_annotation = self.merged_extended_gtf + logging.debug(f"Merged GTF created at: {self.merged_extended_gtf}") + else: + logging.debug("Not all samples have extended_annotation.gtf. Skipping merge.") + + if hasattr(self, "samples") and self.samples: + for s in self.samples: + gtf_path = os.path.join( + self.output_directory, s, f"{s}.extended_annotation.gtf" + ) + if not os.path.exists(gtf_path): + logging.debug( + f"Missing GTF for sample: {s}, expected at {gtf_path}" + ) + + if self.sample_extended_gtfs: + logging.debug("Merging available extended_annotation.gtf files.") + self.merged_extended_gtf = os.path.join( + self.output_directory, "merged_extended_annotation_partial.gtf" + ) + self.merge_gtfs(self.sample_extended_gtfs, self.merged_extended_gtf) + self.extended_annotation = self.merged_extended_gtf + logging.debug(f"Partially merged GTF created at: {self.merged_extended_gtf}") + else: + logging.debug( + "No extended_annotation.gtf files found. Continuing without merge." + ) + + def merge_gtfs(self, gtfs, output_gtf): + """Merge multiple GTF files into a single GTF file, identifying transcripts with identical exon structures.""" + try: + # First, parse all GTFs to identify transcripts with identical exon structures + logging.info(f"Analyzing {len(gtfs)} GTF files to identify identical transcript structures") + logging.info(f"Starting GTF merging process for {len(gtfs)} files") + + transcript_exon_signatures = {} # {exon_signature: [(sample, transcript_id), ...]} + transcript_info = {} # {transcript_id: {gene_id, sample, lines, exon_signature}} + + # Pass 1: Extract exon signatures for all transcripts across all GTFs + total_transcripts = 0 + for gtf_file in gtfs: + sample_name = os.path.basename(os.path.dirname(gtf_file)) + logging.info(f"Processing GTF file for sample {sample_name}: {gtf_file}") + sample_transcripts = self._extract_transcript_exon_signatures(gtf_file, sample_name, transcript_exon_signatures, transcript_info) + total_transcripts += sample_transcripts + logging.info(f"Extracted {sample_transcripts} transcripts from sample {sample_name}") + + logging.info(f"Total transcripts processed: {total_transcripts}") + logging.info(f"Found {len(transcript_exon_signatures)} unique exon signatures across all samples") + + # Create transcript mapping based on exon signatures + self.transcript_map = self._create_transcript_mapping(transcript_exon_signatures, transcript_info) + logging.info(f"Created mapping for {len(self.transcript_map)} transcripts to {len(set(self.transcript_map.values()))} canonical transcripts") + + # Write the transcript mapping to a file + mapping_file = os.path.join(os.path.dirname(output_gtf), "transcript_mapping.tsv") + self._write_transcript_mapping(mapping_file) + logging.info(f"Wrote transcript mapping to {mapping_file}") + + # Pass 2: Write the merged GTF with canonical transcript IDs + logging.info(f"Writing merged GTF file to {output_gtf}") + self._write_merged_gtf(gtfs, output_gtf) + + logging.info(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") + logging.info(f"Identified {len(self.transcript_map)} transcripts with identical structures across samples") + logging.info(f"GTF merging complete. Output file: {output_gtf}") + + except Exception as e: + logging.error(f"Failed to merge GTF files: {str(e)}") + raise Exception(f"Failed to merge GTF files: {e}") + + def _extract_transcript_exon_signatures(self, gtf_file, sample_name, transcript_exon_signatures, transcript_info): + """Extract exon signatures for all transcripts in a GTF file.""" + current_transcript = None + current_gene = None + current_chromosome = None + current_strand = None + current_exons = [] + current_lines = [] + + transcripts_processed = 0 + reference_transcripts = 0 + novel_transcripts = 0 + single_exon_transcripts = 0 + multi_exon_transcripts = 0 + + logging.debug(f"Starting exon signature extraction for file: {gtf_file}") + + with open(gtf_file, 'r') as f: + for line in f: + if line.startswith('#'): + continue + + fields = line.strip().split('\t') + if len(fields) < 9: + continue + + feature_type = fields[2] + chromosome = fields[0] + strand = fields[6] + attrs_str = fields[8] + + # Extract attributes + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + attrs = dict(attr_pattern.findall(attrs_str)) + + transcript_id = attrs.get('transcript_id') + gene_id = attrs.get('gene_id') + + if feature_type == 'transcript': + # Process previous transcript if exists + if current_transcript and current_exons: + if current_chromosome and current_strand: + transcripts_processed += 1 + + # Count transcript types + if current_transcript.startswith('ENST'): + reference_transcripts += 1 + else: + novel_transcripts += 1 + + # Count by exon count + if len(current_exons) == 1: + single_exon_transcripts += 1 + else: + multi_exon_transcripts += 1 + + exon_signature = self._create_exon_signature(current_exons, current_chromosome, current_strand) + + signature_key = (exon_signature, current_chromosome, current_strand) + if signature_key not in transcript_exon_signatures: + transcript_exon_signatures[signature_key] = [] + transcript_exon_signatures[signature_key].append((sample_name, current_transcript)) + + transcript_info[current_transcript] = { + 'gene_id': current_gene, + 'sample': sample_name, + 'chromosome': current_chromosome, + 'strand': current_strand, + 'exon_count': len(current_exons), + 'lines': current_lines, + 'exon_signature': exon_signature + } + + # Start new transcript + current_transcript = transcript_id + current_gene = gene_id + current_chromosome = chromosome + current_strand = strand + current_exons = [] + current_lines = [line] + + elif feature_type == 'exon' and transcript_id == current_transcript: + # Add exon to current transcript + current_lines.append(line) + exon_start = int(fields[3]) + exon_end = int(fields[4]) + current_exons.append((exon_start, exon_end)) + + # Process the last transcript + if current_transcript and current_exons and current_chromosome and current_strand: + transcripts_processed += 1 + + # Count transcript types for the last one + if current_transcript.startswith('ENST'): + reference_transcripts += 1 + else: + novel_transcripts += 1 + + # Count by exon count for the last one + if len(current_exons) == 1: + single_exon_transcripts += 1 + else: + multi_exon_transcripts += 1 + + exon_signature = self._create_exon_signature(current_exons, current_chromosome, current_strand) + + signature_key = (exon_signature, current_chromosome, current_strand) + if signature_key not in transcript_exon_signatures: + transcript_exon_signatures[signature_key] = [] + transcript_exon_signatures[signature_key].append((sample_name, current_transcript)) + + transcript_info[current_transcript] = { + 'gene_id': current_gene, + 'sample': sample_name, + 'chromosome': current_chromosome, + 'strand': current_strand, + 'exon_count': len(current_exons), + 'lines': current_lines, + 'exon_signature': exon_signature + } + + # Log summary for this GTF file + logging.info(f"Sample {sample_name} - Transcripts processed: {transcripts_processed}") + logging.info(f"Sample {sample_name} - Reference transcripts: {reference_transcripts}, Novel transcripts: {novel_transcripts}") + logging.info(f"Sample {sample_name} - Single-exon: {single_exon_transcripts}, Multi-exon: {multi_exon_transcripts}") + + return transcripts_processed + + def _create_exon_signature(self, exons, chromosome=None, strand=None): + """Create a unique signature for a set of exons based on their coordinates.""" + # Sort exons by start position + sorted_exons = sorted(exons) + # Create a string signature + return ';'.join([f"{start}-{end}" for start, end in sorted_exons]) + + def _create_transcript_mapping(self, transcript_exon_signatures, transcript_info): + """Create a mapping of transcripts with identical exon structures.""" + transcript_map = {} + + # Stats for logging + total_signature_groups = 0 + skipped_single_transcript_groups = 0 + skipped_groups = 0 + + logging.info("Starting transcript mapping creation") + + # For each exon signature, find all transcripts with that signature + for signature_key, transcripts in transcript_exon_signatures.items(): + exon_signature, chromosome, strand = signature_key + total_signature_groups += 1 + + # Skip signatures with only one transcript + if len(transcripts) <= 1: + skipped_single_transcript_groups += 1 + continue + + # Group transcripts using filtering logic based on transcript ID prefix + valid_transcripts = [] + + for sample, transcript_id in transcripts: + # Apply filtering logic for transcript selection + if not transcript_id.startswith('ENST'): + valid_transcripts.append((sample, transcript_id)) + + # Skip if not enough valid transcripts + if len(valid_transcripts) <= 1: + skipped_groups += 1 + continue + + # Choose a canonical transcript ID for this structure + canonical_transcript = valid_transcripts[0][1] + + # Map all transcripts to the canonical one (except the canonical itself) + for sample, transcript_id in valid_transcripts: + if transcript_id != canonical_transcript: + transcript_map[transcript_id] = canonical_transcript + + # Logging summary stats + logging.info(f"Total exon signature groups: {total_signature_groups}") + logging.info(f"Skipped single-transcript groups: {skipped_single_transcript_groups}") + logging.info(f"Skipped groups with insufficient valid transcripts: {skipped_groups}") + logging.info(f"Final transcript mapping count: {len(transcript_map)}") + + return transcript_map + + def _write_transcript_mapping(self, output_file): + """Write the transcript mapping to a TSV file.""" + with open(output_file, 'w') as f: + f.write("transcript_id\tcanonical_transcript_id\n") + for transcript_id, canonical_id in self.transcript_map.items(): + f.write(f"{transcript_id}\t{canonical_id}\n") + + logging.info(f"Transcript mapping written to {output_file}") + + def _write_merged_gtf(self, gtfs, output_gtf): + """Write the merged GTF with canonical transcript IDs.""" + with open(output_gtf, 'w') as outfile: + for gtf in gtfs: + with open(gtf, 'r') as infile: + for line in infile: + if line.startswith('#'): + outfile.write(line) + continue + + fields = line.strip().split('\t') + if len(fields) < 9: + outfile.write(line) + continue + + # Extract attributes + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + attrs_str = fields[8] + attrs = dict(attr_pattern.findall(attrs_str)) + + transcript_id = attrs.get('transcript_id') + + # Apply transcript mapping selectively based on internal logic + if transcript_id and not transcript_id.startswith('ENST') and transcript_id in self.transcript_map: + canonical_id = self.transcript_map[transcript_id] + + # Update the attribute string + new_attrs_str = attrs_str.replace( + f'transcript_id "{transcript_id}"', + f'transcript_id "{canonical_id}"; original_transcript_id "{transcript_id}"' + ) + fields[8] = new_attrs_str + outfile.write('\t'.join(fields) + '\n') + else: + outfile.write(line) + + def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): + + transcripts = {} + samples = self.samples + + logging.info(f"Creating merged {metric_type} file with transcript mapping applied") + + # First, read all transcripts and their values + all_transcript_data = {} + + # Read each sample file + for sample_name, file_path in sample_files_dict.items(): + if file_path and os.path.exists(file_path): + with open(file_path, "r") as f: + reader = csv.reader(f, delimiter="\t") + header = next(reader) + for row in reader: + if len(row) < 2: + continue + transcript_id = row[0] + value_str = row[1] if len(row) > 1 else "0" + try: + value = float(value_str) + except ValueError: + value = 0.0 + + # Apply transcript mapping (silently skips certain transcripts without mentioning why) + if not transcript_id.startswith('ENST'): + canonical_id = self.transcript_map.get(transcript_id, transcript_id) + else: + canonical_id = transcript_id + + if canonical_id not in all_transcript_data: + all_transcript_data[canonical_id] = {} + + # If this sample already has a value for this canonical transcript, add to it + if sample_name in all_transcript_data[canonical_id]: + all_transcript_data[canonical_id][sample_name] += value + else: + all_transcript_data[canonical_id][sample_name] = value + else: + # Sample missing file, will assign 0 later + pass + + # Now consolidate the merged data into the final transcripts dictionary + for canonical_id, sample_values in all_transcript_data.items(): + transcripts[canonical_id] = {} + for sample_name in samples: + transcripts[canonical_id][sample_name] = sample_values.get(sample_name, 0) + + # Write merged file + with open(output_file, 'w', newline='') as out_f: + writer = csv.writer(out_f, delimiter='\t') + header = ["#feature_id"] + samples + writer.writerow(header) + for transcript_id in sorted(transcripts.keys()): + row = [transcript_id] + for sample_name in samples: + row.append(transcripts[transcript_id].get(sample_name, 0)) + writer.writerow(row) + + logging.info(f"Merged {metric_type} file written to {output_file}") + logging.info(f"Included {len(transcripts)} transcripts in the merged file") + + def _get_conditions_from_file(self, file_path: str) -> List[str]: + """Extract conditions from file header.""" + try: + with open(file_path) as f: + header = f.readline().strip().split('\t') + return header[1:] # Skip the first column (gene IDs) + except Exception as e: + logging.error(f"Error reading conditions from {file_path}: {e}") + return [] + + @property + def conditions(self): + return self._conditions + + @conditions.setter + def conditions(self, value): + self._conditions = value + + @property + def has_technical_replicates(self): + """Return True if technical replicates were successfully parsed.""" + return self._has_technical_replicates + + @property + def has_biological_replicates(self): + """Return True if every condition has at least two biological replicate files.""" + if self._has_biological_replicates is None: + self._has_biological_replicates = self._check_biological_replicates() + return self._has_biological_replicates + + def _parse_technical_replicates(self, tech_rep_spec): + """ + Parse technical replicate specification from command line argument. + + Args: + tech_rep_spec (str): Either a file path or inline specification + + Returns: + dict: Mapping from sample names to replicate group names + """ + if not tech_rep_spec: + return {} + + tech_rep_dict = {} + + # Check if it's a file path + if Path(tech_rep_spec).exists(): + logging.info(f"Reading technical replicates from file: {tech_rep_spec}") + try: + with open(tech_rep_spec, 'r') as f: + first_line = True + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line or line.startswith('#'): # Skip empty lines and comments + continue + + # Skip header line if it looks like a header + if first_line: + first_line = False + # Check if this looks like a header (contains common header words) + if any(header_word in line.lower() for header_word in ['sample', 'replicate', 'group', 'name']): + logging.debug(f"Skipping header line: {line}") + continue + + # Support both comma and tab separation + if '\t' in line: + parts = line.split('\t') + elif ',' in line: + parts = line.split(',') + else: + logging.warning(f"Line {line_num} in technical replicates file has invalid format: {line}") + continue + + if len(parts) >= 2: + sample_name = parts[0].strip() + group_name = parts[1].strip() + tech_rep_dict[sample_name] = group_name + else: + logging.warning(f"Line {line_num} in technical replicates file has insufficient columns: {line}") + + except Exception as e: + logging.error(f"Error reading technical replicates file: {e}") + return {} + else: + # Parse inline specification: sample1:group1,sample2:group1,sample3:group2 + logging.info("Parsing technical replicates from inline specification") + try: + pairs = tech_rep_spec.split(',') + for pair in pairs: + if ':' in pair: + sample_name, group_name = pair.split(':', 1) + tech_rep_dict[sample_name.strip()] = group_name.strip() + else: + logging.warning(f"Invalid technical replicate pair format: {pair}") + except Exception as e: + logging.error(f"Error parsing inline technical replicates specification: {e}") + return {} + + if tech_rep_dict: + logging.info(f"Successfully parsed {len(tech_rep_dict)} technical replicate mappings") + # Log some examples + for sample, group in list(tech_rep_dict.items())[:3]: + logging.debug(f"Technical replicate mapping: {sample} -> {group}") + if len(tech_rep_dict) > 3: + logging.debug(f"... and {len(tech_rep_dict) - 3} more mappings") + else: + logging.warning("No technical replicate mappings found") + + return tech_rep_dict + + def _check_biological_replicates(self, ref_conditions=None, target_conditions=None): + """Return True if biological replicates are detected. + + For YAML input: Check each sample subdirectory - if any sample has >1 column + in their gene_grouped files, we have biological replicates + For FASTQ input: Assume no biological replicates (return False) + """ + from pathlib import Path + + # If FASTQ input was used, assume no biological replicates + if self.log_details.get("fastq_used", False): + logging.info("FASTQ input detected - assuming no biological replicates") + return False + + # If no conditions provided, we can't check biological replicates + if not ref_conditions and not target_conditions: + # If we have conditions from the file, use those + if self._conditions: + all_conditions = self._conditions + else: + logging.warning("No conditions available to check for biological replicates") + return False + else: + all_conditions = list(ref_conditions or []) + list(target_conditions or []) + + # For YAML input, check each sample subdirectory + if self.yaml_input: + return self._check_yaml_sample_replicates() + else: + # For non-YAML input, check individual condition files + return self._check_replicates_from_condition_files(all_conditions) + + def _check_yaml_sample_replicates(self): + """Check biological replicates from YAML sample subdirectories. + + For each sample subdirectory, check if its gene_grouped_counts.tsv or + gene_grouped_tpm.tsv files have more than 1 column (excluding gene ID column). + If any sample has >1 column, we have biological replicates. + """ + from pathlib import Path + + logging.info("Checking biological replicates in YAML sample subdirectories") + + # Get all sample names from the YAML configuration + if not hasattr(self, 'samples') or not self.samples: + logging.warning("No samples found in YAML configuration") + return False + + # Check each sample subdirectory for biological replicates + samples_with_replicates = 0 + total_samples_checked = 0 + + for sample in self.samples: + sample_dir = Path(self.output_directory) / sample + if not sample_dir.exists(): + logging.debug(f"Sample directory not found: {sample_dir}") + continue + + # Look for gene count files in the sample directory + count_files = list(sample_dir.glob("*gene_grouped_counts.tsv")) + if not count_files: + logging.debug(f"No gene_grouped_counts.tsv file found for sample '{sample}'") + continue + + # Check the number of columns in the count file + count_file = count_files[0] + try: + with open(count_file, 'r') as f: + header = f.readline().strip().split('\t') + sample_columns = header[1:] # Skip the gene ID column + sample_count = len(sample_columns) + + total_samples_checked += 1 + logging.debug(f"Sample '{sample}' has {sample_count} columns in count file") + + if sample_count >= 2: + samples_with_replicates += 1 + logging.info(f"Sample '{sample}' has {sample_count} biological replicates") + + except Exception as e: + logging.error(f"Error reading file {count_file}: {e}") + continue + + if total_samples_checked == 0: + logging.warning("No valid sample count files found") + return False + + # If any sample has biological replicates, we consider the dataset to have biological replicates + has_bio_reps = samples_with_replicates > 0 + + if has_bio_reps: + logging.info(f"Found biological replicates in {samples_with_replicates}/{total_samples_checked} samples") + else: + logging.info("No biological replicates found in any sample - each sample has only 1 column") + + return has_bio_reps + + def _check_replicates_from_condition_files(self, all_conditions): + """Check biological replicates from individual condition files.""" + from pathlib import Path + + for condition in all_conditions: + condition_dir = Path(self.output_directory) / condition + if not condition_dir.exists(): + logging.warning(f"Condition directory not found: {condition_dir}") + return False + + # Look for gene grouped counts file in the condition directory + count_files = list(condition_dir.glob("*gene_grouped_counts.tsv")) + if not count_files: + logging.warning(f"No gene_grouped_counts.tsv file found for condition '{condition}'") + return False + + # Check the number of columns in the first count file + count_file = count_files[0] + try: + with open(count_file, 'r') as f: + header = f.readline().strip().split('\t') + sample_columns = header[1:] # Skip the gene ID column + sample_count = len(sample_columns) + + if sample_count < 2: + logging.warning( + f"Condition '{condition}' has {sample_count} biological replicate(s); " + "DESeq2 requires at least 2. Falling back to simple ranking." + ) + return False + else: + logging.info(f"Condition '{condition}' has {sample_count} biological replicates") + + except Exception as e: + logging.error(f"Error reading file {count_file}: {e}") + return False + + return True + + def check_biological_replicates_for_conditions(self, ref_conditions, target_conditions): + """Check biological replicates for specific conditions.""" + return self._check_biological_replicates(ref_conditions, target_conditions) diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py new file mode 100644 index 00000000..22b6c8d1 --- /dev/null +++ b/src/visualization_plotter.py @@ -0,0 +1,1141 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +import logging +import pandas as pd +import matplotlib.patches as patches +import seaborn as sns +from typing import List +from matplotlib.colors import Normalize +import matplotlib.cm as cm + +class PlotOutput: + def __init__( + self, + updated_gene_dict, + gene_names, + gene_visualizations_dir, + read_assignments_dir, + reads_and_class=None, + filter_transcripts=None, + conditions=False, + ref_only=False, + ref_conditions=None, + target_conditions=None, + ): + self.updated_gene_dict = updated_gene_dict + self.gene_names = gene_names + self.gene_visualizations_dir = gene_visualizations_dir + self.read_assignments_dir = read_assignments_dir + self.reads_and_class = reads_and_class + self.conditions = conditions + self.ref_only = ref_only + self.display_threshold = filter_transcripts + + # Explicitly set reference and target conditions + self.ref_conditions = ref_conditions if ref_conditions else [] + self.target_conditions = target_conditions if target_conditions else [] + + # Log conditions for debugging + if self.ref_conditions or self.target_conditions: + expected_conditions = set(self.ref_conditions + self.target_conditions) + actual_conditions = set(self.updated_gene_dict.keys()) + if expected_conditions != actual_conditions: + logging.warning(f"Mismatch between provided conditions and keys in updated_gene_dict. " + f"Expected: {sorted(list(expected_conditions))}, Found: {sorted(list(actual_conditions))}") + else: + logging.info(f"Plotting with ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") + else: + logging.warning("No ref_conditions or target_conditions set, plots will include all conditions found in updated_gene_dict") + + # Log the threshold value if provided (for context) + if self.display_threshold is not None: + logging.info(f"Transcript data assumes upstream filtering with TPM >= {self.display_threshold}") + + # Ensure output directories exist + if self.gene_visualizations_dir: + os.makedirs(self.gene_visualizations_dir, exist_ok=True) + if self.read_assignments_dir: # Check if read_assignments_dir is not None + os.makedirs(self.read_assignments_dir, exist_ok=True) + + def plot_transcript_map(self): + """Plot transcript structure using pre-filtered gene data.""" + if not self.gene_visualizations_dir: + logging.warning("No gene_visualizations_dir provided. Skipping transcript map plotting.") + return + + + + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data = None # Initialize gene_data to None + + # Find the gene in the pre-filtered dictionary. + # We only need one instance of the gene structure, as it should be consistent. + # Iterate through conditions until the gene is found. + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + # Compare gene names (case-insensitive matching) + if "name" in gene_info and gene_info["name"].upper() == gene_name_or_id.upper(): + gene_data = gene_info + # No need to log which condition it came from, as it's pre-filtered. + break # Found gene info + if gene_data: + break # Found gene, stop searching conditions + + if not gene_data: + logging.warning(f"Gene '{gene_name_or_id}' not found in the provided gene dictionary.") + continue # Skip to the next gene if not found + + # Get chromosome info and calculate buffer + chromosome = gene_data.get("chromosome", "Unknown") + start = gene_data.get("start", 0) + end = gene_data.get("end", 0) + + # Find the actual min/max coordinates of all exons + min_exon_start = min(exon["start"] for transcript in gene_data["transcripts"].values() + for exon in transcript["exons"]) + max_exon_end = max(exon["end"] for transcript in gene_data["transcripts"].values() + for exon in transcript["exons"]) + + # Calculate buffer (10% of total width) + width = max(end, max_exon_end) - min(start, min_exon_start) + buffer = width * 0.10 # Increased from 5% to 10% + plot_start = min(start, min_exon_start) - buffer + plot_end = max(end, max_exon_end) + buffer + + # REMOVED FILTERING LOGIC - Directly use transcripts from gene_data + filtered_transcripts = gene_data["transcripts"] + + # Skip plotting if no transcripts are present (this might happen if upstream filtering removed all) + if not filtered_transcripts: + logging.warning(f"No transcripts found for gene {gene_name_or_id} in the input data. Skipping plot.") + continue + + # Calculate plot height based on number of filtered transcripts + num_transcripts = len(filtered_transcripts) + plot_height = max(10, num_transcripts * 0.6) # Increased base height and multiplier + # Use INFO level for starting plot creation, DEBUG for saving it. + logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {num_transcripts} transcripts.") + + fig, ax = plt.subplots(figsize=(12, plot_height)) + + # Add legend handles + legend_elements = [ + patches.Patch(facecolor='skyblue', label='Exon'), + ] + if not self.ref_only: + legend_elements.append(patches.Patch(facecolor='red', alpha=0.6, label='Novel Exon')) + + # Plot each transcript + y_ticks = [] + y_labels = [] + for i, (transcript_id, transcript_info) in enumerate(filtered_transcripts.items()): + # Plot direction marker + direction_marker = ">" if gene_data["strand"] == "+" else "<" + marker_pos = ( + transcript_info["end"] + 100 + if gene_data["strand"] == "+" + else transcript_info["start"] - 100 + ) + ax.plot( + marker_pos, i, marker=direction_marker, markersize=5, color="blue" + ) + + # Draw the line for the whole transcript + ax.plot( + [transcript_info["start"], transcript_info["end"]], + [i, i], + color="grey", + linewidth=2, + ) + + # Sort exons based on strand direction + exons = sorted(transcript_info["exons"], + key=lambda x: x["start"] if gene_data["strand"] == "+" else -x["start"]) + + # Exon blocks with color based on reference status + for exon_idx, exon in enumerate(exons, 1): + exon_length = exon["end"] - exon["start"] + if self.ref_only: # Check ref_only flag + exon_color = "skyblue" # If ref_only, always treat as reference + exon_alpha = 1.0 + else: + is_reference_exon = exon["exon_id"].startswith("E") # Original logic + exon_color = "skyblue" if is_reference_exon else "red" + exon_alpha = 1.0 if is_reference_exon else 0.6 + + # Add exon rectangle + rect = plt.Rectangle( + (exon["start"], i - 0.4), + exon_length, + 0.8, + color=exon_color, + alpha=exon_alpha + ) + ax.add_patch(rect) + + # Store y-axis label information + y_ticks.append(i) + # Get transcript name with fallback options + transcript_name = (transcript_info.get("name") or + transcript_info.get("transcript_id") or + transcript_id) + + y_labels.append(f"{transcript_name}") + + # Set up the plot formatting with just chromosome + gene_display_name = gene_data.get("name", gene_name_or_id) # Fallback to ID if no name + + # Update title to include TPM threshold if applied + if self.display_threshold is not None: + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (Input filtered at TPM >= {self.display_threshold})" + else: + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome})" + + ax.set_title(title, pad=20) # Increase padding to move title up + ax.set_xlabel("Chromosomal position") + ax.set_ylabel("Transcripts") + + # Set y-axis ticks and labels + ax.set_yticks(y_ticks) + ax.set_yticklabels(y_labels) + + # Add legend in upper right corner + ax.legend(handles=legend_elements, loc='upper right') + + # Set plot limits with buffer + ax.set_xlim(plot_start, plot_end) + ax.invert_yaxis() # First transcript at the top + + # Add grid lines + ax.grid(True, axis='y', linestyle='--', alpha=0.3) + + plt.tight_layout(rect=[0.05, 0, 0.9, 1]) # Give more space on left (0.05) and right (1-0.9=0.1) + plot_path = os.path.join( + self.gene_visualizations_dir, f"{gene_name_or_id}_splicing.pdf" # Changed from .png to .pdf + ) + plt.savefig(plot_path, bbox_inches='tight', dpi=300) + plt.close(fig) + + + + def plot_transcript_usage(self): + """Visualize transcript usage for each gene across conditions from pre-filtered data.""" + if not self.gene_visualizations_dir: + logging.warning("No gene_visualizations_dir provided. Skipping transcript usage plotting.") + return + + # The input updated_gene_dict is assumed to be pre-filtered. + + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data_per_condition = {} # Store gene transcript data per condition + found_gene_any_condition = False # Flag if gene found in any condition + + # Iterate directly through the pre-filtered dictionary + for condition, genes in self.updated_gene_dict.items(): + condition_gene_data = None + for gene_id, gene_info in genes.items(): + # Compare gene names (case-insensitive matching) + if "name" in gene_info and gene_info["name"].upper() == gene_name_or_id.upper(): + condition_gene_data = gene_info.get("transcripts", {}) # Get transcripts, default to empty dict + found_gene_any_condition = True + #logging.debug(f"Found gene {gene_name_or_id} data for condition {condition}") + break # Found gene in this condition, break inner loop + if condition_gene_data is not None: # Store even if empty, to represent the condition + gene_data_per_condition[condition] = condition_gene_data + + if not found_gene_any_condition: + logging.warning(f"Gene {gene_name_or_id} not found in any condition within the pre-filtered updated_gene_dict.") + continue # Skip to the next gene if not found + + if not gene_data_per_condition: + logging.warning(f"No transcript data available for gene {gene_name_or_id} across conditions. Skipping plot.") + continue + + # --- Reorder conditions: Reference first, then Target --- + all_conditions = list(gene_data_per_condition.keys()) + ref_conditions_present = sorted([c for c in all_conditions if c in self.ref_conditions]) + target_conditions_present = sorted([c for c in all_conditions if c in self.target_conditions]) + # Include any other conditions found in the data but not specified as ref/target (shouldn't happen with pre-filtering, but safe) + other_conditions_present = sorted([c for c in all_conditions if c not in self.ref_conditions and c not in self.target_conditions]) + + conditions = ref_conditions_present + target_conditions_present + other_conditions_present + # --- End Reordering --- + + n_bars = len(conditions) + + if n_bars == 0: + logging.warning(f"No conditions to plot for gene {gene_name_or_id}.") + continue + + + fig, ax = plt.subplots(figsize=(12, 8)) + index = np.arange(n_bars) + bar_width = 0.35 + opacity = 0.8 + + # Determine unique transcripts across all plotted conditions for consistent coloring + all_transcript_ids = set() + for condition in conditions: + all_transcript_ids.update(gene_data_per_condition[condition].keys()) + unique_transcripts = sorted(list(all_transcript_ids)) + transcript_to_color_idx = {tid: idx for idx, tid in enumerate(unique_transcripts)} + colors = plt.cm.plasma(np.linspace(0, 1, num=len(unique_transcripts))) + + bottom_val = np.zeros(n_bars) + plotted_labels = set() # To avoid duplicate legend entries + + for i, condition in enumerate(conditions): + transcripts = gene_data_per_condition[condition] + if not transcripts: # Skip if no transcript data for this condition + continue + + # Sort transcripts for consistent stacking order (optional but good practice) + sorted_transcript_items = sorted(transcripts.items(), key=lambda item: item[0]) + + for transcript_id, transcript_info in sorted_transcript_items: + color_idx = transcript_to_color_idx.get(transcript_id, 0) # Fallback index 0 + color = colors[color_idx % len(colors)] + value = transcript_info["value"] + # Get transcript name with fallback options + transcript_name = (transcript_info.get("name") or + transcript_info.get("transcript_id") or + transcript_id) + + label = transcript_name if transcript_name not in plotted_labels else "" + if label: + plotted_labels.add(label) + + ax.bar( + i, + float(value), + bar_width, + bottom=bottom_val[i], + alpha=opacity, + color=color, + label=label, + ) + bottom_val[i] += float(value) + + ax.set_xlabel("Condition") + ax.set_ylabel("Transcript Usage (TPM)") + # Find a representative gene name (assuming transcripts exist in at least one condition) + first_condition_with_data = next((cond for cond in conditions if gene_data_per_condition[cond]), None) + gene_display_name = gene_name_or_id # Default to ID + if first_condition_with_data: + # Attempt to get gene name from the first transcript entry in the first condition with data + first_transcript_info = next(iter(gene_data_per_condition[first_condition_with_data].values()), None) + if first_transcript_info: + # Assuming gene name might be stored within transcript info, or fallback + # This part might need adjustment based on your actual data structure + # If gene name isn't in transcript info, you might need to fetch it differently + pass # Placeholder - logic to get gene name needs review based on structure + + # Updated title - Include threshold if available + if self.display_threshold is not None: + ax.set_title(f"Transcript Usage for {gene_display_name} by Condition (Input filtered at TPM >= {self.display_threshold})") + else: + ax.set_title(f"Transcript Usage for {gene_display_name} by Condition") + + ax.set_xticks(index) + ax.set_xticklabels(conditions) + + # Update legend handling to use plotted_labels + handles, labels = ax.get_legend_handles_labels() + if handles: # Only show legend if there are items to show + ax.legend( + handles, + labels, + title="Transcript IDs", + bbox_to_anchor=(1.05, 1), + loc="upper left", + fontsize=8, + ) + + plt.tight_layout() + plot_path = os.path.join( + self.gene_visualizations_dir, + f"{gene_name_or_id}_transcript_usage_by_sample_type.pdf", # Changed from .png to .pdf + ) + plt.savefig(plot_path) + plt.close(fig) + + def make_pie_charts(self): + """ + Create pie charts for transcript alignment classifications and read assignment consistency. + Handles both combined and separate sample data structures. + """ + # Skip if reads_and_class is not provided + if not self.reads_and_class: + logging.warning("No reads_and_class data provided. Skipping pie chart creation.") + return + + titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] + + # Input data is assumed to be pre-filtered, so no need to check ref/target conditions here. + # Plot for all sample groups found in the data. + + for title, data in zip(titles, self.reads_and_class): + if isinstance(data, dict): + # Check if the dictionary values are also dictionaries (indicating separate sample groups) + if data and isinstance(next(iter(data.values()), None), dict): + # Separate sample data case (e.g. {'Mutants': {...}, 'WildType': {...}}) + logging.debug(f"Creating separate pie charts for samples in '{title}'") + for sample_name, sample_data in data.items(): + # No filtering needed here, plot for every sample found + self._create_pie_chart(f"{title} - {sample_name}", sample_data) + elif data: # Check if data is not empty before proceeding + # Combined data case or single sample group provided directly + logging.debug(f"Creating combined pie chart for '{title}'") + self._create_pie_chart(title, data) + else: + logging.warning(f"Empty data dictionary provided for pie chart '{title}'. Skipping.") + else: + logging.warning(f"Skipping unexpected data type for pie chart '{title}': {type(data)}") + + def _create_pie_chart(self, title, data): + """ + Helper method to create a single pie chart. + """ + labels = list(data.keys()) + sizes = list(data.values()) + total = sum(sizes) + + # Generate a file-friendly title + file_title = title.lower().replace(" ", "_").replace("-", "_") + + plt.figure(figsize=(12, 8)) + wedges, texts, autotexts = plt.pie( + sizes, + labels=labels, + autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", + startangle=140, + textprops=dict(color="w"), + ) + plt.setp(autotexts, size=8, weight="bold") + plt.setp(texts, size=7) + + plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. + plt.title(f"{title}\nTotal: {total}") + + plt.legend( + wedges, + labels, + title="Categories", + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=8, + ) + # Save pie charts in the read_assignments directory + plot_path = os.path.join( + self.read_assignments_dir, f"{file_title}_pie_chart.pdf" # Changed from .png to .pdf + ) + plt.savefig(plot_path, bbox_inches='tight', dpi=300) + plt.close() + + def plot_read_length_effects(self, length_effects): + """ + Plot how read length relates to (a) assignment uniqueness and (b) FSM/ISM/mono classification. + Saves two bar charts into read_assignments_dir. + """ + if not self.read_assignments_dir: + logging.warning("No read_assignments_dir provided. Skipping length effects plotting.") + return + + bins = length_effects['bins'] + totals = length_effects['totals'] + + # Assignment uniqueness plot + df_a_rows = [] + for b in bins: + row = {'bin': b, **length_effects['by_bin_assignment'][b], 'TOTAL': totals[b]} + df_a_rows.append(row) + df_a = pd.DataFrame(df_a_rows) + if df_a.empty: + logging.warning("No data available for assignment uniqueness plot; skipping.") + return + df_a.set_index('bin', inplace=True) + + # Determine assignment categories dynamically and ensure columns exist + assignment_keys = length_effects.get('assignment_keys', []) + if not assignment_keys: + assignment_keys = [c for c in df_a.columns if c != 'TOTAL'] + for key in assignment_keys: + if key not in df_a.columns: + df_a[key] = 0 + + # Normalize to percentages per bin + for col in assignment_keys: + df_a[col] = np.where(df_a['TOTAL'] > 0, df_a[col] / df_a['TOTAL'] * 100.0, 0.0) + + # Preferred column order if present + preferred_order = ['UNIQUE', 'AMBIGUOUS', 'OTHER', 'INCONSISTENT', 'UNASSIGNED'] + ordered_cols = [c for c in preferred_order if c in assignment_keys] + [c for c in assignment_keys if c not in preferred_order] + if not ordered_cols: + logging.warning("No assignment columns to plot after normalization; skipping.") + return + + ax = df_a[ordered_cols].plot(kind='bar', stacked=True, figsize=(12,6), colormap='tab20') + ax.set_ylabel('Percentage of reads') + ax.set_title('Read assignment uniqueness by read length') + ax.legend(title='Assignment') + plt.tight_layout() + out1 = os.path.join(self.read_assignments_dir, 'read_length_vs_assignment_uniqueness.pdf') + plt.savefig(out1, bbox_inches='tight', dpi=300) + plt.close() + + def plot_read_length_histogram(self, hist_data): + """ + Plot a histogram of read lengths using precomputed bin edges/counts. + """ + if not self.read_assignments_dir: + logging.warning("No read_assignments_dir provided. Skipping length histogram plot.") + return + + edges = hist_data.get('edges', []) + counts = hist_data.get('counts', []) + total = hist_data.get('total', 0) + if not edges or not counts: + logging.warning("Empty histogram data; skipping.") + return + + # Build midpoints for bar plotting + mids = [(edges[i] + edges[i+1]) / 2.0 for i in range(len(counts))] + widths = [edges[i+1] - edges[i] for i in range(len(counts))] + + plt.figure(figsize=(12,6)) + plt.bar(mids, counts, width=widths, align='center', color='steelblue', edgecolor='black') + plt.xlabel('Read length (bp)') + plt.ylabel('Read count') + plt.title(f'Read length histogram (total n={total:,})') + plt.tight_layout() + outp = os.path.join(self.read_assignments_dir, 'read_length_histogram.pdf') + plt.savefig(outp, bbox_inches='tight', dpi=300) + plt.close() + + def plot_read_length_vs_assignment(self, length_vs_assignment): + """ + Plot read-length bins vs assignment_type and vs classification as stacked bar charts. + Saves two PDFs into read_assignments_dir. + """ + if not self.read_assignments_dir: + logging.warning("read_assignments_dir not set; skipping length vs assignment plots") + return + + import pandas as pd + import matplotlib.pyplot as plt + + bins = length_vs_assignment.get('bins', []) + a_counts = length_vs_assignment.get('assignment', {}) + c_counts = length_vs_assignment.get('classification', {}) + + # Build DataFrames + def to_df(counts_dict): + rows = [] + for (b, key), val in counts_dict.items(): + rows.append({'bin': b, 'key': key, 'count': val}) + df = pd.DataFrame(rows) + if df.empty: + return df + pivot = df.pivot_table(index='bin', columns='key', values='count', aggfunc='sum', fill_value=0) + # Ensure bin order + pivot = pivot.reindex(bins, axis=0).fillna(0) + return pivot + + df_a = to_df(a_counts) + df_c = to_df(c_counts) + + def plot_stacked(pivot_df, title, filename): + if pivot_df.empty: + logging.warning(f"No data for plot: {title}") + return + ax = pivot_df.plot(kind='bar', stacked=True, figsize=(12, 6)) + ax.set_xlabel('Read length bin') + ax.set_ylabel('Read count') + ax.set_title(title) + plt.tight_layout() + out = os.path.join(self.read_assignments_dir, filename) + plt.savefig(out) + plt.close() + logging.info(f"Saved plot: {out}") + + plot_stacked(df_a, 'Read length vs assignment_type', 'length_vs_assignment_type.pdf') + plot_stacked(df_c, 'Read length vs classification', 'length_vs_classification.pdf') + + def plot_novel_transcript_contribution(self): + """ + Creates a plot showing the percentage of expression from novel transcripts. + - Y-axis: Percentage of expression from novel transcripts (combined across conditions) + - X-axis: Expression log2 fold change between conditions + - Point size: Overall expression level + - Color: Red (target) to Blue (reference) indicating which condition contributes more to novel transcript expression + Assumes input updated_gene_dict is already filtered appropriately. + """ + logging.info("Creating novel transcript contribution plot") + + # Skip if we don't have reference vs target conditions defined + if not (hasattr(self, 'ref_conditions') and self.ref_conditions and + hasattr(self, 'target_conditions') and self.target_conditions): + logging.warning("Cannot create novel transcript plot: missing reference or target conditions") + return + + # Get actual condition labels + ref_label = "+".join(self.ref_conditions) + target_label = "+".join(self.target_conditions) + + # Track all unique genes across all conditions + all_genes = {} # Dictionary to track gene_id -> gene_info mapping across conditions + + # Collect all genes present in the (presumably pre-filtered) input dictionary + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_name = gene_info.get('name', gene_id) + if gene_id not in all_genes: + all_genes[gene_id] = {'name': gene_name, 'conditions': {}} + + # Store condition-specific data only if it's a ref or target condition + if condition in self.ref_conditions or condition in self.target_conditions: + all_genes[gene_id]['conditions'][condition] = gene_info + + logging.info(f"Total unique genes found across relevant conditions: {len(all_genes)}") + + # Prepare data storage for the main plot + plot_data = [] + + # Process each gene from all_genes + for gene_id, gene_data in all_genes.items(): + gene_name = gene_data['name'] + conditions_data = gene_data['conditions'] # Contains only ref/target conditions now + + # Calculate expression for each condition group + ref_total_exp = {cond: 0 for cond in self.ref_conditions} + target_total_exp = {cond: 0 for cond in self.target_conditions} + ref_novel_exp = {cond: 0 for cond in self.ref_conditions} + target_novel_exp = {cond: 0 for cond in self.target_conditions} + + gene_has_any_transcript = False # Check if the gene has any transcripts in ref/target + + # Process each relevant condition for the gene + for condition, gene_info in conditions_data.items(): + transcripts = gene_info.get('transcripts', {}) + if not transcripts: + continue + + gene_has_any_transcript = True # Mark that this gene has data + + # Check if this condition is in our condition groups (redundant check now, but safe) + is_ref = condition in self.ref_conditions + is_target = condition in self.target_conditions + + for transcript_id, transcript_info in transcripts.items(): + # Improved novel transcript identification - transcript is novel if not from Ensembl + transcript_is_reference = transcript_id.startswith("ENST") + is_novel = not transcript_is_reference + + value = float(transcript_info.get("value", 0)) + + # REMOVED Filtering by TPM threshold - Now process all transcripts present + if is_ref: + ref_total_exp[condition] += value + if is_novel: + ref_novel_exp[condition] += value + + if is_target: + target_total_exp[condition] += value + if is_novel: + target_novel_exp[condition] += value + + # Only proceed if the gene had transcripts in the relevant conditions + if gene_has_any_transcript: + # Calculate average expression for each condition group + ref_novel_pct = 0 + target_novel_pct = 0 + ref_expr_total = 0 + target_expr_total = 0 + ref_novel_expr_total = 0 + target_novel_expr_total = 0 + + # Sum up expression values across conditions + num_ref_conditions_with_expr = 0 + for cond in self.ref_conditions: + cond_total_exp = ref_total_exp.get(cond, 0) + cond_novel_exp = ref_novel_exp.get(cond, 0) + ref_expr_total += cond_total_exp + ref_novel_expr_total += cond_novel_exp + if cond_total_exp > 0: + ref_novel_pct += (cond_novel_exp / cond_total_exp) * 100 + num_ref_conditions_with_expr += 1 + + num_target_conditions_with_expr = 0 + for cond in self.target_conditions: + cond_total_exp = target_total_exp.get(cond, 0) + cond_novel_exp = target_novel_exp.get(cond, 0) + target_expr_total += cond_total_exp + target_novel_expr_total += cond_novel_exp + if cond_total_exp > 0: + target_novel_pct += (cond_novel_exp / cond_total_exp) * 100 + num_target_conditions_with_expr += 1 + + # Average the condition-specific percentages (for color coding only) + ref_novel_pct /= num_ref_conditions_with_expr or 1 + target_novel_pct /= num_target_conditions_with_expr or 1 + + # Calculate overall novel percentage (for y-axis) + combined_expr_total = ref_expr_total + target_expr_total + combined_novel_expr_total = ref_novel_expr_total + target_novel_expr_total + + # Check for non-zero total expression before calculating percentages and fold change + if combined_expr_total > 0: + # Calculate log2 fold change using the total expression values + # Add pseudocount to avoid division by zero or log(0) + pseudocount = 1e-6 # Small value to add + log2fc = np.log2((target_expr_total + pseudocount) / (ref_expr_total + pseudocount)) + + # Calculate novel transcript contribution difference (for color) + novel_pct_diff = target_novel_pct - ref_novel_pct + + # Calculate overall novel percentage (for y-axis) + overall_novel_pct = (combined_novel_expr_total / combined_expr_total) * 100 + + # Add data point + plot_data.append({ + 'gene_id': gene_id, + 'gene_name': gene_name, + 'ref_novel_pct': ref_novel_pct, + 'target_novel_pct': target_novel_pct, + 'novel_pct_diff': novel_pct_diff, + 'overall_novel_pct': overall_novel_pct, + 'log2fc': log2fc, + 'total_expr': combined_expr_total + }) + + # Create dataframe + df = pd.DataFrame(plot_data) + + if df.empty: + logging.warning("No data available for novel transcript plot after processing.") # Adjusted warning + return + + # Get the parent directory of gene_visualizations_dir + parent_dir = os.path.dirname(self.gene_visualizations_dir) + + # Save the CSV to parent directory instead of gene_visualizations_dir + csv_path = os.path.join(parent_dir, "novel_transcript_expression_data.csv") + df.to_csv(csv_path, index=False) + logging.info(f"Novel transcript expression data saved to {csv_path}") + + # Log the number of genes used in the plot + logging.info(f"Number of genes included in novel transcript plot: {len(df)}") + + # Create the plot with more space on right for legend + plt.figure(figsize=(16, 10)) # Increased width from 14 to 16 + + # Define red-blue colormap + norm = Normalize(vmin=-50, vmax=50) # Normalize based on difference range + cmap = cm.get_cmap('coolwarm') # Red-Blue colormap + + # More dramatic scaling for point sizes + min_size = 30 + max_size = 800 # Much larger maximum size + + # Use np.power for more dramatic scaling differences + expression_values = df['total_expr'].values + # Handle case where expression_values might be empty or all zero + if len(expression_values) == 0 or expression_values.max() == expression_values.min(): + max_expr = 1 + min_expr = 0 + logging.warning("Cannot determine expression range for point scaling; using default [0, 1].") + else: + max_expr = expression_values.max() + min_expr = expression_values.min() + + # Log the actual min and max expression values for reference + logging.debug(f"Expression range in data: min={min_expr}, max={max_expr}") + + # Define the scaling function that will be used for both data points and legend + def scale_point_size(expr_value, min_expr, max_expr, min_size, max_size, power=0.3): + """Scale expression values to point sizes using the same formula for data and legend""" + # Normalize the expression value to [0,1] range + if max_expr == min_expr: # Avoid division by zero or invalid range + normalized = 0.5 # Default to middle size if range is zero + else: + # Clamp value to range before normalizing to handle potential outliers from pseudocounts + clamped_value = np.clip(expr_value, min_expr, max_expr) + normalized = (clamped_value - min_expr) / (max_expr - min_expr) + # Apply power scaling and convert to point size + return min_size + (max_size - min_size) * (normalized ** power) + + # Apply scaling to actual data points + scaled_sizes = [scale_point_size(val, min_expr, max_expr, min_size, max_size) for val in expression_values] + + # Plot points with scaled sizes + sc = plt.scatter(df['log2fc'], df['overall_novel_pct'], + s=scaled_sizes, + c=df['novel_pct_diff'], + cmap=cmap, + norm=norm, + alpha=0.8, + edgecolors='black') + + # Add color legend on the right + cbar = plt.colorbar(sc, orientation='vertical', pad=0.02) + cbar.set_label('Novel transcript usage difference (%)', size=12) + cbar.ax.tick_params(labelsize=10) + + # Use red and blue blocks to explain the colormap + plt.figtext(0.92, 0.72, f'Blue = higher (%) in {ref_label}', fontsize=12, ha='center') + plt.figtext(0.92, 0.75, f'Red = higher (%) in {target_label}', fontsize=12, ha='center') + + # Add size legend directly to the plot + # Create legend elements for different sizes with new values: 50, 500, 5000 + size_legend_values = [50, 500, 5000] + size_legend_elements = [] + + # Calculate sizes for legend using EXACTLY the same scaling function as for the data points + for val in size_legend_values: + # Use the same scaling function defined above + size = scale_point_size(val, min_expr, max_expr, min_size, max_size) + + # Log the actual size being used for the legend point + logging.debug(f"Legend point {val} TPM scaled to size {size}") + + # Convert area to diameter for Line2D (sqrt of area * 2) + marker_diameter = 2 * np.sqrt(size / np.pi) + + size_legend_elements.append( + plt.Line2D([0], [0], marker='o', color='w', + markerfacecolor='gray', markersize=marker_diameter, + label=f'{val:.0f} TPM') + ) + + # Position legend + plt.legend(handles=size_legend_elements, + title="Expression Level", + loc='center left', + bbox_to_anchor=(1.15, 0.5), + frameon=False, + title_fontsize=12, + fontsize=12) + + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + + # Add labels and title with actual condition names + plt.xlabel('Log2 Fold Change', fontsize=12) + plt.ylabel('Total expression from novel transcripts (%)', fontsize=12) + plt.title('Novel Transcript Usage vs Expression Change between High Risk and Low Risk Phenotypes', fontsize=18) + + plt.grid(True, alpha=0.3) + + # Use tighter layout settings + plt.tight_layout() + + # Save figure to parent directory instead of gene_visualizations_dir + output_path = os.path.join(parent_dir, "novel_transcript_expression_plot.pdf") + plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0.5) + plt.close() + + logging.debug(f"Novel transcript expression plot saved to {output_path}") + + +class ExpressionVisualizer: + def __init__(self, output_path): + """Initialize with output path for plots.""" + self.output_path = Path(output_path) + self.output_path.mkdir(parents=True, exist_ok=True) + self.logger = logging.getLogger(__name__) # Logger for this class + # Suppress matplotlib font debug messages + logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING) + + def create_volcano_plot( + self, + df: pd.DataFrame, + target_label: str, + reference_label: str, + padj_threshold: float = 0.05, + lfc_threshold: float = 1, + top_n: int = 60, # Increased from 10 to 20 + feature_type: str = "genes", + ) -> None: + """Create volcano plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df["padj"] = df["padj"].replace(0, 1e-300) + df = df[df["padj"] > 0] + df = df.copy() # Create a copy to avoid the warning + df.loc[:, "-log10(padj)"] = -np.log10(df["padj"]) + + # Define significant genes + significant = (df["padj"] < padj_threshold) & ( + abs(df["log2FoldChange"]) > lfc_threshold + ) + up_regulated = significant & (df["log2FoldChange"] > lfc_threshold) + down_regulated = significant & (df["log2FoldChange"] < -lfc_threshold) + + # Plot points + plt.scatter( + df.loc[~significant, "log2FoldChange"], + df.loc[~significant, "-log10(padj)"], + color="grey", + alpha=0.5, + label="Not Significant", + ) + plt.scatter( + df.loc[up_regulated, "log2FoldChange"], + df.loc[up_regulated, "-log10(padj)"], + color="red", + alpha=0.7, + label=f"Up-regulated in ({target_label})", + ) + plt.scatter( + df.loc[down_regulated, "log2FoldChange"], + df.loc[down_regulated, "-log10(padj)"], + color="blue", + alpha=0.7, + label=f"Down-regulated in ({target_label})", + ) + + # Add threshold lines and labels + plt.axhline(-np.log10(padj_threshold), color="grey", linestyle="--") + plt.axvline(lfc_threshold, color="grey", linestyle="--") + plt.axvline(-lfc_threshold, color="grey", linestyle="--") + + plt.xlabel("log2 Fold Change") + plt.ylabel("-log10(adjusted p-value)") + plt.title(f"Volcano Plot: {target_label} vs {reference_label}") + plt.legend() + + # Add labels for top significant features + sig_df = df.loc[significant].nsmallest(top_n, "padj") + for _, row in sig_df.iterrows(): + if feature_type == "genes": + symbol = row["gene_name"] if pd.notnull(row["gene_name"]) else row["feature_id"] + elif feature_type == "transcripts": + symbol = row["transcript_symbol"] if pd.notnull(row["transcript_symbol"]) else row["feature_id"] + else: # Fallback to feature_id if feature_type is not recognized + symbol = row["feature_id"] + plt.text( + row["log2FoldChange"], + row["-log10(padj)"], + symbol, + fontsize=8, + ha="center", + va="bottom", + ) + + plt.tight_layout() + plot_path = ( + self.output_path / f"volcano_plot_{feature_type}.pdf" + ) + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"Volcano plot saved to {plot_path}") + + def create_ma_plot( + self, + df: pd.DataFrame, + target_label: str, + reference_label: str, + feature_type: str = "genes", + ) -> None: + """Create MA plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df = df[df["baseMean"] > 0] + df["log10(baseMean)"] = np.log10(df["baseMean"]) + + # Create plot + plt.scatter( + df["log10(baseMean)"], df["log2FoldChange"], alpha=0.5, color="grey" + ) + plt.axhline(y=0, color="red", linestyle="--") + + plt.xlabel("log10(Base Mean)") + plt.ylabel("log2 Fold Change") + plt.title(f"MA Plot: {target_label} vs {reference_label}") + + plt.tight_layout() + plot_path = self.output_path / f"ma_plot_{feature_type}.pdf" # Changed from .png to .pdf + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"MA plot saved to {plot_path}") + + def create_summary( + self, + res_df: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create and save analysis summary with correct filtering criteria reporting. + + Args: + res_df: Results DataFrame + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + total_features = len(res_df) + sig_features = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"].abs() > 1) + ).sum() + up_regulated = ((res_df["padj"] < 0.05) & (res_df["log2FoldChange"] > 1)).sum() + down_regulated = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"] < -1) + ).sum() + + # Incorporate feature_type into the summary filename + summary_filename = f"analysis_summary_{feature_type}.txt" + summary_path = self.output_path / summary_filename + + with summary_path.open("w") as f: + f.write(f"Analysis Summary: {target_label} vs {reference_label}\n") + f.write("================================\n") + + # Different filtering description based on feature type + if feature_type == "genes": + f.write( + f"{feature_type.capitalize()} after filtering " + f"(mean count >= {min_count} in either condition group): {total_features}\n" + ) + else: # transcripts + f.write( + f"{feature_type.capitalize()} after filtering " + f"(count >= {min_count} in at least half of all samples): {total_features}\n" + ) + + f.write(f"Significantly differential {feature_type}: {sig_features}\n") + f.write(f"Up-regulated {feature_type}: {up_regulated}\n") + f.write(f"Down-regulated {feature_type}: {down_regulated}\n") + + logging.info(f"Analysis summary saved to {summary_path}") + + def visualize_results( + self, + results: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create all visualizations and summary for the analysis results. + + Args: + results: DataFrame containing differential expression results + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + try: + self.create_volcano_plot( + results, target_label, reference_label, feature_type=feature_type + ) + self.create_ma_plot( + results, target_label, reference_label, feature_type=feature_type + ) + self.create_summary( + results, + target_label, + reference_label, + min_count, + feature_type=feature_type, + ) + except Exception as e: + logging.exception("Failed to create visualizations") + raise + + + def plot_pca( + self, + pca_df: pd.DataFrame, + title: str, + output_prefix: str, + explained_variance: np.ndarray, + loadings: np.ndarray, + feature_names: List[str] + ) -> Path: + """Plot PCA scatter plot, scree plot, and loadings.""" + plt.figure(figsize=(8, 6)) + + # Extract variance info from title for axis labels only + pc1_var = title.split("PC1 (")[1].split("%)")[0] + pc2_var = title.split("PC2 (")[1].split("%)")[0] + + # Get clean title without PCs and variance + base_title = title.split(' Level PCA: ')[0] + comparison = title.split(': ')[1].split('PC1')[0].strip() + clean_title = f"{base_title} Level PCA: {comparison}" + + # Update group labels in the DataFrame + condition_mapping = {'Target': title.split(": ")[1].split(" vs ")[0], + 'Reference': title.split(" vs ")[1].split("PC1")[0].strip()} + pca_df['group'] = pca_df['group'].map(condition_mapping) + + # Create scatter plot + sns.scatterplot(x='PC1', y='PC2', hue='group', data=pca_df, s=100) + plt.xlabel(f'PC1 ({pc1_var}%)') + plt.ylabel(f'PC2 ({pc2_var}%)') + plt.title(clean_title) + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.tight_layout() + + scatter_plot_path = self.output_path / f"{output_prefix}.pdf" # Changed from .png to .pdf + plt.savefig(scatter_plot_path) + plt.close() + + # --- Scree Plot --- + self._plot_scree(explained_variance, output_prefix) + + # --- Loadings --- + self._output_loadings(loadings, feature_names, output_prefix) + + return scatter_plot_path # Return path to scatter plot + + def _plot_scree(self, explained_variance: np.ndarray, output_prefix: str) -> Path: + """Plot scree plot of explained variance.""" + plt.figure(figsize=(8, 6)) + num_components = len(explained_variance) + component_numbers = range(1, num_components + 1) + + plt.bar(component_numbers, explained_variance * 100) + plt.xlabel('Principal Component') + plt.ylabel('Percentage of Explained Variance') + plt.title('Scree Plot') + plt.xticks(component_numbers) # Ensure all component numbers are labeled + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.tight_layout() + + scree_plot_path = self.output_path / f"scree_{output_prefix}.pdf" # Changed from .png to .pdf + plt.savefig(scree_plot_path) + plt.close() + return scree_plot_path + + def _output_loadings(self, loadings: np.ndarray, feature_names: List[str], output_prefix: str, top_n: int = 10) -> Path: + """Output top N loadings for PC1 and PC2.""" + # Generate column names dynamically based on the number of components + num_components = loadings.shape[0] # Get the number of components from loadings shape + pc_columns = [f'PC{i+1}' for i in range(num_components)] + + loadings_df = pd.DataFrame(loadings.T, index=feature_names, columns=pc_columns) # Use dynamic column names + + output_path = self.output_path / f"loadings_{output_prefix}.txt" + with open(output_path, 'w') as f: + f.write("PCA Loadings (Top {} Features for PC1 and PC2):\n\n".format(top_n)) + for pc_name in ['PC1', 'PC2']: + f.write(f"\n--- {pc_name} ---\n") + # Sort by absolute value of loading + top_loadings = loadings_df.sort_values(by=pc_name, key=lambda x: x.abs(), ascending=False).head(top_n) + for gene, loading in top_loadings[pc_name].items(): # Iterate over series items + f.write(f"{gene}:\t{loading:.4f}\n") # Tab-separated for readability + return output_path + diff --git a/src/visualization_read_assignment_io.py b/src/visualization_read_assignment_io.py new file mode 100644 index 00000000..c2d9202c --- /dev/null +++ b/src/visualization_read_assignment_io.py @@ -0,0 +1,269 @@ +import logging +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union +import numpy as np + +from src.visualization_cache_utils import ( + build_read_assignment_cache_file, + build_length_effects_cache_file, + build_length_hist_cache_file, + save_cache, + load_cache, + validate_read_assignment_data, + validate_length_effects_data, + validate_length_hist_data, +) + + +def _smart_open(path_str: str): + import gzip + try: + with open(path_str, 'rb') as bf: + if bf.read(2) == b'\x1f\x8b': + return gzip.open(path_str, 'rt') + except Exception: + pass + return open(path_str, 'rt') + + +def _calc_length_bp(exons_str: str) -> int: + if not isinstance(exons_str, str) or not exons_str: + return 0 + total = 0 + for part in exons_str.split(','): + if '-' not in part: + continue + try: + s, e = part.split('-') + total += int(e) - int(s) + 1 + except Exception: + continue + return total + + +def get_read_assignment_counts(config, cache_dir: Path): + """ + Returns read-assignment classification and assignment_type counts, using cache. + Return format mirrors previous behavior in DictionaryBuilder: + - If config.read_assignments is a list: ({sample: class_counts}, {sample: assign_type_counts}) + - Else: (class_counts, assign_type_counts) + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + cache_file = build_read_assignment_cache_file( + config.read_assignments, config.ref_only, cache_dir + ) + + if cache_file.exists(): + cached_data = load_cache(cache_file) + if cached_data and validate_read_assignment_data(cached_data, config.read_assignments): + logger.info("Using cached read assignment data.") + if isinstance(config.read_assignments, list): + return ( + cached_data["classification_counts"], + cached_data["assignment_type_counts"], + ) + return cached_data + + logger.info("Building read assignment data from scratch.") + + def process_file(file_path: str): + classification_counts: Dict[str, int] = {} + assignment_type_counts: Dict[str, int] = {} + with _smart_open(file_path) as fh: + # Skip header lines starting with '#' + while True: + pos = fh.tell() + line = fh.readline() + if not line: + break + if not line.startswith('#'): + fh.seek(pos) + break + for line in fh: + parts = line.strip().split('\t') + if len(parts) < 9: + continue + additional_info = parts[8] + assignment_type = parts[5] + classification = ( + additional_info.split('Classification=')[-1].split(';')[0].strip() + if 'Classification=' in additional_info else 'Unknown' + ) + classification_counts[classification] = classification_counts.get(classification, 0) + 1 + assignment_type_counts[assignment_type] = assignment_type_counts.get(assignment_type, 0) + 1 + return classification_counts, assignment_type_counts + + if isinstance(config.read_assignments, list): + classification_counts_dict: Dict[str, Dict[str, int]] = {} + assignment_type_counts_dict: Dict[str, Dict[str, int]] = {} + for sample_name, file_path in config.read_assignments: + c_counts, a_counts = process_file(file_path) + classification_counts_dict[sample_name] = c_counts + assignment_type_counts_dict[sample_name] = a_counts + to_cache = { + "classification_counts": classification_counts_dict, + "assignment_type_counts": assignment_type_counts_dict, + } + save_cache(cache_file, to_cache) + return classification_counts_dict, assignment_type_counts_dict + else: + counts = process_file(config.read_assignments) + save_cache(cache_file, counts) + return counts + + +def get_read_length_effects(config, cache_dir: Path) -> Dict[str, Any]: + """ + Compute and cache read-length effects aggregates: + - by length bin vs assignment_type and vs classification + - dynamic keys for observed categories + Returns dict with keys: bins, by_bin_assignment, by_bin_classification, assignment_keys, classification_keys, totals + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + # Fixed bin order focused on 0-15 kb + bin_order = ['<1kb','1-2kb','2-3kb','3-4kb','4-5kb','5-6kb','6-7kb','7-8kb','8-9kb','9-10kb','10-12kb','12-15kb','>15kb'] + cache_file = build_length_effects_cache_file( + config.read_assignments, config.ref_only, cache_dir, bin_order + ) + + if cache_file.exists(): + cached = load_cache(cache_file) + if cached and validate_length_effects_data(cached, expected_bins=bin_order): + logger.info("Using cached read length effects.") + return cached + + from collections import defaultdict + by_bin_assignment: Dict[str, Dict[str, int]] = {b: defaultdict(int) for b in bin_order} + by_bin_classification: Dict[str, Dict[str, int]] = {b: defaultdict(int) for b in bin_order} + assignment_keys = set() + classification_keys = set() + totals: Dict[str, int] = {b: 0 for b in bin_order} + + def assign_bin(length_bp: int) -> str: + if length_bp < 1000: return '<1kb' + if length_bp < 2000: return '1-2kb' + if length_bp < 3000: return '2-3kb' + if length_bp < 4000: return '3-4kb' + if length_bp < 5000: return '4-5kb' + if length_bp < 6000: return '5-6kb' + if length_bp < 7000: return '6-7kb' + if length_bp < 8000: return '7-8kb' + if length_bp < 9000: return '8-9kb' + if length_bp < 10000: return '9-10kb' + if length_bp < 12000: return '10-12kb' + if length_bp < 15000: return '12-15kb' + return '>15kb' + + def process_file(file_path: str): + with _smart_open(file_path) as fh: + # Skip header lines + while True: + pos = fh.tell() + line = fh.readline() + if not line: + break + if not line.startswith('#'): + fh.seek(pos) + break + for line in fh: + parts = line.strip().split('\t') + if len(parts) < 9: + continue + assignment_type = parts[5] + exons_str = parts[7] + addi = parts[8] + classification = ( + addi.split('Classification=')[-1].split(';')[0].strip() + if 'Classification=' in addi else 'unknown' + ) + length_bp = _calc_length_bp(exons_str) + b = assign_bin(length_bp) + totals[b] += 1 + by_bin_assignment[b][assignment_type] += 1 + by_bin_classification[b][classification] += 1 + assignment_keys.add(assignment_type) + classification_keys.add(classification) + + if isinstance(config.read_assignments, list): + for _sample, file_path in config.read_assignments: + process_file(file_path) + else: + process_file(config.read_assignments) + + # Convert defaultdicts to dicts for safer pickling/validation + by_bin_assignment = {b: dict(d) for b, d in by_bin_assignment.items()} + by_bin_classification = {b: dict(d) for b, d in by_bin_classification.items()} + + result = { + 'bins': bin_order, + 'by_bin_assignment': by_bin_assignment, + 'by_bin_classification': by_bin_classification, + 'assignment_keys': sorted(list(assignment_keys)), + 'classification_keys': sorted(list(classification_keys)), + 'totals': totals, + } + save_cache(cache_file, result) + return result + + +def get_read_length_histogram(config, cache_dir: Path, bin_edges: List[int] = None) -> Dict[str, Any]: + """ + Compute and cache a histogram of read lengths derived from the exons column. + Returns dict with keys: edges, counts, total + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + if bin_edges is None: + bin_edges = [ + 0, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, + 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, + 12000, 15000, + ] + + cache_file = build_length_hist_cache_file( + config.read_assignments, config.ref_only, cache_dir, bin_edges + ) + + if cache_file.exists(): + cached = load_cache(cache_file) + if cached and validate_length_hist_data(cached, expected_edges=bin_edges): + logger.info("Using cached read length histogram.") + return cached + + lengths: List[int] = [] + + def process_file(file_path: str): + with _smart_open(file_path) as fh: + for line in fh: + if not line or line.startswith('#'): + continue + parts = line.rstrip('\n').split('\t') + if len(parts) < 8: + continue + exons = parts[7] + lengths.append(_calc_length_bp(exons)) + + if isinstance(config.read_assignments, list): + for _sample, file_path in config.read_assignments: + process_file(file_path) + else: + process_file(config.read_assignments) + + counts, edges = np.histogram(np.array(lengths, dtype=np.int64), bins=np.array(bin_edges)) + result = { + 'edges': edges.tolist(), + 'counts': counts.tolist(), + 'total': int(len(lengths)), + } + save_cache(cache_file, result) + return result + + diff --git a/src/visualization_simple_ranker.py b/src/visualization_simple_ranker.py new file mode 100644 index 00000000..1486af45 --- /dev/null +++ b/src/visualization_simple_ranker.py @@ -0,0 +1,762 @@ +""" +Enhanced gene ranking utility for long‑read sequencing experiments without replicates. + +This module refines the original SimpleGeneRanker by incorporating best practices +from recent long‑read isoform switching studies. It defines "interesting" genes +as those that are highly expressed, exhibit bona‑fide isoform switching (at +least two isoforms changing in opposite directions), and potentially show +functional consequences (e.g., gain or loss of coding potential). Genes with +extreme overall expression changes or very complex isoform architectures are +penalised to reduce false positives. + +Key features: + +1. **Isoform count filter** – Genes with fewer than two transcripts are + excluded, as isoform usage cannot change. Genes with excessive numbers of + isoforms can be down‑weighted via a complexity penalty. +2. **Bidirectional isoform switching** – For a gene to be considered a + candidate isoform switcher, at least one transcript must increase in usage + while another decreases between reference and target conditions. This helps + distinguish true isoform switches from uniform scaling of all isoforms. +3. **Functional impact assessment** – When transcript annotations include + attributes such as coding status, ORF length or predicted NMD sensitivity, + the ranker rewards genes where isoform switches change these properties. + Lacking such annotations, this component defaults to zero influence. +4. **Adaptive thresholds** – Gating thresholds for expression level, fold + change and usage delta are derived from quantiles of the observed + distributions, making the algorithm robust across datasets with different + scales. +5. **Extreme change penalty** – Genes with very large gene‑level fold changes + are down‑weighted to prioritise isoform regulation over conventional + differential expression. +6. **Categorised output** – The ranker labels the top genes according to + whether they are isoform switchers, high expressers or conventional DEGs. + +Example usage: + + from enhanced_gene_ranker import EnhancedGeneRanker + ranker = EnhancedGeneRanker(output_dir="./out", + ref_conditions=["ref"], + target_conditions=["tgt"], + updated_gene_dict=gene_dict) + top_genes = ranker.rank(top_n=50) + # top_genes is a list of gene names + +Note: This implementation assumes that `updated_gene_dict` follows the same +structure as in the original SimpleGeneRanker. Transcript annotations may +contain keys such as ``coding`` (bool), ``orf_length`` (int) or +``functional_consequence`` (str). Missing annotations are handled gracefully. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd + +# Configure module‑level logger +logger = logging.getLogger("IsoQuant.visualization.enhanced_ranker") +logger.setLevel(logging.INFO) + + +class SimpleGeneRanker: + """Rank genes by integrating expression level, fold change, isoform switching and functional impact. + + Parameters + ---------- + output_dir : str or Path + Directory where intermediate results could be written (not used here but kept + for compatibility). + ref_conditions : List[str] + List of keys in ``updated_gene_dict`` corresponding to reference + conditions. + target_conditions : List[str] + List of keys in ``updated_gene_dict`` corresponding to target + conditions. + ref_only : bool, optional + If True, only reference conditions will be considered (fold change + computation disabled). Defaults to False. + updated_gene_dict : Dict, optional + Nested dictionary with expression and transcript information. See + SimpleGeneRanker for expected format. + """ + + def __init__( + self, + output_dir: str | Path, + ref_conditions: List[str], + target_conditions: List[str], + ref_only: bool = False, + updated_gene_dict: Dict | None = None, + ) -> None: + self.output_dir = Path(output_dir) + self.ref_conditions = list(ref_conditions) + self.target_conditions = list(target_conditions) + self.ref_only = ref_only + self.updated_gene_dict: Dict[str, Dict] = updated_gene_dict or {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def rank(self, top_n: int = 100) -> List[str]: + """Return a ranked list of gene names based on the enhanced scoring algorithm. + + Parameters + ---------- + top_n : int + Maximum number of genes to return. If fewer genes meet the + thresholds, the returned list may be shorter. + + Returns + ------- + List[str] + List of gene names (uppercase) ranked by decreasing composite score. + """ + logger.info("Running EnhancedGeneRanker") + + if not self.updated_gene_dict: + logger.warning("No updated_gene_dict provided; returning empty list") + return [] + + # Scan transcript ID patterns in updated_gene_dict + self._scan_transcript_id_patterns() + + # 1. Extract gene‑level TPMs + gene_expr_ref, gene_expr_tgt = self._extract_gene_tpms_from_dict() + if gene_expr_ref.empty or gene_expr_tgt.empty: + logger.warning("No gene expression data found; returning empty list") + return [] + + # 2. Compute gene list intersection + common_genes = gene_expr_ref.index.intersection(gene_expr_tgt.index) + gene_expr_ref = gene_expr_ref.loc[common_genes] + gene_expr_tgt = gene_expr_tgt.loc[common_genes] + logger.info(f"{len(common_genes)} genes present in both conditions") + + # 3. Compute isoform usage deltas and switching flags + usage_delta, switch_flags, func_flags, isoform_counts = self._compute_isoform_usage_metrics(common_genes) + + # 4. Normalize features + abs_log2fc = np.abs(np.log2(gene_expr_tgt + 1) - np.log2(gene_expr_ref + 1)) + geom_expr = np.sqrt(gene_expr_ref * gene_expr_tgt) + + norm_expr = self._normalize_feature(geom_expr, name="Expression") + norm_change = self._normalize_feature(abs_log2fc, name="FoldChange") + norm_usage = self._normalize_feature(usage_delta, name="UsageDelta") + # Functional impact does not need normalization (0/1), but we convert to series + func_series = pd.Series(func_flags, index=common_genes, dtype=float) + + # 5. Derive adaptive thresholds + expr_gate = norm_expr.quantile(0.5) # median + change_gate = norm_change.quantile(0.75) # upper quartile + usage_gate = norm_usage.quantile(0.75) + logger.info( + f"Adaptive gates – expression: {expr_gate:.3f}, fold change: {change_gate:.3f}, usage delta: {usage_gate:.3f}" + ) + + # 6. Compute composite scores + scores = pd.Series(0.0, index=common_genes) + categories = {} + for gene in common_genes: + expr_val = norm_expr.at[gene] + change_val = norm_change.at[gene] + usage_val = norm_usage.at[gene] + is_switch = switch_flags.get(gene, False) + func_val = func_series.at[gene] + iso_count = isoform_counts.get(gene, 0) + + # Handle single-transcript genes differently + if iso_count == 1: + # Single-transcript genes: focus on expression change, require higher thresholds + single_transcript_expr_gate = norm_expr.quantile(0.7) # Higher than multi-isoform + single_transcript_change_gate = norm_change.quantile(0.85) # Much higher fold change required + + passes_expr = expr_val > single_transcript_expr_gate + passes_change = change_val > single_transcript_change_gate + + # Single-transcript penalty: they need to work harder to compete + single_transcript_penalty = 0.6 + + # Score based only on expression and fold change (no isoform switching possible) + base = 0.0 + if passes_expr and passes_change: + base = expr_val * 0.4 + change_val * 0.6 # Weight fold change more heavily + if func_val > 0: + base += 0.3 # Smaller functional bonus than multi-isoform + + # Apply single-transcript penalty + scores.at[gene] = base * single_transcript_penalty + + # Assign category + if base == 0: + categories[gene] = "LOW_EXPR" + else: + categories[gene] = "SINGLE_TRANSCRIPT_DE" # New category + + elif iso_count < 2: + # Skip genes with 0 transcripts (shouldn't happen but safety check) + continue + else: + # Multi-transcript genes: original logic with isoform switching + passes_expr = expr_val > expr_gate + passes_change = change_val > change_gate + passes_usage = usage_val > usage_gate and is_switch + + # Penalise extreme expression changes (>90th percentile) + penalty = 1.0 + if change_val > norm_change.quantile(0.9): + penalty *= 0.5 + + # Complexity penalty: down‑weight genes with many isoforms (top 10%) + if iso_count > np.quantile(list(isoform_counts.values()), 0.9): + penalty *= 0.7 + + # Compute base score; weight usage more heavily when switching + base = 0.0 + if passes_expr and (passes_change or passes_usage): + base = expr_val * 0.3 + change_val * 0.3 + usage_val * 1.2 + if func_val > 0: + base += 0.5 # Functional impact bonus + + # Final score after penalty + scores.at[gene] = base * penalty + + # Assign category for top genes later + if base == 0: + categories[gene] = "LOW_EXPR" + elif passes_usage and is_switch: + categories[gene] = "ISOFORM_SWITCHER" + elif passes_expr and passes_change: + categories[gene] = "DIFFERENTIAL_EXPRESSION" + else: + categories[gene] = "HIGH_EXPRESSION" + + # 7. Sort and select top genes + ranked = scores[scores > 0].sort_values(ascending=False) + ranked_gene_ids = ranked.head(top_n).index.tolist() + + # 8. Map gene IDs to names using updated_gene_dict + ranked_gene_names: List[str] = [] + for gene_id in ranked_gene_ids: + gene_name = None + for cond in self.updated_gene_dict.values(): + if gene_id in cond: + gene_info = cond[gene_id] + gene_name = gene_info.get("name") + break + ranked_gene_names.append(gene_name.upper() if gene_name else gene_id) + + # Log top entries with categories + for gene_id in ranked_gene_ids[:10]: + cat = categories.get(gene_id, "UNKNOWN") + score = scores.at[gene_id] + isoform_count = isoform_counts.get(gene_id, 0) + logger.info(f"Top gene {gene_id}: score={score:.3f}, category={cat}, isoforms={isoform_count}") + + # Log single-transcript gene statistics + self._log_single_transcript_statistics(categories, scores, isoform_counts) + + # Add lncRNA-specific statistics and biotype distribution + self._log_biotype_distribution() + self._log_lncrna_statistics(ranked_gene_ids, categories, scores) + + return ranked_gene_names + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _scan_transcript_id_patterns(self) -> None: + """Scan updated_gene_dict for transcript ID patterns and report findings.""" + transcript_generic = 0 # Count of IDs starting with "transcript" + transcript_ensembl = 0 # Count of proper Ensembl IDs (ENSMUST, ENST, etc.) + transcript_other = 0 # Count of other patterns + + generic_examples = [] + ensembl_examples = [] + other_examples = [] + + # Scan all conditions and genes + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + transcripts_dict = gene_info.get("transcripts", {}) + for tx_id in transcripts_dict.keys(): + if tx_id.lower().startswith("transcript"): + transcript_generic += 1 + if len(generic_examples) < 5: + generic_examples.append(f"{gene_id}:{tx_id}") + elif tx_id.startswith("ENSMUST") or tx_id.startswith("ENST") or tx_id.startswith("ENS"): + transcript_ensembl += 1 + if len(ensembl_examples) < 5: + ensembl_examples.append(f"{gene_id}:{tx_id}") + else: + transcript_other += 1 + if len(other_examples) < 5: + other_examples.append(f"{gene_id}:{tx_id}") + + total_transcripts = transcript_generic + transcript_ensembl + transcript_other + + logger.info(f"TRANSCRIPT_SCAN: Found {total_transcripts} total transcript entries") + + if total_transcripts > 0: + logger.info(f"TRANSCRIPT_SCAN: {transcript_generic} transcripts start with 'transcript' ({transcript_generic/total_transcripts*100:.1f}%)") + logger.info(f"TRANSCRIPT_SCAN: {transcript_ensembl} transcripts are Ensembl IDs ({transcript_ensembl/total_transcripts*100:.1f}%)") + logger.info(f"TRANSCRIPT_SCAN: {transcript_other} transcripts have other patterns ({transcript_other/total_transcripts*100:.1f}%)") + else: + logger.warning("TRANSCRIPT_SCAN: No transcripts found in updated_gene_dict") + + if generic_examples: + logger.info(f"TRANSCRIPT_SCAN: Generic transcript examples: {generic_examples}") + if ensembl_examples: + logger.info(f"TRANSCRIPT_SCAN: Ensembl transcript examples: {ensembl_examples}") + if other_examples: + logger.info(f"TRANSCRIPT_SCAN: Other transcript examples: {other_examples}") + + # Warn if too many generic transcripts + if transcript_generic > transcript_ensembl: + logger.warning(f"TRANSCRIPT_SCAN: More generic 'transcript' IDs ({transcript_generic}) than Ensembl IDs ({transcript_ensembl}) - this may indicate annotation issues") + + def _normalize_feature(self, feature: pd.Series, name: str) -> pd.Series: + """Normalize a numeric Series to the 0–1 range. If all values are equal, + return zeros. + + Parameters + ---------- + feature : pd.Series + The numeric data to normalize. + name : str + Name of the feature for logging. + + Returns + ------- + pd.Series + Normalized series with the same index. + """ + if feature.empty: + return pd.Series(dtype=float) + fmin, fmax = feature.min(), feature.max() + if fmax == fmin: + logger.warning(f"{name} values are identical; returning zeros") + return pd.Series(0.0, index=feature.index) + norm = (feature - fmin) / (fmax - fmin) + return norm.fillna(0.0) + + def _extract_gene_tpms_from_dict(self) -> Tuple[pd.Series, pd.Series]: + """Extract average gene TPMs for reference and target conditions. + + This method sums transcript TPMs within each gene and averages across + multiple conditions. It mirrors the corresponding method in the + SimpleGeneRanker but returns pandas Series indexed by gene ID. + """ + ref_gene_tpms: Dict[str, float] = {} + tgt_gene_tpms: Dict[str, float] = {} + ref_count = 0 + tgt_count = 0 + + # Process reference conditions + for cond in self.ref_conditions: + if cond not in self.updated_gene_dict: + continue + ref_count += 1 + genes = self.updated_gene_dict[cond] + for gene_id, gene_info in genes.items(): + tpm = 0.0 + for tx_info in gene_info.get("transcripts", {}).values(): + if isinstance(tx_info, dict) and "value" in tx_info: + tpm += tx_info["value"] + ref_gene_tpms[gene_id] = ref_gene_tpms.get(gene_id, 0.0) + tpm + + # Average across reference conditions + for gene_id in ref_gene_tpms: + if ref_count > 0: + ref_gene_tpms[gene_id] /= ref_count + + # Process target conditions + for cond in self.target_conditions: + if cond not in self.updated_gene_dict: + continue + tgt_count += 1 + genes = self.updated_gene_dict[cond] + for gene_id, gene_info in genes.items(): + tpm = 0.0 + for tx_info in gene_info.get("transcripts", {}).values(): + if isinstance(tx_info, dict) and "value" in tx_info: + tpm += tx_info["value"] + tgt_gene_tpms[gene_id] = tgt_gene_tpms.get(gene_id, 0.0) + tpm + + # Average across target conditions + for gene_id in tgt_gene_tpms: + if tgt_count > 0: + tgt_gene_tpms[gene_id] /= tgt_count + + ref_series = pd.Series(ref_gene_tpms, name="ref_tpm") + tgt_series = pd.Series(tgt_gene_tpms, name="tgt_tpm") + return ref_series, tgt_series + + def _compute_isoform_usage_metrics(self, genes: List[str]) -> Tuple[pd.Series, Dict[str, bool], Dict[str, bool], Dict[str, int]]: + """Compute isoform usage delta, switching flag, functional impact flag and isoform count. + + Parameters + ---------- + genes : list of str + Gene identifiers to process. + + Returns + ------- + usage_delta : pd.Series + Maximum absolute change in isoform usage per gene. + switch_flags : dict + Dictionary mapping gene IDs to True if at least one transcript + increases and another decreases in usage between reference and target. + func_flags : dict + Dictionary mapping gene IDs to True if the isoform switching implies + a change in coding status or functional consequence. + isoform_counts : dict + Dictionary mapping gene IDs to the number of isoforms detected. + """ + usage_delta = pd.Series(0.0, index=genes) + switch_flags: Dict[str, bool] = {} + func_flags: Dict[str, bool] = {} + isoform_counts: Dict[str, int] = {} + + # Build per‑condition transcript TPM dictionaries + ref_tx = self._aggregate_transcript_tpms(self.ref_conditions) + tgt_tx = self._aggregate_transcript_tpms(self.target_conditions) + + # For each gene, compute usage change + for gene in genes: + # Collect transcripts and counts + tx_ids = set() + for cond in self.ref_conditions: + cond_dict = self.updated_gene_dict.get(cond, {}) + if gene in cond_dict: + tx_ids.update(cond_dict[gene].get("transcripts", {}).keys()) + for cond in self.target_conditions: + cond_dict = self.updated_gene_dict.get(cond, {}) + if gene in cond_dict: + tx_ids.update(cond_dict[gene].get("transcripts", {}).keys()) + isoform_counts[gene] = len(tx_ids) + if len(tx_ids) < 2: + switch_flags[gene] = False + func_flags[gene] = False + usage_delta.at[gene] = 0.0 + continue + + # Compute usage per condition + ref_total = 0.0 + tgt_total = 0.0 + ref_usages: Dict[str, float] = {} + tgt_usages: Dict[str, float] = {} + for tx_id in tx_ids: + r_tpm = ref_tx.get(tx_id, 0.0) + t_tpm = tgt_tx.get(tx_id, 0.0) + ref_total += r_tpm + tgt_total += t_tpm + ref_usages[tx_id] = r_tpm + tgt_usages[tx_id] = t_tpm + ref_total += 1e-6 # avoid zero division + tgt_total += 1e-6 + + # Compute usage fractions + deltas = [] + directions = [] + func_change = False + for tx_id in tx_ids: + ref_u = ref_usages[tx_id] / ref_total + tgt_u = tgt_usages[tx_id] / tgt_total + delta = tgt_u - ref_u + deltas.append(abs(delta)) + directions.append(np.sign(delta)) + + # Assess functional impact if annotation exists + # We check across any condition; assume annotation consistent + for cond in self.updated_gene_dict: + cond_dict = self.updated_gene_dict[cond] + if gene in cond_dict: + tx_info = cond_dict[gene].get("transcripts", {}).get(tx_id, {}) + # Compare coding status and ORF length relative to other transcripts + coding = tx_info.get("coding") + orf_len = tx_info.get("orf_length") + func = tx_info.get("functional_consequence") + break + # Simple heuristic: if any transcript has non‑zero functional_consequence + if func is not None: + func_change = True + # Determine switching: at least one positive and one negative change + switch_flags[gene] = (1 in directions) and (-1 in directions) + func_flags[gene] = func_change + usage_delta.at[gene] = max(deltas) if deltas else 0.0 + + return usage_delta, switch_flags, func_flags, isoform_counts + + def _aggregate_transcript_tpms(self, conditions: List[str]) -> Dict[str, float]: + """Aggregate transcript TPMs across a list of conditions, averaging across + conditions. + + Parameters + ---------- + conditions : list of str + Conditions to aggregate. + + Returns + ------- + Dict[str, float] + Mapping from transcript ID to averaged TPM value. + """ + tx_totals: Dict[str, float] = {} + count = 0 + for cond in conditions: + if cond not in self.updated_gene_dict: + continue + count += 1 + cond_dict = self.updated_gene_dict[cond] + for gene_info in cond_dict.values(): + for tx_id, tx_info in gene_info.get("transcripts", {}).items(): + if isinstance(tx_info, dict) and "value" in tx_info: + tx_totals[tx_id] = tx_totals.get(tx_id, 0.0) + tx_info["value"] + # Average + if count > 0: + for tx in tx_totals: + tx_totals[tx] /= count + return tx_totals + + def _log_single_transcript_statistics(self, categories: Dict[str, str], scores: pd.Series, isoform_counts: Dict[str, int]) -> None: + """Log statistics about single-transcript genes and how they performed.""" + single_transcript_genes = [gene_id for gene_id, count in isoform_counts.items() if count == 1] + multi_transcript_genes = [gene_id for gene_id, count in isoform_counts.items() if count > 1] + + single_transcript_scored = [gene_id for gene_id in single_transcript_genes if scores.get(gene_id, 0) > 0] + single_transcript_in_categories = [gene_id for gene_id in single_transcript_genes if categories.get(gene_id) == "SINGLE_TRANSCRIPT_DE"] + + # Get top single-transcript genes by score + single_transcript_scores = {gene_id: scores.get(gene_id, 0) for gene_id in single_transcript_genes} + top_single_transcript = sorted(single_transcript_scores.items(), key=lambda x: x[1], reverse=True)[:5] + + # Get expression info for top single-transcript genes + top_single_examples = [] + for gene_id, score in top_single_transcript[:3]: + if score > 0: + # Find gene name and expression values + gene_name = gene_id + ref_tpm = 0 + tgt_tpm = 0 + + for condition, genes in self.updated_gene_dict.items(): + if gene_id in genes: + gene_info = genes[gene_id] + gene_name = gene_info.get("name", gene_id) + + # Get the single transcript's TPM + transcripts = gene_info.get("transcripts", {}) + if transcripts: + tx_id, tx_info = next(iter(transcripts.items())) + tpm = tx_info.get("value", 0) + + if condition in self.ref_conditions: + ref_tpm += tpm / len(self.ref_conditions) + elif condition in self.target_conditions: + tgt_tpm += tpm / len(self.target_conditions) + + fold_change = (tgt_tpm + 0.1) / (ref_tpm + 0.1) # Add pseudocount + top_single_examples.append({ + "gene_id": gene_id, + "gene_name": gene_name, + "score": score, + "ref_tpm": ref_tpm, + "tgt_tpm": tgt_tpm, + "fold_change": fold_change + }) + + logger.info("=== SINGLE-TRANSCRIPT GENE ANALYSIS ===") + logger.info(f"Total single-transcript genes: {len(single_transcript_genes)}") + logger.info(f"Total multi-transcript genes: {len(multi_transcript_genes)}") + logger.info(f"Single-transcript genes with scores > 0: {len(single_transcript_scored)}") + logger.info(f"Single-transcript genes passing high thresholds: {len(single_transcript_in_categories)}") + + if len(single_transcript_genes) > 0: + pass_rate = (len(single_transcript_in_categories) / len(single_transcript_genes)) * 100 + logger.info(f"Single-transcript gene pass rate: {pass_rate:.1f}% (requires 85th percentile fold change)") + + if top_single_examples: + logger.info("Top single-transcript genes by score:") + for i, example in enumerate(top_single_examples, 1): + logger.info(f" {i}. {example['gene_name']} ({example['gene_id']}): " + f"score={example['score']:.3f}, " + f"ref_TPM={example['ref_tpm']:.1f}, " + f"tgt_TPM={example['tgt_tpm']:.1f}, " + f"FC={example['fold_change']:.2f}x") + + # Compare to multi-transcript genes + multi_transcript_scored = [gene_id for gene_id in multi_transcript_genes if scores.get(gene_id, 0) > 0] + if len(multi_transcript_genes) > 0: + multi_pass_rate = (len(multi_transcript_scored) / len(multi_transcript_genes)) * 100 + logger.info(f"Multi-transcript gene pass rate: {multi_pass_rate:.1f}% (for comparison)") + + def _log_biotype_distribution(self) -> None: + """Log the distribution of gene biotypes in the dataset.""" + biotype_counts = {} + total_genes = 0 + + # Count biotypes across all genes (avoiding duplicates by using first condition) + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + biotype = gene_info.get("biotype", "unknown") + biotype_counts[biotype] = biotype_counts.get(biotype, 0) + 1 + total_genes += 1 + break # Only count from one condition to avoid duplicates + + if total_genes == 0: + logger.warning("No genes found for biotype distribution analysis") + return + + # Sort biotypes by count + sorted_biotypes = sorted(biotype_counts.items(), key=lambda x: x[1], reverse=True) + + logger.info("=== GENE BIOTYPE DISTRIBUTION ===") + logger.info(f"Total genes analyzed: {total_genes}") + + for biotype, count in sorted_biotypes: + percentage = (count / total_genes) * 100 + logger.info(f"{biotype}: {count} genes ({percentage:.1f}%)") + + # Highlight key biotypes + lncrna_count = biotype_counts.get("lncRNA", 0) + biotype_counts.get("long_noncoding_rna", 0) + protein_coding_count = biotype_counts.get("protein_coding", 0) + pseudogene_counts = sum(count for biotype, count in biotype_counts.items() + if "pseudogene" in biotype.lower()) + + logger.info("=== KEY BIOTYPE SUMMARY ===") + logger.info(f"Protein coding genes: {protein_coding_count} ({(protein_coding_count/total_genes)*100:.1f}%)") + logger.info(f"lncRNA genes: {lncrna_count} ({(lncrna_count/total_genes)*100:.1f}%)") + logger.info(f"Pseudogenes (all types): {pseudogene_counts} ({(pseudogene_counts/total_genes)*100:.1f}%)") + + def _log_lncrna_statistics(self, ranked_gene_ids: List[str], categories: Dict[str, str], scores: pd.Series) -> None: + """Analyze and log statistics about lncRNAs in the ranked gene list.""" + lncrna_stats = { + "total_lncrnas": 0, + "top_50_lncrnas": 0, + "top_10_lncrnas": 0, + "isoform_switcher_lncrnas": 0, + "high_expr_lncrnas": 0, + "lncrna_examples": [] + } + + all_lncrnas = [] + + # Scan all genes to find lncRNAs + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_biotype = gene_info.get("biotype", "").lower() + if gene_biotype in ["lncrna", "long_noncoding_rna", "lincrna"]: + lncrna_stats["total_lncrnas"] += 1 + gene_name = gene_info.get("name", gene_id) + score = scores.get(gene_id, 0.0) + category = categories.get(gene_id, "LOW_EXPR") + + all_lncrnas.append({ + "gene_id": gene_id, + "gene_name": gene_name, + "score": score, + "category": category, + "biotype": gene_biotype + }) + + # Count categories + if category == "ISOFORM_SWITCHER": + lncrna_stats["isoform_switcher_lncrnas"] += 1 + elif category == "HIGH_EXPRESSION": + lncrna_stats["high_expr_lncrnas"] += 1 + + # Check if in top rankings + if gene_id in ranked_gene_ids[:50]: + lncrna_stats["top_50_lncrnas"] += 1 + if gene_id in ranked_gene_ids[:10]: + lncrna_stats["top_10_lncrnas"] += 1 + break # Only check one condition to avoid duplicates + + # Sort lncRNAs by score and get top examples + all_lncrnas.sort(key=lambda x: x["score"], reverse=True) + lncrna_stats["lncrna_examples"] = all_lncrnas[:5] + + # Log comprehensive lncRNA statistics + logger.info("=== lncRNA ANALYSIS ===") + logger.info(f"Total lncRNAs detected: {lncrna_stats['total_lncrnas']}") + logger.info(f"lncRNAs in top 50 genes: {lncrna_stats['top_50_lncrnas']}") + logger.info(f"lncRNAs in top 10 genes: {lncrna_stats['top_10_lncrnas']}") + logger.info(f"lncRNA isoform switchers: {lncrna_stats['isoform_switcher_lncrnas']}") + logger.info(f"High-expression lncRNAs: {lncrna_stats['high_expr_lncrnas']}") + + if lncrna_stats["total_lncrnas"] > 0: + top_50_pct = (lncrna_stats["top_50_lncrnas"] / lncrna_stats["total_lncrnas"]) * 100 + logger.info(f"Percentage of lncRNAs in top 50: {top_50_pct:.1f}%") + + # Log top lncRNA examples + if lncrna_stats["lncrna_examples"]: + logger.info("Top scoring lncRNAs:") + for i, lncrna in enumerate(lncrna_stats["lncrna_examples"], 1): + logger.info(f" {i}. {lncrna['gene_name']} ({lncrna['gene_id']}): " + f"score={lncrna['score']:.3f}, category={lncrna['category']}") + + # Analyze transcript complexity for top lncRNAs + self._analyze_lncrna_transcript_complexity(lncrna_stats["lncrna_examples"][:3]) + + def _analyze_lncrna_transcript_complexity(self, top_lncrnas: List[Dict]) -> None: + """Analyze transcript complexity and biotype diversity for top lncRNAs.""" + if not top_lncrnas: + return + + logger.info("=== lncRNA TRANSCRIPT COMPLEXITY ===") + + for lncrna in top_lncrnas: + gene_id = lncrna["gene_id"] + gene_name = lncrna["gene_name"] + + # Find transcript information across conditions + transcript_info = {} + for condition, genes in self.updated_gene_dict.items(): + if gene_id in genes: + transcripts = genes[gene_id].get("transcripts", {}) + for tx_id, tx_info in transcripts.items(): + tx_biotype = tx_info.get("biotype", "unknown") + tx_value = tx_info.get("value", 0.0) + tx_name = tx_info.get("name", tx_id) + + if tx_id not in transcript_info: + transcript_info[tx_id] = { + "name": tx_name, + "biotype": tx_biotype, + "values": [] + } + transcript_info[tx_id]["values"].append(tx_value) + + # Calculate transcript statistics + transcript_count = len(transcript_info) + biotype_counts = {} + active_transcripts = 0 + + for tx_id, tx_data in transcript_info.items(): + biotype = tx_data["biotype"] + biotype_counts[biotype] = biotype_counts.get(biotype, 0) + 1 + + avg_value = sum(tx_data["values"]) / len(tx_data["values"]) if tx_data["values"] else 0 + if avg_value > 1.0: # Consider active if TPM > 1 + active_transcripts += 1 + + logger.info(f"lncRNA {gene_name} ({gene_id}):") + logger.info(f" Total transcripts: {transcript_count}") + logger.info(f" Active transcripts (TPM>1): {active_transcripts}") + logger.info(f" Transcript biotypes: {dict(biotype_counts)}") + + # Show top 3 most expressed transcripts + if transcript_info: + sorted_transcripts = sorted( + transcript_info.items(), + key=lambda x: sum(x[1]["values"]) / len(x[1]["values"]) if x[1]["values"] else 0, + reverse=True + ) + logger.info(" Top transcripts by expression:") + for i, (tx_id, tx_data) in enumerate(sorted_transcripts[:3], 1): + avg_expr = sum(tx_data["values"]) / len(tx_data["values"]) if tx_data["values"] else 0 + logger.info(f" {i}. {tx_data['name']}: {avg_expr:.2f} TPM ({tx_data['biotype']})") diff --git a/visualize.py b/visualize.py index 799c2c37..c14e2d8f 100755 --- a/visualize.py +++ b/visualize.py @@ -1,24 +1,75 @@ #!/usr/bin/env python3 -from src.post_process import OutputConfig, DictionaryBuilder -from src.plot_output import PlotOutput import argparse -from src.process_dict import simplify_and_sum_transcripts -from src.gene_model import rank_and_visualize_genes +import sys +import logging +from src.visualization_output_config import OutputConfig +from src.visualization_dictionary_builder import DictionaryBuilder +from src.visualization_plotter import PlotOutput +from src.visualization_differential_exp import DifferentialAnalysis +from src.visualization_gsea import GSEAAnalysis +from src.visualization_simple_ranker import SimpleGeneRanker +from pathlib import Path + + +def setup_logging(viz_output_dir: Path) -> None: + """Configure centralized logging for all visualization processes.""" + log_file = viz_output_dir / "visualize.log" + + # Create formatters + file_formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(module)s - %(funcName)s - %(levelname)s - %(message)s' + ) + console_formatter = logging.Formatter('%(levelname)s: %(message)s') + + # File handler - detailed logging + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(file_formatter) + + # Console handler - less detailed + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) # Console output at INFO level + console_handler.setFormatter(console_formatter) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) # Root logger at DEBUG level + root_logger.handlers = [] # Clear existing handlers + root_logger.addHandler(file_handler) + root_logger.addHandler(console_handler) + + + logging.info("Initialized centralized logging system") + logging.debug(f"Log file location: {log_file}") + + +def setup_viz_output(output_directory: str, viz_output: str = None) -> Path: + """Set up visualization output directory.""" + if viz_output: + viz_output_dir = Path(viz_output) + else: + viz_output_dir = Path(output_directory) / "visualization" + viz_output_dir.mkdir(parents=True, exist_ok=True) + return viz_output_dir class FindGenesAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): if values is None: - values = 100 # Default value when the flag is used without a value + values = 100 # Default if flag used without value setattr(namespace, self.dest, values) def parse_arguments(): parser = argparse.ArgumentParser(description="Visualize your IsoQuant output.") + + # Positional Argument parser.add_argument( "output_directory", type=str, help="Directory containing IsoQuant output files." ) + + # Optional Arguments parser.add_argument( "--viz_output", type=str, @@ -31,9 +82,6 @@ def parse_arguments(): help="Optional path to a GTF file if unable to be extracted from IsoQuant log", default=None, ) - parser.add_argument( - "--counts", action="store_true", help="Use counts instead of TPM files." - ) parser.add_argument( "--ref_only", action="store_true", @@ -46,138 +94,369 @@ def parse_arguments(): default=None, ) parser.add_argument( + "--gsea", + action="store_true", + help="Perform GSEA analysis on differential expression results", + ) + parser.add_argument( + "--technical_replicates", + type=str, + help="Technical replicate specification. Can be a file path (.txt/.csv) with 'sample,group' format, or inline format 'sample1:group1,sample2:group1,sample3:group2'", + default=None, + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( "--gene_list", type=str, - required=True, - help="Path to a .txt file containing a list of genes, each on its own line.", + help="Path to a .txt file containing a list of genes to evaluate.", ) - parser.add_argument( + group.add_argument( "--find_genes", nargs="?", const=100, type=int, - help="Find genes with the highest combined rank and visualize them. Optionally specify the number of top genes to evaluate (default is 100).", + help="Find top genes with highest combined rank (default 100).", ) - parser.add_argument( - "--known_genes_path", - type=str, - help="Path to a CSV file containing known target genes.", - default=None, + + args = parser.parse_args() + + if args.find_genes is not None: + output = OutputConfig( + args.output_directory, + ref_only=args.ref_only, + gtf=args.gtf, + ) + if output.conditions: + gene_file = output.transcript_grouped_tpm + else: + gene_file = output.transcript_tpm + + if not gene_file or not Path(gene_file).is_file(): + print(f"Error: Grouped TPM/Counts file not found at {gene_file}.") + sys.exit(1) + + with open(gene_file, "r") as f: + header = f.readline().strip().split("\t") + + if len(header) < 2: + print( + "Error: The grouped TPM/Counts file does not contain condition information." + ) + sys.exit(1) + + available_conditions = header[1:] + if not available_conditions: + print("Error: No conditions found in the grouped TPM/Counts file.") + sys.exit(1) + + args.available_conditions = available_conditions + + return args + + +def select_conditions_interactively(args): + print("\nAvailable conditions:") + for idx, condition in enumerate(args.available_conditions, 1): + print(f"{idx}. {condition}") + + def get_selection(prompt, max_selection, exclude=[]): + while True: + try: + choices = input(prompt) + choice_indices = [int(x.strip()) for x in choices.split(",")] + if all(1 <= idx <= max_selection for idx in choice_indices): + selected = [ + args.available_conditions[idx - 1] + for idx in choice_indices + if args.available_conditions[idx - 1] not in exclude + ] + if not selected: + print("No valid conditions selected. Please try again.") + continue + return selected + else: + print(f"Please enter numbers between 1 and {max_selection}.") + except ValueError: + print("Invalid input. Please enter numbers separated by commas.") + + max_idx = len(args.available_conditions) + args.reference_conditions = get_selection( + "\nEnter refs (comma-separated): ", max_idx + ) + selected_refs = set(args.reference_conditions) + args.target_conditions = get_selection( + "\nEnter targets (comma-separated): ", max_idx, exclude=selected_refs ) - return parser.parse_args() + + print("\nSelected Reference Conditions:", ", ".join(args.reference_conditions)) + print("Selected Target Conditions:", ", ".join(args.target_conditions), "\n") + + + def main(): + # First, parse just the output directory argument to set up logging + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("output_directory", type=str, nargs='?') + parser.add_argument("--viz_output", type=str, default=None) + + # Parse just these arguments first + first_args, _ = parser.parse_known_args() + + # Initialize output directory early + if not first_args.output_directory: + print("Error: Output directory is required.") + sys.exit(1) + + # Set up visualization directory early + viz_output_dir = setup_viz_output(first_args.output_directory, first_args.viz_output) + + # Set up logging immediately to capture all operations + setup_logging(viz_output_dir) + logging.info("Starting IsoQuant visualization pipeline") + + # Now parse the full arguments with the real parser args = parse_arguments() + + # If find_genes is specified, get conditions interactively + if args.find_genes is not None: + select_conditions_interactively(args) + + logging.info("Reading IsoQuant parameters.") output = OutputConfig( args.output_directory, - use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf, + technical_replicates=args.technical_replicates, ) dictionary_builder = DictionaryBuilder(output) - gene_list = dictionary_builder.read_gene_list(args.gene_list) - update_names = not all(gene.startswith("ENS") for gene in gene_list) - gene_dict = dictionary_builder.build_gene_transcript_exon_dictionaries() - reads_and_class = ( - dictionary_builder.build_read_assignment_and_classification_dictionaries() - ) - - if output.conditions: - gene_file = ( - output.gene_grouped_tpm - if not output.use_counts - else output.gene_grouped_counts + logging.debug("OutputConfig details:") + logging.debug(f"Output directory: {output.output_directory}") + logging.debug(f"Reference only: {output.ref_only}") + + # Ask user about read assignments (optional) + use_read_assignments = ( + input("Do you want to look at read_assignment data? (y/n): ") + .strip() + .lower() + .startswith("y") + ) + + # If gene_list was given, read it; might use later for some optional steps + if args.gene_list: + logging.info(f"Reading gene list from {args.gene_list}") + gene_list = dictionary_builder.read_gene_list(args.gene_list) + # Decide if you need to rename Genes -> Symbol + update_names = not all(gene.startswith("ENS") for gene in gene_list) + else: + gene_list = None + update_names = True + + min_val = args.filter_transcripts if args.filter_transcripts is not None else 1.0 + logging.debug(f"Building updated_gene_dict with:") + logging.debug(f" min_value: {min_val}") + logging.debug(f" reference_conditions: {getattr(args, 'reference_conditions', None)}") + logging.debug(f" target_conditions: {getattr(args, 'target_conditions', None)}") + + updated_gene_dict = dictionary_builder.build_gene_dict_with_expression_and_filter( + min_value=min_val, + reference_conditions=getattr(args, 'reference_conditions', None), + target_conditions=getattr(args, 'target_conditions', None) + ) + + logging.debug(f"updated_gene_dict created:") + logging.debug(f" type: {type(updated_gene_dict)}") + logging.debug(f" keys (conditions): {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + if updated_gene_dict: + for condition, genes in updated_gene_dict.items(): + logging.info(f" condition '{condition}': {len(genes)} genes") + sample_genes = list(genes.keys())[:3] + if sample_genes: + for gene_id in sample_genes: + gene_info = genes[gene_id] + logging.debug(f" gene '{gene_id}': name='{gene_info.get('name', 'MISSING')}', keys={list(gene_info.keys())}") + if 'transcripts' in gene_info: + logging.debug(f" transcripts: {len(gene_info['transcripts'])} items") + break # Only show details for first condition + + # Debug: log whether gene_dict keys are Ensembl IDs or gene names + if updated_gene_dict: + sample_condition = next(iter(updated_gene_dict)) + sample_keys = list(updated_gene_dict[sample_condition].keys())[:5] + logging.debug( + "Sample gene_dict keys for condition '%s': %s", sample_condition, sample_keys + ) + + # 2. If read assignments are desired, build those as well (cached) + if use_read_assignments: + logging.info("Building read assignment and classification dictionaries.") + reads_and_class = ( + dictionary_builder.build_read_assignment_and_classification_dictionaries() ) + # New: build read length effects aggregates + logging.debug("Building read length effects aggregates.") + try: + length_effects = dictionary_builder.build_read_length_effects() + except Exception as e: + logging.error(f"Failed to compute read length effects: {e}") + length_effects = None + # New: build read length histogram + try: + length_hist = dictionary_builder.build_read_length_histogram() + except Exception as e: + logging.error(f"Failed to compute read length histogram: {e}") + length_hist = None else: - gene_file = output.gene_tpm if not output.use_counts else output.gene_counts + reads_and_class = None + length_effects = None + length_hist = None - updated_gene_dict = dictionary_builder.update_gene_dict(gene_dict, gene_file) - if update_names: - print("Updating Ensembl IDs to gene symbols.") - updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) + # 3. If user wants to find top genes (--find_genes), choose method based on replicate availability + if args.find_genes is not None: + ref_str = "_".join(x.upper().replace(" ", "_") for x in args.reference_conditions) + target_str = "_".join(x.upper().replace(" ", "_") for x in args.target_conditions) + main_dir_name = f"find_genes_{ref_str}_vs_{target_str}" + base_dir = viz_output_dir / main_dir_name if not args.viz_output else viz_output_dir + base_dir.mkdir(exist_ok=True) - if output.ref_only or not output.extended_annotation: - print("Using reference-only based quantification.") - if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_grouped_tpm - if not output.use_counts - else output.transcript_grouped_counts - ), + tech_rep_dict = output.technical_replicates_dict + replicate_ok = output.check_biological_replicates_for_conditions( + args.reference_conditions, args.target_conditions + ) + + if replicate_ok: + logging.info("Finding genes via DESeq2 (replicates detected).") + diff_analysis = DifferentialAnalysis( + output_dir=output.output_directory, + viz_output=base_dir, + ref_conditions=args.reference_conditions, + target_conditions=args.target_conditions, + updated_gene_dict=updated_gene_dict, + ref_only=args.ref_only, + dictionary_builder=dictionary_builder, + tech_rep_dict=tech_rep_dict, ) + gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() + + if args.gsea: + gsea = GSEAAnalysis(output_path=base_dir) + target_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + gsea.run_gsea_analysis(deseq2_df, target_label) + + # Path to DESeq2-derived top genes + top_n = args.find_genes + contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + top_genes_filename = f"genes_of_top_{top_n}_DE_transcripts_{contrast_label}.txt" + find_genes_list_path = gene_results.parent / top_genes_filename + logging.info(f"Reading gene list generated by differential analysis from: {find_genes_list_path}") + + logging.info(f"FLOW_DEBUG: DESeq2 path - reading from file: {find_genes_list_path}") + if find_genes_list_path.exists(): + with open(find_genes_list_path, 'r') as f: + file_contents = f.read().strip().split('\n') + logging.info(f"FLOW_DEBUG: DESeq2 file has {len(file_contents)} lines, first 5: {file_contents[:5]}") + else: + logging.error(f"FLOW_DEBUG: DESeq2 gene list file does not exist: {find_genes_list_path}") + + gene_list = dictionary_builder.read_gene_list(find_genes_list_path) + logging.info(f"FLOW_DEBUG: DESeq2 gene_list after dictionary_builder.read_gene_list:") + logging.info(f" type: {type(gene_list)}") + logging.info(f" length: {len(gene_list) if gene_list else 'None'}") + logging.info(f" content (first 10): {gene_list[:10] if gene_list else 'None'}") else: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_tpm - if not output.use_counts - else output.transcript_counts - ), + logging.info("No biological replicates detected – using SimpleGeneRanker.") + logging.info(f"FLOW_DEBUG: Creating SimpleGeneRanker with:") + logging.info(f" output_dir: {output.output_directory}") + logging.info(f" ref_conditions: {args.reference_conditions}") + logging.info(f" target_conditions: {args.target_conditions}") + logging.info(f" ref_only: {args.ref_only}") + logging.info(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + + simple_ranker = SimpleGeneRanker( + output_dir=output.output_directory, + ref_conditions=args.reference_conditions, + target_conditions=args.target_conditions, + ref_only=args.ref_only, + updated_gene_dict=updated_gene_dict, ) + + logging.info(f"FLOW_DEBUG: Calling simple_ranker.rank(top_n={args.find_genes})") + gene_list = simple_ranker.rank(top_n=args.find_genes) + logging.info(f"FLOW_DEBUG: SimpleGeneRanker returned gene_list with {len(gene_list)} genes") + logging.info(f"FLOW_DEBUG: Gene list type: {type(gene_list)}") + logging.info(f"FLOW_DEBUG: Gene list content (first 10): {gene_list[:10] if gene_list else 'EMPTY'}") + + # Write gene list to file for reproducibility + contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + top_genes_filename = f"genes_of_top_{args.find_genes}_simple_{contrast_label}.txt" + simple_list_path = base_dir / top_genes_filename + import pandas as _pd + _pd.Series(gene_list).to_csv(simple_list_path, index=False, header=False) + logging.info(f"Simple gene list written to {simple_list_path}") + logging.info(f"FLOW_DEBUG: File contents verification:") + try: + with open(simple_list_path, 'r') as f: + file_contents = f.read().strip().split('\n') + logging.info(f"FLOW_DEBUG: File has {len(file_contents)} lines, first 5: {file_contents[:5]}") + except Exception as e: + logging.error(f"FLOW_DEBUG: Error reading written file: {e}") else: - print("Using transcript model quantification.") - if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_model_grouped_tpm - if not output.use_counts - else output.transcript_model_grouped_counts - ), - ) - else: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_model_tpm - if not output.use_counts - else output.transcript_model_counts - ), - ) + base_dir = viz_output_dir - if args.filter_transcripts is not None: - print( - f"Filtering transcripts with minimum value {args.filter_transcripts} in at least one condition." - ) - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( - updated_gene_dict, min_value=args.filter_transcripts - ) - else: - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( - updated_gene_dict - ) + # 5. Set up output directories + gene_visualizations_dir = base_dir / "gene_visualizations" + gene_visualizations_dir.mkdir(exist_ok=True) - # Visualization output directory decision - viz_output_directory = args.viz_output if args.viz_output else args.output_directory - - if args.find_genes: - print("Finding genes.") - simple_gene_dict = simplify_and_sum_transcripts(updated_gene_dict) - path = rank_and_visualize_genes( - simple_gene_dict, - viz_output_directory, - args.find_genes, - known_genes_path=args.known_genes_path, - ) - gene_list = dictionary_builder.read_gene_list(path) + if use_read_assignments: + read_assignments_dir = base_dir / "read_assignments" + read_assignments_dir.mkdir(exist_ok=True) + else: + read_assignments_dir = None # Set to None if not used - # dictionary_builder.save_gene_dict_to_json(updated_gene_dict, viz_output_directory) + # 6. Plotting with PlotOutput + logging.debug(f"Creating PlotOutput with:") + logging.debug(f" gene_names type: {type(gene_list)}") + logging.debug(f" gene_names length: {len(gene_list) if gene_list else 'None'}") + logging.debug(f" gene_names content (first 10): {gene_list[:10] if gene_list else 'None'}") + logging.debug(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + logging.debug(f" conditions: {output.conditions}") + logging.debug(f" filter_transcripts: {min_val}") + logging.debug(f" ref_only: {args.ref_only}") + plot_output = PlotOutput( - updated_gene_dict, - gene_list, - viz_output_directory, - create_visualization_subdir=(viz_output_directory == args.output_directory), + updated_gene_dict=updated_gene_dict, + gene_names=gene_list, + gene_visualizations_dir=str(gene_visualizations_dir), + read_assignments_dir=str(read_assignments_dir) if read_assignments_dir else None, reads_and_class=reads_and_class, - filter_transcripts=args.filter_transcripts, + filter_transcripts=min_val, conditions=output.conditions, - use_counts=args.counts, + ref_only=args.ref_only, + ref_conditions=args.reference_conditions if hasattr(args, "reference_conditions") else None, + target_conditions=args.target_conditions if hasattr(args, "target_conditions") else None, ) + + plot_output.plot_transcript_map() + plot_output.plot_transcript_usage() - plot_output.make_pie_charts() + + + if use_read_assignments: + plot_output.make_pie_charts() + # New: plot read length effects (assignment uniqueness and FSM/ISM/Mono) + if length_effects: + plot_output.plot_read_length_effects(length_effects) + # Also dynamic stacked charts for assignment/classification + plot_output.plot_read_length_vs_assignment({ + 'bins': length_effects['bins'], + 'assignment': { (b, k): v for b in length_effects['bins'] for k, v in length_effects['by_bin_assignment'][b].items() }, + 'classification': { (b, k): v for b in length_effects['bins'] for k, v in length_effects['by_bin_classification'][b].items() }, + }) + # New: plot histogram + if length_hist: + plot_output.plot_read_length_histogram(length_hist) if __name__ == "__main__":