diff --git a/src/poke_env/environment/env.py b/src/poke_env/environment/env.py index 8c9f6535a..082a2b3e7 100644 --- a/src/poke_env/environment/env.py +++ b/src/poke_env/environment/env.py @@ -15,9 +15,13 @@ from pettingzoo.utils.env import ParallelEnv # type: ignore[import-untyped] from poke_env.battle.abstract_battle import AbstractBattle +from poke_env.battle.battle import Battle +from poke_env.battle.double_battle import DoubleBattle +from poke_env.battle.pokemon import Pokemon from poke_env.concurrency import POKE_LOOP, create_in_poke_loop from poke_env.player.battle_order import ( BattleOrder, + DoubleBattleOrder, ForfeitBattleOrder, _EmptyBattleOrder, ) @@ -86,9 +90,9 @@ def __init__(self, *args: Any, **kwargs: Any): self.battle: Optional[AbstractBattle] = None def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]: - return self._env_move(battle) + return self._choose_move(battle) - async def _env_move(self, battle: AbstractBattle) -> BattleOrder: + async def _choose_move(self, battle: AbstractBattle) -> BattleOrder: if not self.battle or self.battle.finished: self.battle = battle assert self.battle.battle_tag == battle.battle_tag @@ -96,6 +100,36 @@ async def _env_move(self, battle: AbstractBattle) -> BattleOrder: order = await self.order_queue.async_get() return order + def teampreview(self, battle: AbstractBattle) -> Awaitable[str]: + return self._teampreview(battle) + + async def _teampreview(self, battle: AbstractBattle) -> str: + if isinstance(battle, Battle): + return self.random_teampreview(battle) + elif isinstance(battle, DoubleBattle): + if battle.format is None or "vgc" not in battle.format: + return self.random_teampreview(battle) + species = [p.base_species for p in battle.team.values()] + order1 = await self._choose_move(battle) + if isinstance(order1, (ForfeitBattleOrder, _EmptyBattleOrder)): + return order1.message + assert isinstance(order1, DoubleBattleOrder) + assert isinstance(order1.first_order.order, Pokemon) + assert isinstance(order1.second_order.order, Pokemon) + action1 = species.index(order1.first_order.order.base_species) + 1 + action2 = species.index(order1.second_order.order.base_species) + 1 + order2 = await self._choose_move(battle) + if isinstance(order2, (ForfeitBattleOrder, _EmptyBattleOrder)): + return order2.message + assert isinstance(order2, DoubleBattleOrder) + assert isinstance(order2.first_order.order, Pokemon) + assert isinstance(order2.second_order.order, Pokemon) + action3 = species.index(order2.first_order.order.base_species) + 1 + action4 = species.index(order2.second_order.order.base_species) + 1 + return f"/team {action1}{action2}{action3}{action4}" + else: + raise TypeError() + def _battle_finished_callback(self, battle: AbstractBattle): asyncio.run_coroutine_threadsafe(self.battle_queue.async_put(battle), POKE_LOOP) diff --git a/src/poke_env/player/player.py b/src/poke_env/player/player.py index 5515e03e9..2bebd2902 100644 --- a/src/poke_env/player/player.py +++ b/src/poke_env/player/player.py @@ -412,7 +412,10 @@ async def _handle_battle_request( ): message = self.choose_default_move().message elif battle.teampreview: - message = self.teampreview(battle) + m = self.teampreview(battle) + if isinstance(m, Awaitable): + m = await m + message = m else: if maybe_default_order: self._trying_again.set() @@ -680,7 +683,7 @@ def reset_battles(self): ) self._battles = {} - def teampreview(self, battle: AbstractBattle) -> str: + def teampreview(self, battle: AbstractBattle) -> Union[str, Awaitable[str]]: """Returns a teampreview order for the given battle. This order must be of the form /team TEAM, where TEAM is a string defining the diff --git a/unit_tests/player/test_env.py b/unit_tests/player/test_env.py index c04db7ca6..ca2156aca 100644 --- a/unit_tests/player/test_env.py +++ b/unit_tests/player/test_env.py @@ -72,7 +72,7 @@ def embed_battle(battle): player = _EnvPlayer(start_listening=False) battle = Battle("bat1", player.username, player.logger, gen=8) player.order_queue.put(ForfeitBattleOrder()) - order = asyncio.get_event_loop().run_until_complete(player._env_move(battle)) + order = asyncio.get_event_loop().run_until_complete(player._choose_move(battle)) assert isinstance(order, ForfeitBattleOrder) assert embed_battle(player.battle_queue.get()) == "battle"