diff --git a/khmer/_oxli/graphs.pxd b/khmer/_oxli/graphs.pxd index 55339a2ec1..71d7802afa 100644 --- a/khmer/_oxli/graphs.pxd +++ b/khmer/_oxli/graphs.pxd @@ -139,7 +139,7 @@ cdef extern from "oxli/hashgraph.hh" namespace "oxli" nogil: void consume_seqfile_and_tag[SeqIO](const string &, unsigned int, - unsigned long long) + unsigned long long) # Ugly workaround. For some reason, Cython doesn't like *just this* # templated overload -- it chooses whichever was defined last, breaking @@ -147,7 +147,7 @@ cdef extern from "oxli/hashgraph.hh" namespace "oxli" nogil: # the Cython side and give it a real name substitution for code gen. void consume_seqfile_and_tag_readparser "consume_seqfile_and_tag" [SeqIO](shared_ptr[CpReadParser[SeqIO]], unsigned int, - unsigned long long) + unsigned long long) void consume_sequence_and_tag(const string &, unsigned long long &) @@ -160,7 +160,7 @@ cdef extern from "oxli/hashgraph.hh" namespace "oxli" nogil: unsigned int &, unsigned long long &) except +oxli_raise_py_error - uintptr_t trim_on_stoptags(string) + uintptr_t trim_on_stoptags(string) unsigned int traverse_from_kmer(CpKmer, uint32_t, @@ -177,7 +177,7 @@ cdef extern from "oxli/hashgraph.hh" namespace "oxli" nogil: void load_stop_tags(string, bool) except +oxli_raise_py_error void extract_unique_paths(string, uint32_t, float, vector[string]) void calc_connected_graph_size(CpKmer, uint64_t&, KmerSet&, - const uint64_t, bool) + const uint64_t, bool) uint32_t kmer_degree(HashIntoType, HashIntoType) uint32_t kmer_degree(const char *) void find_high_degree_nodes(const char *, set[HashIntoType] &) const diff --git a/khmer/_oxli/graphs.pyx b/khmer/_oxli/graphs.pyx index 14c2d2de96..28fa621dae 100644 --- a/khmer/_oxli/graphs.pyx +++ b/khmer/_oxli/graphs.pyx @@ -128,7 +128,7 @@ cdef class Hashtable: return deref(self._ht_this).get_count( kmer) else: self._kmer_type_error(kmer) - + def ksize(self): """k-mer size""" @@ -211,18 +211,6 @@ cdef class Hashtable: max_count)) return posns - def consume_seqfile_with_reads_parser(self, read_parser): - """Count all k-mers from read_parser.""" - cdef unsigned long long n_consumed = 0 - cdef unsigned int total_reads = 0 - - cdef CPyReadParser_Object* parser = read_parser - - deref(self._ht_this).consume_seqfile[CpFastxReader](parser.parser, - total_reads, - n_consumed) - return total_reads, n_consumed - def consume_seqfile(self, file_name): """Count all k-mers from file_name.""" cdef unsigned long long n_consumed = 0 @@ -273,6 +261,53 @@ cdef class Hashtable: n_consumed) return total_reads, n_consumed + def consume_seqfile_with_parser(self, object read_parser): + """Count all k-mers from read_parser.""" + cdef unsigned long long n_consumed = 0 + cdef unsigned int total_reads = 0 + cdef CPyReadParser_Object* parser = read_parser + + deref(self._ht_this).consume_seqfile[CpFastxReader]( + parser.parser, total_reads, n_consumed + ) + return total_reads, n_consumed + + def consume_seqfile_with_mask_with_parser(self, object read_parser, + Hashtable mask, int threshold=0): + cdef unsigned long long n_consumed = 0 + cdef unsigned int total_reads = 0 + cdef CPyReadParser_Object* parser = read_parser + cdef CpHashtable * cmask = mask._ht_this.get() + deref(self._ht_this).consume_seqfile_with_mask[CpFastxReader]( + parser.parser, cmask, threshold, total_reads, n_consumed + ) + return total_reads, n_consumed + + def consume_seqfile_banding_with_parser(self, object read_parser, num_bands, + band): + """Count all k-mers from file_name.""" + cdef unsigned long long n_consumed = 0 + cdef unsigned int total_reads = 0 + cdef CPyReadParser_Object* parser = read_parser + deref(self._ht_this).consume_seqfile_banding[CpFastxReader]( + parser.parser, num_bands, band, total_reads, n_consumed + ) + return total_reads, n_consumed + + def consume_seqfile_banding_with_mask_with_parser(self, object read_parser, + num_bands, band, + Hashtable mask, + int threshold=0): + cdef unsigned long long n_consumed = 0 + cdef unsigned int total_reads = 0 + cdef CPyReadParser_Object* parser = read_parser + cdef CpHashtable * cmask = mask._ht_this.get() + deref(self._ht_this).consume_seqfile_banding_with_mask[CpFastxReader]( + parser.parser, num_bands, band, cmask, threshold, total_reads, + n_consumed + ) + return total_reads, n_consumed + def abundance_distribution(self, file_name, Hashtable tracking): """Calculate the k-mer abundance distribution over reads in file_name.""" cdef FastxParserPtr parser = get_parser[CpFastxReader](_bstring(file_name)) @@ -284,7 +319,7 @@ cdef class Hashtable: abunds.append(x[i]) return abunds - def abundance_distribution_with_reads_parser(self, object read_parser, Hashtable tracking): + def abundance_distribution_with_parser(self, object read_parser, Hashtable tracking): """Calculate the k-mer abundance distribution over reads.""" cdef CpHashtable * cptracking = tracking._ht_this.get() @@ -486,12 +521,12 @@ cdef class Hashgraph(Hashtable): list; used in graph contraction.''' cdef HashSet hdns = HashSet(self.ksize()) _sequence = self._valid_sequence(sequence) - deref(self._hg_this).find_high_degree_nodes(_sequence, + deref(self._hg_this).find_high_degree_nodes(_sequence, hdns.hs) return hdns - def traverse_linear_path(self, object kmer, HashSet hdns, + def traverse_linear_path(self, object kmer, HashSet hdns, Nodegraph stop_filter=None): '''Traverse the path through the graph starting with the given k-mer and avoiding high-degree nodes, finding (and returning) @@ -539,7 +574,7 @@ cdef class Hashgraph(Hashtable): cdef HashSet hs = HashSet(self.ksize()) deref(self._hg_this).get_tags_for_sequence(_sequence, hs.hs) return hs - + def find_all_tags_list(self, object kmer): '''Find all tags within range of the given k-mer, return as list''' cdef CpKmer _kmer = self._build_kmer(kmer) @@ -548,7 +583,7 @@ cdef class Hashgraph(Hashtable): cdef shared_ptr[CpHashgraph] this = self._hg_this with nogil: - deref(deref(self._hg_this).partition).find_all_tags(_kmer, deref(tags), + deref(deref(self._hg_this).partition).find_all_tags(_kmer, deref(tags), deref(this).all_tags) return result @@ -564,16 +599,16 @@ cdef class Hashgraph(Hashtable): total_reads, n_consumed) return total_reads, n_consumed - + def print_tagset(self, str filename): '''Print out all of the tags.''' deref(self._hg_this).print_tagset(_bstring(filename)) - + def add_tag(self, object kmer): '''Add a k-mer to the tagset.''' cdef HashIntoType _kmer = self.sanitize_hash_kmer(kmer) deref(self._hg_this).add_tag(_kmer) - + def get_tagset(self): '''Get all tagged k-mers as DNA strings.''' cdef HashIntoType st @@ -591,16 +626,16 @@ cdef class Hashgraph(Hashtable): def load_tagset(self, str filename, clear_tags=True): '''Load tags from a file.''' deref(self._hg_this).load_tagset(_bstring(filename), clear_tags) - + def save_tagset(self, str filename): '''Save tags to a file.''' deref(self._hg_this).save_tagset(_bstring(filename)) - + @property def n_tags(self): '''Return the count of all tags.''' return deref(self._hg_this).n_tags() - + def divide_tags_into_subsets(self, int subset_size=0): '''Divide tags equally up into subsets of given size.''' cdef set[HashIntoType] divvy @@ -608,12 +643,12 @@ cdef class Hashgraph(Hashtable): cdef HashSet hs = HashSet(self.ksize()) hs.hs = divvy return hs - + @property def tag_density(self): '''Get the tagging density.''' return deref(self._hg_this)._get_tag_density() - + @tag_density.setter def tag_density(self, int density): '''Set the tagging density.''' @@ -630,7 +665,7 @@ cdef class Hashgraph(Hashtable): cdef HashIntoType end = self.sanitize_hash_kmer(end_kmer) cdef bool cbreak = break_on_stoptags cdef bool cstop = stop_big_traversals - + with nogil: deref(subset_ptr).do_partition(start, end, cbreak, cstop) @@ -650,7 +685,7 @@ cdef class Hashgraph(Hashtable): return ppi - + def assign_partition_id(self, PrePartitionInfo ppi): '''Assign a partition ID to a given tag.''' cdef cp_pre_partition_info * cppi = ppi._this.get() @@ -658,7 +693,7 @@ cdef class Hashgraph(Hashtable): pi = deref(deref(self._hg_this).partition).assign_partition_id(deref(cppi).kmer, deref(cppi).tagged_kmers) return pi - + def output_partitions(self, str filename, str output, bool output_unassigned=False): '''Write out sequences in given filename to another file, annotating ''' @@ -668,7 +703,7 @@ cdef class Hashgraph(Hashtable): _bstring(output), output_unassigned) return n_partitions - + def load_partitionmap(self, str filename): '''Load a partitionmap for the master subset.''' deref(deref(self._hg_this).partition).load_partitionmap(_bstring(filename)) @@ -676,12 +711,12 @@ cdef class Hashgraph(Hashtable): def save_partitionmap(self, str filename): '''Save a partitionmap for the master subset.''' deref(deref(self._hg_this).partition).save_partitionmap(_bstring(filename)) - + def _validate_partitionmap(self): '''Run internal validation checks.''' deref(deref(self._hg_this).partition)._validate_pmap() - - def consume_seqfile_and_tag_with_reads_parser(self, object read_parser): + + def consume_seqfile_and_tag_with_parser(self, object read_parser): '''Count all k-mers using the given reads parser''' cdef unsigned long long n_consumed = 0 cdef unsigned int total_reads = 0 @@ -693,7 +728,7 @@ cdef class Hashgraph(Hashtable): total_reads, n_consumed) return total_reads, n_consumed - + def consume_partitioned_fasta(self, filename): '''Count all k-mers in a given file''' cdef unsigned long long n_consumed = 0 @@ -703,7 +738,7 @@ cdef class Hashgraph(Hashtable): total_reads, n_consumed) return total_reads, n_consumed - + def merge_subset(self, SubsetPartition subset): '''Merge the given subset into this one.''' deref(deref(self._hg_this).partition).merge(subset._this.get()) @@ -711,11 +746,11 @@ cdef class Hashgraph(Hashtable): def merge_subset_from_disk(self, str filename): '''Merge the given subset (filename) into this one.''' deref(deref(self._hg_this).partition).merge_from_disk(_bstring(filename)) - + def count_partitions(self): '''Count the number of partitions in the master partitionmap.''' return self.partition.count_partitions() - + def set_partition_id(self, object kmer, PartitionID pid): '''Set the partition ID for this tag.''' cdef string start = self.sanitize_kmer(kmer) @@ -729,7 +764,7 @@ cdef class Hashgraph(Hashtable): '''Get the partition ID of this tag.''' cdef string _kmer = self.sanitize_kmer(kmer) return deref(deref(self._hg_this).partition).get_partition_id(_kmer) - + def repartition_largest_partition(self, Countgraph counts not None, unsigned int distance, unsigned int threshold, @@ -754,7 +789,7 @@ cdef class Hashgraph(Hashtable): def load_stop_tags(self, object filename, clear_tags=False): '''Load the set of stop tags.''' deref(self._hg_this).load_stop_tags(_bstring(filename), clear_tags) - + def save_stop_tags(self, object filename): '''Save the set of stop tags.''' deref(self._hg_this).save_stop_tags(_bstring(filename)) @@ -762,7 +797,7 @@ cdef class Hashgraph(Hashtable): def print_stop_tags(self, filename): '''Print out the set of stop tags.''' deref(self._hg_this).print_stop_tags(_bstring(filename)) - + def trim_on_stoptags(self, str sequence): '''Trim the reads on the given stop tags.''' cdef size_t trim_at @@ -776,7 +811,7 @@ cdef class Hashgraph(Hashtable): '''Add this k-mer as a stop tag.''' cdef HashIntoType _kmer = self.sanitize_hash_kmer(kmer) deref(self._hg_this).add_stop_tag(_kmer) - + def get_stop_tags(self): '''Return a DNA list of all of the stop tags.''' cdef HashIntoType st diff --git a/oxli/functions.py b/oxli/functions.py index e3608f66da..60f72f6b8a 100755 --- a/oxli/functions.py +++ b/oxli/functions.py @@ -50,9 +50,9 @@ def build_graph(ifilenames, graph, num_threads=1, tags=False): - tags: should there be tags """ if tags: - eat = graph.consume_seqfile_and_tag_with_reads_parser + eat = graph.consume_seqfile_and_tag_with_parser else: - eat = graph.consume_seqfile_with_reads_parser + eat = graph.consume_seqfile_with_parser for _, ifile in enumerate(ifilenames): rparser = khmer.ReadParser(ifile) diff --git a/sandbox/count-kmers-single.py b/sandbox/count-kmers-single.py index aca0d7be2c..5f37842fbe 100755 --- a/sandbox/count-kmers-single.py +++ b/sandbox/count-kmers-single.py @@ -103,7 +103,7 @@ def main(): for _ in range(args.threads): thread = \ threading.Thread( - target=countgraph.consume_seqfile_with_reads_parser, + target=countgraph.consume_seqfile_with_parser, args=(rparser, ) ) threads.append(thread) diff --git a/sandbox/optimal_args_hashbits.py b/sandbox/optimal_args_hashbits.py index 794ad26db5..ec39c8186f 100755 --- a/sandbox/optimal_args_hashbits.py +++ b/sandbox/optimal_args_hashbits.py @@ -81,7 +81,7 @@ def main(): file=sys.stderr) htable = khmer.new_nodegraph(args.ksize, args.max_tablesize, args.n_tables) - target_method = htable.consume_seqfile_with_reads_parser + target_method = htable.consume_seqfile_with_parser for _, filename in enumerate(filenames): rparser = khmer.ReadParser(filename) diff --git a/scripts/abundance-dist-single.py b/scripts/abundance-dist-single.py index 56278cbfa1..866fdd5a46 100755 --- a/scripts/abundance-dist-single.py +++ b/scripts/abundance-dist-single.py @@ -148,7 +148,7 @@ def main(): # pylint: disable=too-many-locals,too-many-branches for _ in range(args.threads): thread = \ threading.Thread( - target=countgraph.consume_seqfile_with_reads_parser, + target=countgraph.consume_seqfile_with_parser, args=(rparser, ) ) threads.append(thread) @@ -163,7 +163,7 @@ def main(): # pylint: disable=too-many-locals,too-many-branches abundance_lists = [] def __do_abundance_dist__(read_parser): - abundances = countgraph.abundance_distribution_with_reads_parser( + abundances = countgraph.abundance_distribution_with_parser( read_parser, tracking) abundance_lists.append(abundances) diff --git a/scripts/filter-abund-single.py b/scripts/filter-abund-single.py index a2e44c37c6..2dccec776c 100755 --- a/scripts/filter-abund-single.py +++ b/scripts/filter-abund-single.py @@ -141,7 +141,7 @@ def main(): for _ in range(args.threads): cur_thread = \ threading.Thread( - target=graph.consume_seqfile_with_reads_parser, + target=graph.consume_seqfile_with_parser, args=(rparser, ) ) threads.append(cur_thread) diff --git a/scripts/load-into-counting.py b/scripts/load-into-counting.py index 8164f4e84a..89eddb0c24 100755 --- a/scripts/load-into-counting.py +++ b/scripts/load-into-counting.py @@ -149,7 +149,7 @@ def main(): for _ in range(args.threads): cur_thrd = \ threading.Thread( - target=countgraph.consume_seqfile_with_reads_parser, + target=countgraph.consume_seqfile_with_parser, args=(rparser, ) ) threads.append(cur_thrd) diff --git a/tests/test_countgraph.py b/tests/test_countgraph.py index 2c8409ca5a..d0fb88c178 100644 --- a/tests/test_countgraph.py +++ b/tests/test_countgraph.py @@ -1188,16 +1188,16 @@ def test_consume_absentfasta(): print(str(err)) -def test_consume_absentfasta_with_reads_parser(): +def test_consume_absentfasta_with_parser(): countgraph = khmer.Countgraph(4, 4 ** 4, 4) try: - countgraph.consume_seqfile_with_reads_parser() + countgraph.consume_seqfile_with_parser() assert 0, "this should fail" except TypeError as err: print(str(err)) try: readparser = ReadParser(utils.get_test_data('empty-file')) - countgraph.consume_seqfile_with_reads_parser(readparser) + countgraph.consume_seqfile_with_parser(readparser) assert 0, "this should fail" except OSError as err: print(str(err)) diff --git a/tests/test_counttable.py b/tests/test_counttable.py index 15b7808a0b..3a833f4ff6 100644 --- a/tests/test_counttable.py +++ b/tests/test_counttable.py @@ -174,3 +174,31 @@ def test_consume_with_mask_threshold(): assert ct.get('ATTTGAGAAAAAA') == 1 assert ct.get('TTTGAGAAAAAAG') == 1 assert ct.get('TTGAGAAAAAAGT') == 1 + + +def consume_with_all_teh_parsers(): + """Test "_with_parser" variant of "consume_seqfile" methods.""" + maskfile = utils.get_test_data('seq-a.fa') + mask = khmer.Counttable(13, 1e3, 4) + mask.consume_seqfile(maskfile) + + infile = utils.get_test_data('seq-b.fa') + ct = khmer.Counttable(13, 1e3, 4) + parser = khmer.ReadParser(infile) + nr, nk = ct.consume_seqfile_with_mask_with_parser(parser, mask) + + assert nr == 1 + assert nk == 3 + assert ct.get('GATTTGAGAAAAA') == 0 # in the mask + assert ct.get('ATTTGAGAAAAAA') == 1 + + ct = khmer.Counttable(13, 1e3, 4) + parser = khmer.ReadParser(infile) + nr, nk = ct.consume_seqfile_banding_with_mask(parser, 4, 1, mask) + + assert nr == 1 + assert nk == 1 + assert ct.get('GATTTGAGAAAAA') == 0 # in the mask + assert ct.get('ATTTGAGAAAAAA') == 0 # out of band + assert ct.get('TTTGAGAAAAAAG') == 0 # out of band + assert ct.get('TTGAGAAAAAAGT') == 1 diff --git a/tests/test_nodegraph.py b/tests/test_nodegraph.py index 249f901acf..f5b871ba59 100644 --- a/tests/test_nodegraph.py +++ b/tests/test_nodegraph.py @@ -907,16 +907,16 @@ def test_bad_primes_list(): print(str(e)) -def test_consume_absentfasta_with_reads_parser(): +def test_consume_absentfasta_with_parser(): nodegraph = khmer.Nodegraph(31, 1, 1) try: - nodegraph.consume_seqfile_with_reads_parser() + nodegraph.consume_seqfile_with_parser() assert 0, "this should fail" except TypeError as err: print(str(err)) try: readparser = ReadParser(utils.get_test_data('empty-file')) - nodegraph.consume_seqfile_with_reads_parser(readparser) + nodegraph.consume_seqfile_with_parser(readparser) assert 0, "this should fail" except OSError as err: print(str(err)) @@ -936,7 +936,7 @@ def test_consume_seqfile_and_tag_with_badreads_parser(): nodegraph = khmer.Nodegraph(6, 1e6, 2) try: readsparser = khmer.ReadParser(utils.get_test_data("test-empty.fa")) - nodegraph.consume_seqfile_and_tag_with_reads_parser(readsparser) + nodegraph.consume_seqfile_and_tag_with_parser(readsparser) assert 0, "this should fail" except OSError as e: print(str(e)) diff --git a/tests/test_tabletype.py b/tests/test_tabletype.py index d67635d516..277a4d9d8a 100644 --- a/tests/test_tabletype.py +++ b/tests/test_tabletype.py @@ -376,7 +376,7 @@ def test_consume_seqfile_reads_parser(AnyTabletype): kh = AnyTabletype(5) rparser = ReadParser(utils.get_test_data('test-fastq-reads.fq')) - kh.consume_seqfile_with_reads_parser(rparser) + kh.consume_seqfile_with_parser(rparser) kh2 = AnyTabletype(5) for record in screed.open(utils.get_test_data('test-fastq-reads.fq')): @@ -460,7 +460,7 @@ def test_abund_dist_A_readparser(AnyTabletype): tracking = Nodegraph(4, 1, 1, primes=PRIMES_1m) kh.consume_seqfile(A_filename) - dist = kh.abundance_distribution_with_reads_parser(rparser, tracking) + dist = kh.abundance_distribution_with_parser(rparser, tracking) print(dist[:10]) assert sum(dist) == 1