17
17
from stable_baselines3 .common .vec_env .patch_gym import _patch_env
18
18
19
19
20
- def _worker (
20
+ def _worker ( # noqa: C901
21
21
remote : mp .connection .Connection ,
22
22
parent_remote : mp .connection .Connection ,
23
23
env_fn_wrapper : CloudpickleWrapper ,
@@ -58,6 +58,12 @@ def _worker(
58
58
remote .send (method (* data [1 ], ** data [2 ]))
59
59
elif cmd == "get_attr" :
60
60
remote .send (env .get_wrapper_attr (data ))
61
+ elif cmd == "has_attr" :
62
+ try :
63
+ env .get_wrapper_attr (data )
64
+ remote .send (True )
65
+ except AttributeError :
66
+ remote .send (False )
61
67
elif cmd == "set_attr" :
62
68
remote .send (setattr (env , data [0 ], data [1 ])) # type: ignore[func-returns-value]
63
69
elif cmd == "is_wrapped" :
@@ -66,6 +72,8 @@ def _worker(
66
72
raise NotImplementedError (f"`{ cmd } ` is not implemented in the worker" )
67
73
except EOFError :
68
74
break
75
+ except KeyboardInterrupt :
76
+ break
69
77
70
78
71
79
class SubprocVecEnv (VecEnv ):
@@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]:
165
173
outputs = [pipe .recv () for pipe in self .remotes ]
166
174
return outputs
167
175
176
+ def has_attr (self , attr_name : str ) -> bool :
177
+ """Check if an attribute exists for a vectorized environment. (see base class)."""
178
+ target_remotes = self ._get_target_remotes (indices = None )
179
+ for remote in target_remotes :
180
+ remote .send (("has_attr" , attr_name ))
181
+ return all ([remote .recv () for remote in target_remotes ])
182
+
168
183
def get_attr (self , attr_name : str , indices : VecEnvIndices = None ) -> list [Any ]:
169
184
"""Return attribute from vectorized environment (see base class)."""
170
185
target_remotes = self ._get_target_remotes (indices )
0 commit comments