77import importlib .util
88import io
99import pathlib
10- from typing import Dict , Optional
10+ from typing import Dict
1111
1212import torch
1313from PIL import Image
1414from tensordict import TensorDict , TensorDictBase
15- from torchrl .data import Bounded , Categorical , Composite , NonTensor , Unbounded
15+ from torchrl .data import Binary , Bounded , Categorical , Composite , NonTensor , Unbounded
1616
1717from torchrl .envs import EnvBase
1818from torchrl .envs .common import _EnvPostInit
1919
2020from torchrl .envs .utils import _classproperty
2121
2222
23- class _HashMeta (_EnvPostInit ):
23+ class _ChessMeta (_EnvPostInit ):
2424 def __call__ (cls , * args , ** kwargs ):
2525 instance = super ().__call__ (* args , ** kwargs )
2626 if kwargs .get ("include_hash" ):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
3737 if instance .include_pgn :
3838 in_keys .append ("pgn" )
3939 out_keys .append ("pgn_hash" )
40- return instance .append_transform (Hash (in_keys , out_keys ))
40+ instance = instance .append_transform (Hash (in_keys , out_keys ))
41+ if kwargs .get ("mask_actions" , True ):
42+ from torchrl .envs import ActionMask
43+
44+ instance = instance .append_transform (ActionMask ())
4145 return instance
4246
4347
44- class ChessEnv (EnvBase , metaclass = _HashMeta ):
48+ class ChessEnv (EnvBase , metaclass = _ChessMeta ):
4549 r"""A chess environment that follows the TorchRL API.
4650
4751 This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
6367 include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
6468 include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
6569 include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70+ mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71+ to the env to make sure that the actions are properly masked. Default: ``True``.
6672 pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
6773
6874 .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -202,16 +208,15 @@ def _legal_moves_to_index(
202208 ) -> torch .Tensor :
203209 if not self .stateful :
204210 if tensordict is None :
205- raise RuntimeError (
206- "rand_action requires a tensordict when stateful is False."
207- )
208- if self .include_fen :
209- fen = self ._get_fen (tensordict )
211+ # trust the board
212+ pass
213+ elif self .include_fen :
214+ fen = tensordict .get ("fen" , None )
210215 fen = fen .data
211216 self .board .set_fen (fen )
212217 board = self .board
213218 elif self .include_pgn :
214- pgn = self . _get_pgn ( tensordict )
219+ pgn = tensordict . get ( "pgn" )
215220 pgn = pgn .data
216221 board = self ._pgn_to_board (pgn , self .board )
217222
@@ -224,15 +229,19 @@ def _legal_moves_to_index(
224229 )
225230
226231 if return_mask :
227- return torch .zeros (len (self .san_moves ), dtype = torch .bool ).index_fill_ (
228- 0 , indices , True
229- )
232+ return self ._move_index_to_mask (indices )
230233 if pad :
231234 indices = torch .nn .functional .pad (
232235 indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
233236 )
234237 return indices
235238
239+ @classmethod
240+ def _move_index_to_mask (cls , indices : torch .Tensor ) -> torch .Tensor :
241+ return torch .zeros (len (cls .san_moves ), dtype = torch .bool ).index_fill_ (
242+ 0 , indices , True
243+ )
244+
236245 def __init__ (
237246 self ,
238247 * ,
@@ -242,6 +251,7 @@ def __init__(
242251 include_pgn : bool = False ,
243252 include_legal_moves : bool = False ,
244253 include_hash : bool = False ,
254+ mask_actions : bool = True ,
245255 pixels : bool = False ,
246256 ):
247257 chess = self .lib
@@ -252,6 +262,7 @@ def __init__(
252262 self .include_san = include_san
253263 self .include_fen = include_fen
254264 self .include_pgn = include_pgn
265+ self .mask_actions = mask_actions
255266 self .include_legal_moves = include_legal_moves
256267 if include_legal_moves :
257268 # 218 max possible legal moves per chess board position
@@ -276,8 +287,10 @@ def __init__(
276287
277288 self .stateful = stateful
278289
279- if not self .stateful :
280- self .full_state_spec = self .full_observation_spec .clone ()
290+ # state_spec is loosely defined as such - it's not really an issue that extra keys
291+ # can go missing but it allows us to reset the env using fen passed to the reset
292+ # method.
293+ self .full_state_spec = self .full_observation_spec .clone ()
281294
282295 self .pixels = pixels
283296 if pixels :
@@ -297,16 +310,16 @@ def __init__(
297310 self .full_reward_spec = Composite (
298311 reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
299312 )
313+ if self .mask_actions :
314+ self .full_observation_spec ["action_mask" ] = Binary (
315+ n = len (self .san_moves ), dtype = torch .bool
316+ )
317+
300318 # done spec generated automatically
301319 self .board = chess .Board ()
302320 if self .stateful :
303321 self .action_spec .set_provisional_n (len (list (self .board .legal_moves )))
304322
305- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
306- mask = self ._legal_moves_to_index (tensordict , return_mask = True )
307- self .action_spec .update_mask (mask )
308- return super ().rand_action (tensordict )
309-
310323 def _is_done (self , board ):
311324 return board .is_game_over () | board .is_fifty_moves ()
312325
@@ -316,11 +329,11 @@ def _reset(self, tensordict=None):
316329 if tensordict is not None :
317330 dest = tensordict .empty ()
318331 if self .include_fen :
319- fen = self . _get_fen ( tensordict )
332+ fen = tensordict . get ( "fen" , None )
320333 if fen is not None :
321334 fen = fen .data
322335 elif self .include_pgn :
323- pgn = self . _get_pgn ( tensordict )
336+ pgn = tensordict . get ( "pgn" , None )
324337 if pgn is not None :
325338 pgn = pgn .data
326339 else :
@@ -360,13 +373,18 @@ def _reset(self, tensordict=None):
360373 if self .include_legal_moves :
361374 moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
362375 dest .set ("legal_moves" , moves_idx )
376+ if self .mask_actions :
377+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
378+ elif self .mask_actions :
379+ dest .set (
380+ "action_mask" ,
381+ self ._legal_moves_to_index (
382+ board = self .board , pad = True , return_mask = True
383+ ),
384+ )
385+
363386 if self .pixels :
364387 dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
365-
366- if self .stateful :
367- mask = self ._legal_moves_to_index (dest , return_mask = True )
368- self .action_spec .update_mask (mask )
369-
370388 return dest
371389
372390 _cairosvg_lib = None
@@ -437,16 +455,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
437455 pgn_string = str (game )
438456 return pgn_string
439457
440- @classmethod
441- def _get_fen (cls , tensordict ):
442- fen = tensordict .get ("fen" , None )
443- return fen
444-
445- @classmethod
446- def _get_pgn (cls , tensordict ):
447- pgn = tensordict .get ("pgn" , None )
448- return pgn
449-
450458 def get_legal_moves (self , tensordict = None , uci = False ):
451459 """List the legal moves in a position.
452460
@@ -470,7 +478,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
470478 raise ValueError (
471479 "tensordict must be given since this env is not stateful"
472480 )
473- fen = self . _get_fen ( tensordict ).data
481+ fen = tensordict . get ( "fen" ).data
474482 board .set_fen (fen )
475483 moves = board .legal_moves
476484
@@ -488,10 +496,10 @@ def _step(self, tensordict):
488496 fen = None
489497 if not self .stateful :
490498 if self .include_fen :
491- fen = self . _get_fen ( tensordict ).data
499+ fen = tensordict . get ( "fen" ).data
492500 board .set_fen (fen )
493501 elif self .include_pgn :
494- pgn = self . _get_pgn ( tensordict ).data
502+ pgn = tensordict . get ( "pgn" ).data
495503 board = self ._pgn_to_board (pgn , board )
496504 else :
497505 raise RuntimeError (
@@ -521,6 +529,15 @@ def _step(self, tensordict):
521529 if self .include_legal_moves :
522530 moves_idx = self ._legal_moves_to_index (board = board , pad = True )
523531 dest .set ("legal_moves" , moves_idx )
532+ if self .mask_actions :
533+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
534+ elif self .mask_actions :
535+ dest .set (
536+ "action_mask" ,
537+ self ._legal_moves_to_index (
538+ board = self .board , pad = True , return_mask = True
539+ ),
540+ )
524541
525542 turn = torch .tensor (board .turn )
526543 done = self ._is_done (board )
@@ -540,11 +557,6 @@ def _step(self, tensordict):
540557 dest .set ("terminated" , [done ])
541558 if self .pixels :
542559 dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
543-
544- if self .stateful :
545- mask = self ._legal_moves_to_index (dest , return_mask = True )
546- self .action_spec .update_mask (mask )
547-
548560 return dest
549561
550562 def _set_seed (self , * args , ** kwargs ):
0 commit comments