6
6
import warnings
7
7
from abc import ABC , abstractmethod
8
8
from collections import deque
9
- from typing import Any , ClassVar , Dict , Iterable , List , Optional , Tuple , Type , TypeVar , Union
9
+ from collections .abc import Iterable
10
+ from typing import Any , ClassVar , Optional , TypeVar , Union
10
11
11
12
import gymnasium as gym
12
13
import numpy as np
@@ -94,7 +95,7 @@ class BaseAlgorithm(ABC):
94
95
"""
95
96
96
97
# Policy aliases (see _get_policy_from_name())
97
- policy_aliases : ClassVar [Dict [str , Type [BasePolicy ]]] = {}
98
+ policy_aliases : ClassVar [dict [str , type [BasePolicy ]]] = {}
98
99
policy : BasePolicy
99
100
observation_space : spaces .Space
100
101
action_space : spaces .Space
@@ -104,10 +105,10 @@ class BaseAlgorithm(ABC):
104
105
105
106
def __init__ (
106
107
self ,
107
- policy : Union [str , Type [BasePolicy ]],
108
+ policy : Union [str , type [BasePolicy ]],
108
109
env : Union [GymEnv , str , None ],
109
110
learning_rate : Union [float , Schedule ],
110
- policy_kwargs : Optional [Dict [str , Any ]] = None ,
111
+ policy_kwargs : Optional [dict [str , Any ]] = None ,
111
112
stats_window_size : int = 100 ,
112
113
tensorboard_log : Optional [str ] = None ,
113
114
verbose : int = 0 ,
@@ -117,7 +118,7 @@ def __init__(
117
118
seed : Optional [int ] = None ,
118
119
use_sde : bool = False ,
119
120
sde_sample_freq : int = - 1 ,
120
- supported_action_spaces : Optional [Tuple [ Type [spaces .Space ], ...]] = None ,
121
+ supported_action_spaces : Optional [tuple [ type [spaces .Space ], ...]] = None ,
121
122
) -> None :
122
123
if isinstance (policy , str ):
123
124
self .policy_class = self ._get_policy_from_name (policy )
@@ -141,10 +142,10 @@ def __init__(
141
142
self .start_time = 0.0
142
143
self .learning_rate = learning_rate
143
144
self .tensorboard_log = tensorboard_log
144
- self ._last_obs = None # type: Optional[Union[np.ndarray, Dict [str, np.ndarray]]]
145
+ self ._last_obs = None # type: Optional[Union[np.ndarray, dict [str, np.ndarray]]]
145
146
self ._last_episode_starts = None # type: Optional[np.ndarray]
146
147
# When using VecNormalize:
147
- self ._last_original_obs = None # type: Optional[Union[np.ndarray, Dict [str, np.ndarray]]]
148
+ self ._last_original_obs = None # type: Optional[Union[np.ndarray, dict [str, np.ndarray]]]
148
149
self ._episode_num = 0
149
150
# Used for gSDE only
150
151
self .use_sde = use_sde
@@ -283,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps
283
284
"""
284
285
self ._current_progress_remaining = 1.0 - float (num_timesteps ) / float (total_timesteps )
285
286
286
- def _update_learning_rate (self , optimizers : Union [List [th .optim .Optimizer ], th .optim .Optimizer ]) -> None :
287
+ def _update_learning_rate (self , optimizers : Union [list [th .optim .Optimizer ], th .optim .Optimizer ]) -> None :
287
288
"""
288
289
Update the optimizers learning rate using the current learning rate schedule
289
290
and the current progress remaining (from 1 to 0).
@@ -299,7 +300,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
299
300
for optimizer in optimizers :
300
301
update_learning_rate (optimizer , self .lr_schedule (self ._current_progress_remaining ))
301
302
302
- def _excluded_save_params (self ) -> List [str ]:
303
+ def _excluded_save_params (self ) -> list [str ]:
303
304
"""
304
305
Returns the names of the parameters that should be excluded from being
305
306
saved by pickling. E.g. replay buffers are skipped by default
@@ -320,7 +321,7 @@ def _excluded_save_params(self) -> List[str]:
320
321
"_custom_logger" ,
321
322
]
322
323
323
- def _get_policy_from_name (self , policy_name : str ) -> Type [BasePolicy ]:
324
+ def _get_policy_from_name (self , policy_name : str ) -> type [BasePolicy ]:
324
325
"""
325
326
Get a policy class from its name representation.
326
327
@@ -337,7 +338,7 @@ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
337
338
else :
338
339
raise ValueError (f"Policy { policy_name } unknown" )
339
340
340
- def _get_torch_save_params (self ) -> Tuple [ List [str ], List [str ]]:
341
+ def _get_torch_save_params (self ) -> tuple [ list [str ], list [str ]]:
341
342
"""
342
343
Get the name of the torch variables that will be saved with
343
344
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
@@ -387,7 +388,7 @@ def _setup_learn(
387
388
reset_num_timesteps : bool = True ,
388
389
tb_log_name : str = "run" ,
389
390
progress_bar : bool = False ,
390
- ) -> Tuple [int , BaseCallback ]:
391
+ ) -> tuple [int , BaseCallback ]:
391
392
"""
392
393
Initialize different variables needed for training.
393
394
@@ -435,7 +436,7 @@ def _setup_learn(
435
436
436
437
return total_timesteps , callback
437
438
438
- def _update_info_buffer (self , infos : List [ Dict [str , Any ]], dones : Optional [np .ndarray ] = None ) -> None :
439
+ def _update_info_buffer (self , infos : list [ dict [str , Any ]], dones : Optional [np .ndarray ] = None ) -> None :
439
440
"""
440
441
Retrieve reward, episode length, episode success and update the buffer
441
442
if using Monitor wrapper or a GoalEnv.
@@ -535,11 +536,11 @@ def learn(
535
536
536
537
def predict (
537
538
self ,
538
- observation : Union [np .ndarray , Dict [str , np .ndarray ]],
539
- state : Optional [Tuple [np .ndarray , ...]] = None ,
539
+ observation : Union [np .ndarray , dict [str , np .ndarray ]],
540
+ state : Optional [tuple [np .ndarray , ...]] = None ,
540
541
episode_start : Optional [np .ndarray ] = None ,
541
542
deterministic : bool = False ,
542
- ) -> Tuple [np .ndarray , Optional [Tuple [np .ndarray , ...]]]:
543
+ ) -> tuple [np .ndarray , Optional [tuple [np .ndarray , ...]]]:
543
544
"""
544
545
Get the policy action from an observation (and optional hidden state).
545
546
Includes sugar-coating to handle different observations (e.g. normalizing images).
@@ -640,11 +641,11 @@ def set_parameters(
640
641
641
642
@classmethod
642
643
def load ( # noqa: C901
643
- cls : Type [SelfBaseAlgorithm ],
644
+ cls : type [SelfBaseAlgorithm ],
644
645
path : Union [str , pathlib .Path , io .BufferedIOBase ],
645
646
env : Optional [GymEnv ] = None ,
646
647
device : Union [th .device , str ] = "auto" ,
647
- custom_objects : Optional [Dict [str , Any ]] = None ,
648
+ custom_objects : Optional [dict [str , Any ]] = None ,
648
649
print_system_info : bool = False ,
649
650
force_reset : bool = True ,
650
651
** kwargs ,
@@ -800,7 +801,7 @@ def load( # noqa: C901
800
801
model .policy .reset_noise () # type: ignore[operator]
801
802
return model
802
803
803
- def get_parameters (self ) -> Dict [str , Dict ]:
804
+ def get_parameters (self ) -> dict [str , dict ]:
804
805
"""
805
806
Return the parameters of the agent. This includes parameters from different networks, e.g.
806
807
critics (value functions) and policies (pi functions).
0 commit comments