diff --git a/tf_agents/environments/suite_atari.py b/tf_agents/environments/suite_atari.py index 8b02a5bb4..4415703f4 100644 --- a/tf_agents/environments/suite_atari.py +++ b/tf_agents/environments/suite_atari.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from typing import Dict, Optional, Sequence, Text +from typing import Dict, Optional, Sequence, Text, Any import ale_py # pylint: disable=unused-import import gin @@ -84,13 +84,15 @@ def load( ] = DEFAULT_ATARI_GYM_WRAPPERS, env_wrappers: Sequence[types.PyEnvWrapper] = (), spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None, + gym_kwargs: Optional[Dict[str, Any]] = None, ) -> py_environment.PyEnvironment: """Loads the selected environment and wraps it with the specified wrappers.""" if spec_dtype_map is None: spec_dtype_map = {gym.spaces.Box: np.uint8} + gym_kwargs = gym_kwargs if gym_kwargs else {} gym_spec = gym.spec(environment_name) - gym_env = gym_spec.make() + gym_env = gym_spec.make(**gym_kwargs) if max_episode_steps is None and gym_spec.max_episode_steps is not None: max_episode_steps = gym_spec.max_episode_steps