|
18 | 18 | from __future__ import division |
19 | 19 | from __future__ import print_function |
20 | 20 |
|
21 | | -from typing import Dict, Optional, Sequence, Text |
| 21 | +from typing import Dict, Optional, Sequence, Text, Any |
22 | 22 |
|
23 | 23 | import ale_py # pylint: disable=unused-import |
24 | 24 | import gin |
@@ -84,13 +84,15 @@ def load( |
84 | 84 | ] = DEFAULT_ATARI_GYM_WRAPPERS, |
85 | 85 | env_wrappers: Sequence[types.PyEnvWrapper] = (), |
86 | 86 | spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None, |
| 87 | + gym_kwargs: Optional[Dict[str, Any]] = None, |
87 | 88 | ) -> py_environment.PyEnvironment: |
88 | 89 | """Loads the selected environment and wraps it with the specified wrappers.""" |
89 | 90 | if spec_dtype_map is None: |
90 | 91 | spec_dtype_map = {gym.spaces.Box: np.uint8} |
91 | 92 |
|
| 93 | + gym_kwargs = gym_kwargs if gym_kwargs else {} |
92 | 94 | gym_spec = gym.spec(environment_name) |
93 | | - gym_env = gym_spec.make() |
| 95 | + gym_env = gym_spec.make(**gym_kwargs) |
94 | 96 |
|
95 | 97 | if max_episode_steps is None and gym_spec.max_episode_steps is not None: |
96 | 98 | max_episode_steps = gym_spec.max_episode_steps |
|
0 commit comments