diff --git a/pathpy/path_extraction/temporal_paths.py b/pathpy/path_extraction/temporal_paths.py index 7a13cd8..ebf6c36 100644 --- a/pathpy/path_extraction/temporal_paths.py +++ b/pathpy/path_extraction/temporal_paths.py @@ -264,13 +264,11 @@ def generate_causal_tree(dag, root, node_map): queue.append((w, depth+1)) # only consider nodes that have not already # been added to this level - if not visited[node_map[w], depth+1]: - # add edge to causal tree - y = '{0}_{1}'.format(node_map[w], depth+1) - edges.append((x, y)) - - visited[node_map[w], depth+1] = True - causal_mapping[y] = node_map[w] + y = '{0}_{1}'.format(node_map[w], depth+1) + if not visited[x, y]: + # add edge to causal tree + edges.append((x, y)) + visited[x, y] = True # Adding all edges at once is more efficient! causal_tree.add_edges(edges) diff --git a/tests/test_DAG.py b/tests/test_DAG.py index 8ae10e9..175261f 100644 --- a/tests/test_DAG.py +++ b/tests/test_DAG.py @@ -240,6 +240,77 @@ def test_dag_from_temporal_network(): assert sorted(dag.routes_to_node('a_5')) == sorted([['c_2', 'a_5'], ['a_1', 'b_4', 'a_5']]) +def test_generate_causal_tree_diamond(): + """ + + (b,d,3)-(d,a,5) + / \ + (a,b,1) (a,e,6) + \ / + (b,c,2)-(c,a,4) + + ---------------------------------> t + + (d,2) + / \ + (a,0)-(b,1) (a,3)-(e,4) + \ / + (c,2) + + -------------------------------> depth + + """ + + tn = pp.TemporalNetwork() + tn.add_edge('a', 'b', 1) + tn.add_edge('b', 'c', 2) + tn.add_edge('b', 'd', 3) + tn.add_edge('c', 'a', 4) + tn.add_edge('d', 'a', 5) + tn.add_edge('a', 'e', 6) + + delta = 10 + + dag, node_map = pp.DAG.from_temporal_network(tn, delta) + root = list(dag.roots)[0] + + causal_tree, causal_mapping = pp.path_extraction.generate_causal_tree(dag, root, node_map) + assert set(causal_tree.edges.keys()) == set([('a_0', 'b_1'), ('b_1', 'd_2'), ('b_1', 'c_2'), ('d_2', 'a_3'), ('c_2', 'a_3'), ('a_3', 'e_4')]) + + +def test_generate_causal_tree_trapezium(): + """ + + (b,c,2)-(c,b,3) + / \ + (a,b,1)-----------------(b,c,4)-(c,d,5) + + ---------------------------------> t + + (b,3)-(c,4)-(d,5) + / + (a,0)-(b,1)-(c,2)-(d,3) + + -------------------------------> depth + + """ + + tn = pp.TemporalNetwork() + tn.add_edge('a', 'b', 1) + tn.add_edge('b', 'c', 2) + tn.add_edge('c', 'b', 3) + tn.add_edge('b', 'c', 4) + tn.add_edge('c', 'd', 5) + + delta = 10 + + dag, node_map = pp.DAG.from_temporal_network(tn, delta) + root = list(dag.roots)[0] + + causal_tree, causal_mapping = pp.path_extraction.generate_causal_tree(dag, root, node_map) + assert set(causal_tree.edges.keys()) == set([('a_0', 'b_1'), ('b_1', 'c_2'), ('c_2', 'd_3'), ('c_2', 'b_3'), ('b_3', 'c_4'), ('c_4', 'd_5')]) + + @pytest.mark.networkx def test_strong_connected_components(random_network): from pathpy.classes.network import network_to_networkx