Skip to content

Commit 1370fac

Browse files
committed
Get score from game progression if available
1 parent c07c582 commit 1370fac

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

textworld/envs/glulx/git_glulx_ml.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _detect_i7_events_debug_tags(text: str) -> Tuple[List[str], str]:
8989
"""
9090
matches = []
9191
open_tags = []
92-
for match in re.findall("\[[^]]+\]\n?", text):
92+
for match in re.findall(r"\[[^]]+\]\n?", text):
9393
text = text.replace(match, "") # Remove i7 debug tags.
9494
tag_name = match.strip()[1:-1] # Strip starting '[' and trailing ']'.
9595

@@ -127,8 +127,6 @@ def __init__(self, *args, **kwargs):
127127
:param kwargs: The kwargs
128128
"""
129129
super().__init__(*args, **kwargs)
130-
self._has_won = False
131-
self._has_lost = False
132130
self.has_timeout = False
133131
self._state_tracking = False
134132
self._compute_intermediate_reward = False
@@ -153,7 +151,7 @@ def init(self, output: str, game: Game,
153151
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0
154152
self._objective = game.objective
155153
self._score = 0
156-
self._max_score = sum(quest.reward for quest in game.quests)
154+
self._max_score = self._game_progression.max_score
157155

158156
def view(self) -> "GlulxGameState":
159157
"""
@@ -218,12 +216,6 @@ def update(self, command: str, output: str) -> "GlulxGameState":
218216
# An action that affects the state of the game.
219217
game_state._game_progression.update(game_state._action)
220218

221-
if game_state._compute_intermediate_reward:
222-
if game_state._game_progression.winning_policy is None:
223-
game_state._has_lost = True
224-
elif len(game_state._game_progression.winning_policy) == 0:
225-
game_state._has_won = True
226-
227219
return game_state
228220

229221
@property
@@ -321,18 +313,22 @@ def intermediate_reward(self):
321313
@property
322314
def score(self):
323315
if not hasattr(self, "_score"):
324-
# Check if there was any Inform7 events.
325-
if self._feedback == self._raw:
326-
self._score = self.previous_state.score
316+
if self._state_tracking:
317+
self._score = self._game_progression.score
327318
else:
328-
output = self._raw
329-
if not self.game_ended:
330-
output = self._env._send("score")
331319

332-
match = re.search("scored (?P<score>[0-9]+) out of a possible (?P<max_score>[0-9]+),", output)
333-
self._score = 0
334-
if match:
335-
self._score = int(match.groupdict()["score"])
320+
# Check if there was any Inform7 events.
321+
if self._feedback == self._raw:
322+
self._score = self.previous_state.score
323+
else:
324+
output = self._raw
325+
if not self.game_ended:
326+
output = self._env._send("score")
327+
328+
match = re.search("scored (?P<score>[0-9]+) out of a possible (?P<max_score>[0-9]+),", output)
329+
self._score = 0
330+
if match:
331+
self._score = int(match.groupdict()["score"])
336332

337333
return self._score
338334

@@ -342,11 +338,23 @@ def max_score(self):
342338

343339
@property
344340
def has_won(self):
345-
return self._has_won or '*** The End ***' in self.feedback
341+
if not hasattr(self, "_has_won"):
342+
if self._compute_intermediate_reward:
343+
self._has_won = self._game_progression.completed
344+
else:
345+
self._has_won = '*** The End ***' in self.feedback
346+
347+
return self._has_won
346348

347349
@property
348350
def has_lost(self):
349-
return self._has_lost or '*** You lost! ***' in self.feedback
351+
if not hasattr(self, "_has_lost"):
352+
if self._compute_intermediate_reward:
353+
self._has_lost = self._game_progression.failed
354+
else:
355+
self._has_lost = '*** You lost! ***' in self.feedback
356+
357+
return self._has_lost
350358

351359
@property
352360
def game_ended(self) -> bool:

textworld/generator/game.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def __init__(self, quest: Quest) -> None:
564564
Args:
565565
quest: The quest to keep track of its completion.
566566
"""
567-
self._quest = quest
567+
self.quest = quest
568568
self._completed = False
569569
self._failed = False
570570
self._unfinishable = False
@@ -617,12 +617,12 @@ def update(self, action: Optional[Action] = None, state: Optional[State] = None)
617617

618618
if state is not None:
619619
# Check if quest is completed.
620-
if self._quest.win_action is not None:
621-
self._completed = state.is_applicable(self._quest.win_action)
620+
if self.quest.win_action is not None:
621+
self._completed = state.is_applicable(self.quest.win_action)
622622

623623
# Check if quest has failed.
624-
if self._quest.fail_action is not None:
625-
self._failed = state.is_applicable(self._quest.fail_action)
624+
if self.quest.fail_action is not None:
625+
self._failed = state.is_applicable(self.quest.fail_action)
626626

627627
# Try compressing the winning policy given the new game state.
628628
if self.compress_winning_policy(state):
@@ -696,17 +696,33 @@ def __init__(self, game: Game, track_quests: bool = True) -> None:
696696
@property
697697
def done(self) -> bool:
698698
""" Whether all quests are completed or at least one has failed or is unfinishable. """
699+
return self.completed or self.failed
700+
701+
@property
702+
def completed(self) -> bool:
703+
""" Whether all quests are completed. """
699704
if not self.tracking_quests:
700-
return False # There is nothing to be "done".
705+
return False # There is nothing to be "completed".
701706

702-
all_completed = True
703-
for quest_progression in self.quest_progressions:
704-
if quest_progression.failed or quest_progression.unfinishable:
705-
return True
707+
return all(qp.completed for qp in self.quest_progressions)
706708

707-
all_completed &= quest_progression.completed
709+
@property
710+
def failed(self) -> bool:
711+
""" Whether at least one quest has failed or is unfinishable. """
712+
if not self.tracking_quests:
713+
return False # There is nothing to be "failed".
714+
715+
return any((qp.failed or qp.unfinishable) for qp in self.quest_progressions)
708716

709-
return all_completed
717+
@property
718+
def score(self) -> int:
719+
""" Sum of the reward of all completed quests. """
720+
return sum(qp.quest.reward for qp in self.quest_progressions if qp.completed)
721+
722+
@property
723+
def max_score(self) -> int:
724+
""" Sum of the reward of all quests. """
725+
return sum(quest.reward for quest in self.game.quests)
710726

711727
@property
712728
def tracking_quests(self) -> bool:

0 commit comments

Comments
 (0)