diff --git a/khmer/utils.py b/khmer/utils.py index d31841703d..f782326ff3 100644 --- a/khmer/utils.py +++ b/khmer/utils.py @@ -34,6 +34,8 @@ # Contact: khmer-project@idyll.org """Helpful methods for performing common argument-checking tasks in scripts.""" from __future__ import print_function, unicode_literals +import numbers +import random def print_error(msg): @@ -43,6 +45,15 @@ def print_error(msg): print(msg, file=sys.stderr) +def check_random_state(seed): + if seed is None or seed is random: + return random + if isinstance(seed, numbers.Integral): + return random.Random(seed) + if isinstance(seed, random.Random): + return seed + + def _split_left_right(name): """Split record name at the first whitespace and return both parts. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..409d722e8a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.fixture() +def seed(): + """Base seed used by all tests that require random numbers""" + return 1 diff --git a/tests/test_assembly.py b/tests/test_assembly.py index d3ba0a4d13..d30c35d3a2 100644 --- a/tests/test_assembly.py +++ b/tests/test_assembly.py @@ -39,12 +39,9 @@ from __future__ import print_function from __future__ import absolute_import -import itertools -import random - import khmer from khmer.khmer_args import estimate_optimal_with_K_and_f as optimal_fp -from khmer import ReadParser +from khmer.utils import check_random_state from khmer import reverse_complement as revcomp from . import khmer_tst_utils as utils @@ -55,6 +52,7 @@ def teardown(): utils.cleanup() + # We just define this globally rather than in a module-level fixture, # as we need it during parameterization and whatnot. K = 21 @@ -71,32 +69,35 @@ def __new__(cls, value, pos=0): return str.__new__(cls, value) -def mutate_base(base): +def mutate_base(base, rng=None): + rng = check_random_state(rng) if base in 'AT': - return random.choice('GC') + return rng.choice('GC') elif base in 'GC': - return random.choice('AT') + return rng.choice('AT') else: assert False, 'bad base' -def mutate_sequence(sequence, N=1): +def mutate_sequence(sequence, N=1, rng=None): + rng = check_random_state(rng) sequence = list(sequence) - positions = random.sample(range(len(sequence)), N) + positions = rng.sample(range(len(sequence)), N) for i in positions: - sequence[i] = mutate_base(sequence[i]) + sequence[i] = mutate_base(sequence[i], rng) return ''.join(sequence) -def mutate_position(sequence, pos): +def mutate_position(sequence, pos, rng=None): + rng = check_random_state(rng) sequence = list(sequence) - sequence[pos] = mutate_base(sequence[pos]) + sequence[pos] = mutate_base(sequence[pos], rng) return ''.join(sequence) -def get_random_sequence(length, exclude=None): +def get_random_sequence(length, exclude=None, rng=None): '''Generate a random (non-looping) nucleotide sequence. To be non-overlapping, the sequence should not include any repeated @@ -109,7 +110,7 @@ def get_random_sequence(length, exclude=None): Returns: str: A random non-looping sequence. ''' - + rng = check_random_state(rng) seen = set() def add_seen(kmer): @@ -120,11 +121,11 @@ def add_seen(kmer): for pos in range(0, len(exclude) - K): add_seen(exclude[pos:pos + K - 1]) - seq = [random.choice('ACGT') for _ in range(K - 1)] # do first K-1 bases + seq = [rng.choice('ACGT') for _ in range(K - 1)] # do first K-1 bases add_seen(''.join(seq)) while(len(seq) < length): - next_base = random.choice('ACGT') + next_base = rng.choice('ACGT') next_kmer = ''.join(seq[-K + 2:] + [next_base]) assert len(next_kmer) == K - 1 if (next_kmer) not in seen: @@ -135,10 +136,11 @@ def add_seen(kmer): return ''.join(seq) -def reads(sequence, L=100, N=100): +def reads(sequence, L=100, N=100, rng=None): + rng = check_random_state(rng) positions = list(range(len(sequence) - L)) for i in range(N): - start = random.choice(positions) + start = rng.choice(positions) yield sequence[start:start + L] @@ -147,22 +149,22 @@ def kmers(sequence): yield sequence[i:i + K] -def test_mutate_sequence(): - for _ in range(100): - assert 'A' not in mutate_sequence('A' * 10, 10) - assert 'T' not in mutate_sequence('T' * 10, 10) - assert 'C' not in mutate_sequence('C' * 10, 10) - assert 'G' not in mutate_sequence('G' * 10, 10) +def test_mutate_sequence(seed): + for i in range(100): + assert 'A' not in mutate_sequence('A' * 10, 10, rng=seed+i) + assert 'T' not in mutate_sequence('T' * 10, 10, rng=seed+i) + assert 'C' not in mutate_sequence('C' * 10, 10, rng=seed+i) + assert 'G' not in mutate_sequence('G' * 10, 10, rng=seed+i) -def test_mutate_position(): - assert mutate_position('AAAA', 2) in ['AACA', 'AAGA'] - assert mutate_position('TTTT', 2) in ['TTCT', 'TTGT'] - assert mutate_position('CCCC', 2) in ['CCAC', 'CCTC'] - assert mutate_position('GGGG', 2) in ['GGAG', 'GGTG'] +def test_mutate_position(seed): + assert mutate_position('AAAA', 2, rng=seed) in ['AACA', 'AAGA'] + assert mutate_position('TTTT', 2, rng=seed) in ['TTCT', 'TTGT'] + assert mutate_position('CCCC', 2, rng=seed) in ['CCAC', 'CCTC'] + assert mutate_position('GGGG', 2, rng=seed) in ['GGAG', 'GGTG'] -def test_reads(): +def test_reads(seed): contigfile = utils.get_test_data('simple-genome.fa') contig = list(screed.open(contigfile))[0].sequence @@ -170,7 +172,7 @@ def test_reads(): assert read in contig for read in reads(contig): - assert mutate_sequence(read) not in contig + assert mutate_sequence(read, rng=seed) not in contig ''' @@ -200,10 +202,10 @@ def known_sequence(request): @pytest.fixture(params=list(range(500, 1600, 500)), ids=lambda val: '(L={0})'.format(val)) -def random_sequence(request): - +def random_sequence(request, seed): def get(exclude=None): - return get_random_sequence(request.param, exclude=exclude) + return get_random_sequence(request.param, exclude=exclude, + rng=seed + request.param) return get @@ -338,7 +340,7 @@ def right_double_fork_structure(request, linear_structure, random_sequence): @pytest.fixture def right_triple_fork_structure(request, right_double_fork_structure, - random_sequence): + random_sequence, seed): ''' Sets up a graph structure like so: @@ -352,7 +354,7 @@ def right_triple_fork_structure(request, right_double_fork_structure, Where S is the start position of the high degreen node (HDN). ''' - + rng = check_random_state(seed) graph, core_sequence, L, HDN, R, top_sequence = right_double_fork_structure bottom_branch = random_sequence(exclude=core_sequence + top_sequence) print(len(core_sequence), len(top_sequence), len(bottom_branch)) @@ -360,7 +362,7 @@ def right_triple_fork_structure(request, right_double_fork_structure, # the branch sequence, mutated at position S+1 # choose a base not already represented at that position bases = {'A', 'C', 'G', 'T'} - mutated = random.choice(list(bases - {R[-1], top_sequence[R.pos + K - 1]})) + mutated = rng.choice(list(bases - {R[-1], top_sequence[R.pos + K - 1]})) bottom_sequence = core_sequence[:HDN.pos + K] + mutated + bottom_branch