diff --git a/kaggle_environments/envs/open_spiel/open_spiel.py b/kaggle_environments/envs/open_spiel/open_spiel.py index 5a939c3d..4ed17bc4 100644 --- a/kaggle_environments/envs/open_spiel/open_spiel.py +++ b/kaggle_environments/envs/open_spiel/open_spiel.py @@ -421,8 +421,32 @@ def random_agent( return {"submission": int(action)} +def debug_agent( + observation: dict[str, Any], + configuration: dict[str, Any], + max_history_length: int = 8, +) -> int: + """A built-in random agent specifically for OpenSpiel environments.""" + del configuration + serialized_game_and_state = observation.get("serializedGameAndState") + if not serialized_game_and_state: + return None + game, state = pyspiel.deserialize_game_and_state(serialized_game_and_state) + if len(state.history()) >= max_history_length: + return { + "submission": pyspiel.INVALID_ACTION, + "status": "Max history length reached; intentionally submitting invalid action.", + } + legal_actions = observation.get("legalActions") + if not legal_actions: + return None + action = random.choice(legal_actions) + return {"submission": int(action)} + + AGENT_REGISTRY = { "random": random_agent, + "debug": debug_agent, } diff --git a/kaggle_environments/envs/open_spiel/test_open_spiel.py b/kaggle_environments/envs/open_spiel/test_open_spiel.py index 43b99cb4..024f0b81 100644 --- a/kaggle_environments/envs/open_spiel/test_open_spiel.py +++ b/kaggle_environments/envs/open_spiel/test_open_spiel.py @@ -1,5 +1,7 @@ -from absl.testing import absltest +import functools import sys + +from absl.testing import absltest from kaggle_environments import make import pyspiel from . import open_spiel as open_spiel_env @@ -91,6 +93,22 @@ def test_agent_error(self): self.assertEqual(json["rewards"], [None, None]) self.assertEqual(json["statuses"], ["ERROR", "ERROR"]) + def test_debug_agent(self): + env = make("open_spiel_chess", debug=True) + max_history_length = 5 + debug_agent = functools.partial( + open_spiel_env.debug_agent, + max_history_length=max_history_length, + ) + env.run([debug_agent, "random"]) + json = env.toJSON() + self.assertEqual(json["rewards"], [-1, 1]) + self.assertEqual(json["statuses"], ["DONE", "DONE"]) + self.assertEqual( + len(json["info"]["actionHistory"]), + max_history_length, + ) + if __name__ == '__main__': absltest.main() \ No newline at end of file