diff --git a/souporcell_pipeline.py b/souporcell_pipeline.py index f001574..0e3fe73 100755 --- a/souporcell_pipeline.py +++ b/souporcell_pipeline.py @@ -145,17 +145,22 @@ def get_fasta_regions(fastaname, threads): return regions -def get_bam_regions(bamname, threads): +def get_bam_regions(bamname, threads, known_chroms=None): bam = pysam.AlignmentFile(bamname) + refs = list(bam.references) + if known_chroms is not None: + refs = [r for r in refs if r in known_chroms] + assert len(refs) > 0, "No BAM references matched the chromosomes from the provided VCF(s)." + total_reference_length = 0 - for chrom in bam.references: + for chrom in refs: total_reference_length += bam.get_reference_length(chrom) step_length = int(math.ceil(total_reference_length / threads)) regions = [] region = [] region_so_far = 0 chrom_so_far = 0 - for chrom in bam.references: + for chrom in refs: chrom_length = bam.get_reference_length(chrom) #print(chrom+" size "+str(chrom_length)+" and step size "+str(step_length)) while True: @@ -180,6 +185,26 @@ def get_bam_regions(bamname, threads): return regions +def read_vcf_chroms(paths): + """ + Return a set of chromosome names seen in one or more VCF files. + Only inspects body lines (ignores headers entirely). + """ + if isinstance(paths, str): + paths = [paths] + + contigs = set() + + for path in paths: + with open_function(path) as fh: + for line in fh: + if line.startswith("#"): + continue + chrom = line.split("\t", 1)[0] + contigs.add(chrom) + + return contigs + def make_fastqs(args): if not os.path.isfile(args.bam + ".bai"): print("no bam index found, creating") @@ -350,9 +375,12 @@ def freebayes(args, bam, fasta): if not(args.known_genotypes == None): print("using known genotypes") args.common_variants = args.known_genotypes + + known_chroms = read_vcf_chroms(args.common_variants) + print(f"Restricting BAM regions to {len(known_chroms)} chromosomes from provided VCF(s).") # parallelize the samtools depth call. It takes too long - regions = get_bam_regions(bam, int(args.threads)) + regions = get_bam_regions(bam, int(args.threads), known_chroms=known_chroms) depth_files = [] depth_procs = [] print(len(regions))