Skip to content

Commit 3a25b39

Browse files
committed
Add an option to prune out excess information in the game state
1 parent eeabe29 commit 3a25b39

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

textworld/render/graph.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from textworld.logic import Proposition
77

88

9-
def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
9+
def build_graph_from_facts(facts: Iterable[Proposition], prune:Optional[bool] = False) -> nx.DiGraph:
1010
""" Builds a graph from a collection of facts.
1111
1212
Arguments:
@@ -18,6 +18,10 @@ def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
1818
G = nx.DiGraph()
1919
labels = {}
2020
for fact in facts:
21+
# Prune out facts that we don't want in our KB representation
22+
if prune and fact.name == 'free':
23+
continue
24+
2125
# Extract relation triplet from fact (subject, object, relation)
2226
triplet = (*fact.names, fact.name)
2327
triplet = triplet if len(triplet) >= 3 else triplet + ("is",)
@@ -41,7 +45,8 @@ def build_graph_from_facts(facts: Iterable[Proposition]) -> nx.DiGraph:
4145
def show_graph(facts: Iterable[Proposition],
4246
title: str = "Knowledge Graph",
4347
renderer:Optional[str] = None,
44-
save:Optional[str] = None) -> "plotly.graph_objs._figure.Figure":
48+
save:Optional[str] = None,
49+
prune:Optional[bool] = False) -> "plotly.graph_objs._figure.Figure":
4550

4651
r""" Visualizes the graph made from a collection of facts.
4752
@@ -80,7 +85,7 @@ def show_graph(facts: Iterable[Proposition],
8085
except:
8186
raise ImportError('Visualization dependencies not installed. Try running `pip install textworld[vis]`')
8287

83-
G = build_graph_from_facts(facts)
88+
G = build_graph_from_facts(facts, prune)
8489

8590
plt.figure(figsize=(16, 9))
8691
pos = nx.drawing.nx_pydot.pydot_layout(G, prog="fdp")

0 commit comments

Comments
 (0)