From 46e34c247d6d826ca52d97d619cfc79243b86c9f Mon Sep 17 00:00:00 2001 From: AroneyS Date: Wed, 2 Oct 2024 10:50:07 +1000 Subject: [PATCH] fix suffixes not being removed in preclusters --- binchicken/workflow/scripts/target_elusive.py | 18 +++++ test/test_target_elusive.py | 81 ++++++++++++++++++- 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/binchicken/workflow/scripts/target_elusive.py b/binchicken/workflow/scripts/target_elusive.py index 07f16055..66da09b4 100755 --- a/binchicken/workflow/scripts/target_elusive.py +++ b/binchicken/workflow/scripts/target_elusive.py @@ -20,6 +20,7 @@ def get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=2, MAX_COASSEMBLY_SAMPLES=2): logging.info(f"Polars using {str(pl.thread_pool_size())} threads") @@ -31,6 +32,18 @@ def get_clusters( # Set to 2 to produce paired edges MAX_COASSEMBLY_SAMPLES = 2 + sample_distances = ( + sample_distances + .with_columns( + pl.when(pl.col("query_name").is_in(samples)) + .then(pl.col("query_name")) + .otherwise(pl.col("query_name").str.replace(SUFFIX_RE, "")), + pl.when(pl.col("match_name").is_in(samples)) + .then(pl.col("match_name")) + .otherwise(pl.col("match_name").str.replace(SUFFIX_RE, "")), + ) + ) + logging.info("Converting to sparse array") samples = np.unique(np.concatenate([ sample_distances.select("query_name").to_numpy().flatten(), @@ -185,6 +198,8 @@ def process_chunk(df): processed_chunk = process_chunk(chunk) processed_chunk.write_csv(f, separator="\t", include_header=i==0) + logging.info("Done") + return def pipeline( @@ -325,6 +340,7 @@ def process_groups(df): ) sample_preclusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=PRECLUSTER_SIZE, MAX_COASSEMBLY_SAMPLES=MAX_COASSEMBLY_SAMPLES, ) @@ -349,3 +365,5 @@ def process_groups(df): ) targets.write_csv(targets_path, separator="\t") edges.sort("style", "cluster_size", "samples").write_csv(edges_path, separator="\t") + + logging.info("Done") diff --git a/test/test_target_elusive.py b/test/test_target_elusive.py index 3c432b65..ae515348 100644 --- a/test/test_target_elusive.py +++ b/test/test_target_elusive.py @@ -56,13 +56,62 @@ def test_get_clusters(self): ["sample_1", "sample_3", 1-1], ["sample_2", "sample_3", 1-0.9], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["sample_1", "sample_2", "sample_3"]) + + expected_clusters = pl.DataFrame([ + ["sample_1,sample_2"], + ["sample_2,sample_3"], + ], orient="row", schema=CLUSTERS_COLUMNS) + + observed_clusters = get_clusters(sample_distances, samples) + self.assertDataFrameEqual(expected_clusters, observed_clusters) + + def test_get_clusters_suffix(self): + sample_distances = pl.DataFrame([ + ["sample_1.1", "sample_2.1", 1-0.5], + ["sample_1.1", "sample_3.1", 1-1], + ["sample_2.1", "sample_3.1", 1-0.9], + ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["sample_1", "sample_2", "sample_3"]) expected_clusters = pl.DataFrame([ ["sample_1,sample_2"], ["sample_2,sample_3"], ], orient="row", schema=CLUSTERS_COLUMNS) - observed_clusters = get_clusters(sample_distances) + observed_clusters = get_clusters(sample_distances, samples) + self.assertDataFrameEqual(expected_clusters, observed_clusters) + + def test_get_clusters_suffix_underscore(self): + sample_distances = pl.DataFrame([ + ["sample_1_1", "sample_2_1", 1-0.5], + ["sample_1_1", "sample_3_1", 1-1], + ["sample_2_1", "sample_3_1", 1-0.9], + ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["sample_1", "sample_2", "sample_3"]) + + expected_clusters = pl.DataFrame([ + ["sample_1,sample_2"], + ["sample_2,sample_3"], + ], orient="row", schema=CLUSTERS_COLUMNS) + + observed_clusters = get_clusters(sample_distances, samples) + self.assertDataFrameEqual(expected_clusters, observed_clusters) + + def test_get_clusters_suffix_R(self): + sample_distances = pl.DataFrame([ + ["sample_1_R1", "sample_2_R1", 1-0.5], + ["sample_1_R1", "sample_3_R1", 1-1], + ["sample_2_R1", "sample_3_R1", 1-0.9], + ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["sample_1", "sample_2", "sample_3"]) + + expected_clusters = pl.DataFrame([ + ["sample_1,sample_2"], + ["sample_2,sample_3"], + ], orient="row", schema=CLUSTERS_COLUMNS) + + observed_clusters = get_clusters(sample_distances, samples) self.assertDataFrameEqual(expected_clusters, observed_clusters) def test_get_clusters_size_three(self): @@ -71,6 +120,7 @@ def test_get_clusters_size_three(self): ["1", "3", 1-1], ["2", "3", 1-0.9], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -81,6 +131,7 @@ def test_get_clusters_size_three(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=3, MAX_COASSEMBLY_SAMPLES=3, ) @@ -95,6 +146,7 @@ def test_get_clusters_size_three_of_four(self): ["2", "4", 1-0.4], ["3", "4", 1-1], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3", "4"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -111,6 +163,7 @@ def test_get_clusters_size_three_of_four(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=3, MAX_COASSEMBLY_SAMPLES=3, ) @@ -122,6 +175,7 @@ def test_get_clusters_size_three_of_four_missing(self): ["1", "3", 1-0.2], ["2", "3", 1-0.2], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -138,6 +192,7 @@ def test_get_clusters_size_three_of_four_missing(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=3, MAX_COASSEMBLY_SAMPLES=3, ) @@ -152,6 +207,7 @@ def test_get_clusters_size_three_of_four_balanced(self): ["2", "4", 1-0.2], ["3", "4", 1-0.1], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3", "4"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -168,6 +224,7 @@ def test_get_clusters_size_three_of_four_balanced(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=3, MAX_COASSEMBLY_SAMPLES=3, ) @@ -181,6 +238,7 @@ def test_get_clusters_size_three_of_four_balanced_missing(self): ["2", "4", 1-0.2], ["3", "4", 1-0.1], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3", "4"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -197,6 +255,7 @@ def test_get_clusters_size_three_of_four_balanced_missing(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=3, MAX_COASSEMBLY_SAMPLES=3, ) @@ -215,6 +274,7 @@ def test_get_clusters_size_four_of_five(self): ["3", "5", 1-0.3], ["4", "5", 1-0.1], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3", "4", "5"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -241,6 +301,7 @@ def test_get_clusters_size_four_of_five(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=4, MAX_COASSEMBLY_SAMPLES=3, ) @@ -258,6 +319,7 @@ def test_get_clusters_size_four_of_five_missing(self): ["3", "5", 1-0.3], ["4", "5", 1-0.1], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["1", "2", "3", "4", "5"]) expected_clusters = pl.DataFrame([ ["1,2"], @@ -284,6 +346,7 @@ def test_get_clusters_size_four_of_five_missing(self): observed_clusters = get_clusters( sample_distances, + samples, PRECLUSTER_SIZE=4, MAX_COASSEMBLY_SAMPLES=3, ) @@ -301,6 +364,17 @@ def test_get_clusters_real_world(self): ["SRR4249921", "SRR5207344", 1-0.15], ["SRR6979552", "SRR6980357", 1-0.25], ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set([ + "SRR12149290", "SRR10571243", + "ERR3201415", "ERR3220216", + "ERR1414209", "ERR4804028", + "SRR6979552", "SRR15213103", + "SRR12149290", "SRR12352217", + "SRR6979552", "SRR4831657", + "ERR3201415", "SRR11784293", + "SRR4249921", "SRR5207344", + "SRR6979552", "SRR6980357", + ]) expected_clusters = pl.DataFrame([ ["SRR10571243,SRR12149290"], @@ -314,17 +388,18 @@ def test_get_clusters_real_world(self): ["SRR6979552,SRR6980357"], ], orient="row", schema=CLUSTERS_COLUMNS) - observed_clusters = get_clusters(sample_distances) + observed_clusters = get_clusters(sample_distances, samples) self.assertDataFrameEqual(expected_clusters, observed_clusters) def test_get_clusters_empty_inputs(self): sample_distances = pl.DataFrame([ ], orient="row", schema=SAMPLE_DISTANCES_COLUMNS) + samples = set(["sample_1", "sample_2"]) expected_clusters = pl.DataFrame([ ], orient="row", schema=CLUSTERS_COLUMNS) - observed_clusters = get_clusters(sample_distances) + observed_clusters = get_clusters(sample_distances, samples) self.assertDataFrameEqual(expected_clusters, observed_clusters) def test_target_elusive(self):