Skip to content
38 changes: 36 additions & 2 deletions src/poke_env/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
from pettingzoo.utils.env import ParallelEnv # type: ignore[import-untyped]

from poke_env.battle.abstract_battle import AbstractBattle
from poke_env.battle.battle import Battle
from poke_env.battle.double_battle import DoubleBattle
from poke_env.battle.pokemon import Pokemon
from poke_env.concurrency import POKE_LOOP, create_in_poke_loop
from poke_env.player.battle_order import (
BattleOrder,
DoubleBattleOrder,
ForfeitBattleOrder,
_EmptyBattleOrder,
)
Expand Down Expand Up @@ -86,16 +90,46 @@ def __init__(self, *args: Any, **kwargs: Any):
self.battle: Optional[AbstractBattle] = None

def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]:
return self._env_move(battle)
return self._choose_move(battle)

async def _env_move(self, battle: AbstractBattle) -> BattleOrder:
async def _choose_move(self, battle: AbstractBattle) -> BattleOrder:
if not self.battle or self.battle.finished:
self.battle = battle
assert self.battle.battle_tag == battle.battle_tag
await self.battle_queue.async_put(battle)
order = await self.order_queue.async_get()
return order

def teampreview(self, battle: AbstractBattle) -> Awaitable[str]:
return self._teampreview(battle)

async def _teampreview(self, battle: AbstractBattle) -> str:
if isinstance(battle, Battle):
return self.random_teampreview(battle)
elif isinstance(battle, DoubleBattle):
if battle.format is None or "vgc" not in battle.format:
return self.random_teampreview(battle)
species = [p.base_species for p in battle.team.values()]
order1 = await self._choose_move(battle)
if isinstance(order1, (ForfeitBattleOrder, _EmptyBattleOrder)):
return order1.message
assert isinstance(order1, DoubleBattleOrder)
assert isinstance(order1.first_order.order, Pokemon)
assert isinstance(order1.second_order.order, Pokemon)
action1 = species.index(order1.first_order.order.base_species) + 1
action2 = species.index(order1.second_order.order.base_species) + 1
order2 = await self._choose_move(battle)
if isinstance(order2, (ForfeitBattleOrder, _EmptyBattleOrder)):
return order2.message
assert isinstance(order2, DoubleBattleOrder)
assert isinstance(order2.first_order.order, Pokemon)
assert isinstance(order2.second_order.order, Pokemon)
action3 = species.index(order2.first_order.order.base_species) + 1
action4 = species.index(order2.second_order.order.base_species) + 1
return f"/team {action1}{action2}{action3}{action4}"
else:
raise TypeError()

def _battle_finished_callback(self, battle: AbstractBattle):
asyncio.run_coroutine_threadsafe(self.battle_queue.async_put(battle), POKE_LOOP)

Expand Down
7 changes: 5 additions & 2 deletions src/poke_env/player/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ async def _handle_battle_request(
):
message = self.choose_default_move().message
elif battle.teampreview:
message = self.teampreview(battle)
m = self.teampreview(battle)
if isinstance(m, Awaitable):
m = await m
message = m
else:
if maybe_default_order:
self._trying_again.set()
Expand Down Expand Up @@ -680,7 +683,7 @@ def reset_battles(self):
)
self._battles = {}

def teampreview(self, battle: AbstractBattle) -> str:
def teampreview(self, battle: AbstractBattle) -> Union[str, Awaitable[str]]:
"""Returns a teampreview order for the given battle.

This order must be of the form /team TEAM, where TEAM is a string defining the
Expand Down
2 changes: 1 addition & 1 deletion unit_tests/player/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def embed_battle(battle):
player = _EnvPlayer(start_listening=False)
battle = Battle("bat1", player.username, player.logger, gen=8)
player.order_queue.put(ForfeitBattleOrder())
order = asyncio.get_event_loop().run_until_complete(player._env_move(battle))
order = asyncio.get_event_loop().run_until_complete(player._choose_move(battle))
assert isinstance(order, ForfeitBattleOrder)
assert embed_battle(player.battle_queue.get()) == "battle"

Expand Down
Loading