66from textworld .logic import Proposition
77
88
9- def build_graph_from_facts (facts : Iterable [Proposition ]) -> nx .DiGraph :
9+ def build_graph_from_facts (facts : Iterable [Proposition ], prune : Optional [ bool ] = False ) -> nx .DiGraph :
1010 """ Builds a graph from a collection of facts.
1111
1212 Arguments:
@@ -18,6 +18,10 @@ def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
1818 G = nx .DiGraph ()
1919 labels = {}
2020 for fact in facts :
21+ # Prune out facts that we don't want in our KB representation
22+ if prune and fact .name == 'free' :
23+ continue
24+
2125 # Extract relation triplet from fact (subject, object, relation)
2226 triplet = (* fact .names , fact .name )
2327 triplet = triplet if len (triplet ) >= 3 else triplet + ("is" ,)
@@ -41,7 +45,8 @@ def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
4145def show_graph (facts : Iterable [Proposition ],
4246 title : str = "Knowledge Graph" ,
4347 renderer :Optional [str ] = None ,
44- save :Optional [str ] = None ) -> "plotly.graph_objs._figure.Figure" :
48+ save :Optional [str ] = None ,
49+ prune :Optional [bool ] = False ) -> "plotly.graph_objs._figure.Figure" :
4550
4651 r""" Visualizes the graph made from a collection of facts.
4752
@@ -80,7 +85,7 @@ def show_graph(facts: Iterable[Proposition],
8085 except :
8186 raise ImportError ('Visualization dependencies not installed. Try running `pip install textworld[vis]`' )
8287
83- G = build_graph_from_facts (facts )
88+ G = build_graph_from_facts (facts , prune )
8489
8590 plt .figure (figsize = (16 , 9 ))
8691 pos = nx .drawing .nx_pydot .pydot_layout (G , prog = "fdp" )
0 commit comments