@@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
7676 being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
7777
7878 Examples:
79+ >>> import torch
80+ >>> from torchrl.envs import ChessEnv
81+ >>> _ = torch.manual_seed(0)
7982 >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
83+ >>> print(env)
84+ TransformedEnv(
85+ env=ChessEnv(),
86+ transform=ActionMask(keys=['action', 'action_mask']))
8087 >>> r = env.reset()
81- >>> env.rand_step(r)
88+ >>> print( env.rand_step(r) )
8289 TensorDict(
8390 fields={
8491 action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
92+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
8593 done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
8694 fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
8795 legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
8896 next: TensorDict(
8997 fields={
98+ action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
9099 done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
91- fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
100+ fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP /RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
92101 legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
93102 pgn: NonTensorData(data=[Event "?"]
94103 [Site "?"]
@@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
97106 [White "?"]
98107 [Black "?"]
99108 [Result "*"]
100- 1. b3 *, batch_size=torch.Size([]), device=None),
109+
110+ 1. f4 *, batch_size=torch.Size([]), device=None),
101111 reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
102- san: NonTensorData(data=b3 , batch_size=torch.Size([]), device=None),
112+ san: NonTensorData(data=f4 , batch_size=torch.Size([]), device=None),
103113 terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
104114 turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
105115 batch_size=torch.Size([]),
@@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
112122 [White "?"]
113123 [Black "?"]
114124 [Result "*"]
125+
115126 *, batch_size=torch.Size([]), device=None),
116- san: NonTensorData(data=[SAN][START] , batch_size=torch.Size([]), device=None),
127+ san: NonTensorData(data=<start> , batch_size=torch.Size([]), device=None),
117128 terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
118129 turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
119130 batch_size=torch.Size([]),
120131 device=None,
121132 is_shared=False)
122- >>> env.rollout(1000)
133+ >>> print( env.rollout(1000) )
123134 TensorDict(
124135 fields={
125- action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
126- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
136+ action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
137+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
138+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
127139 fen: NonTensorStack(
128140 ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
129- batch_size=torch.Size([352 ]),
141+ batch_size=torch.Size([96 ]),
130142 device=None),
131- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
143+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
132144 next: TensorDict(
133145 fields={
134- done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
146+ action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
147+ done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
135148 fen: NonTensorStack(
136- ['rnbqkbnr/pppppppp/8/8/8/N7 /PPPPPPPP/R1BQKBNR b K ...,
137- batch_size=torch.Size([352 ]),
149+ ['rnbqkbnr/pppppppp/8/8/8/5N2 /PPPPPPPP/RNBQKB1R b ...,
150+ batch_size=torch.Size([96 ]),
138151 device=None),
139- legal_moves: Tensor(shape=torch.Size([352 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
152+ legal_moves: Tensor(shape=torch.Size([96 , 219]), device=cpu, dtype=torch.int64, is_shared=False),
140153 pgn: NonTensorStack(
141154 ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
142- batch_size=torch.Size([352 ]),
155+ batch_size=torch.Size([96 ]),
143156 device=None),
144- reward: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
157+ reward: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.float32, is_shared=False),
145158 san: NonTensorStack(
146- ['Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', 'g6 ', 'd4 ', 'd6' ...,
147- batch_size=torch.Size([352 ]),
159+ ['Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8 ', 'Na3 ', 'Ra ...,
160+ batch_size=torch.Size([96 ]),
148161 device=None),
149- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
150- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
151- batch_size=torch.Size([352 ]),
162+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164+ batch_size=torch.Size([96 ]),
152165 device=None,
153166 is_shared=False),
154167 pgn: NonTensorStack(
155168 ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
156- batch_size=torch.Size([352 ]),
169+ batch_size=torch.Size([96 ]),
157170 device=None),
158171 san: NonTensorStack(
159- ['[SAN][START] ', 'Na3 ', 'a5 ', 'Nb1 ', 'Nc6 ', 'a3 ', ...,
160- batch_size=torch.Size([352 ]),
172+ ['<start> ', 'Nf3 ', 'Na6 ', 'c4 ', 'f6 ', 'h4 ', 'Rb8', ...,
173+ batch_size=torch.Size([96 ]),
161174 device=None),
162- terminated: Tensor(shape=torch.Size([352 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
163- turn: Tensor(shape=torch.Size([352 ]), device=cpu, dtype=torch.bool, is_shared=False)},
164- batch_size=torch.Size([352 ]),
175+ terminated: Tensor(shape=torch.Size([96 , 1]), device=cpu, dtype=torch.bool, is_shared=False),
176+ turn: Tensor(shape=torch.Size([96 ]), device=cpu, dtype=torch.bool, is_shared=False)},
177+ batch_size=torch.Size([96 ]),
165178 device=None,
166179 is_shared=False)
167180
@@ -227,13 +240,15 @@ def _legal_moves_to_index(
227240 [self ._san_moves .index (board .san (m )) for m in board .legal_moves ],
228241 dtype = torch .int64 ,
229242 )
230-
243+ mask = None
231244 if return_mask :
232- return self ._move_index_to_mask (indices )
245+ mask = self ._move_index_to_mask (indices )
233246 if pad :
234247 indices = torch .nn .functional .pad (
235248 indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
236249 )
250+ if return_mask :
251+ return indices , mask
237252 return indices
238253
239254 @classmethod
@@ -371,16 +386,19 @@ def _reset(self, tensordict=None):
371386 dest .set ("pgn" , pgn )
372387 dest .set ("turn" , turn )
373388 if self .include_legal_moves :
374- moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
375- dest .set ("legal_moves" , moves_idx )
389+ moves_idx = self ._legal_moves_to_index (
390+ board = self .board , pad = True , return_mask = self .mask_actions
391+ )
376392 if self .mask_actions :
377- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
393+ moves_idx , mask = moves_idx
394+ dest .set ("action_mask" , mask )
395+ dest .set ("legal_moves" , moves_idx )
378396 elif self .mask_actions :
379397 dest .set (
380398 "action_mask" ,
381399 self ._legal_moves_to_index (
382400 board = self .board , pad = True , return_mask = True
383- ),
401+ )[ 1 ] ,
384402 )
385403
386404 if self .pixels :
@@ -527,16 +545,19 @@ def _step(self, tensordict):
527545 dest .set ("san" , san )
528546
529547 if self .include_legal_moves :
530- moves_idx = self ._legal_moves_to_index (board = board , pad = True )
531- dest .set ("legal_moves" , moves_idx )
548+ moves_idx = self ._legal_moves_to_index (
549+ board = board , pad = True , return_mask = self .mask_actions
550+ )
532551 if self .mask_actions :
533- dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
552+ moves_idx , mask = moves_idx
553+ dest .set ("action_mask" , mask )
554+ dest .set ("legal_moves" , moves_idx )
534555 elif self .mask_actions :
535556 dest .set (
536557 "action_mask" ,
537558 self ._legal_moves_to_index (
538559 board = self .board , pad = True , return_mask = True
539- ),
560+ )[ 1 ] ,
540561 )
541562
542563 turn = torch .tensor (board .turn )
0 commit comments