diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 5972fa364..3051723c3 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -832,6 +832,8 @@ def __init__( normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -846,18 +848,21 @@ def __init__( self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): - q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = create_mlp( + features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm + ) q_net = nn.Sequential(*q_net) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor, q_networks=None) -> Tuple[th.Tensor, ...]: + q_networks = q_networks or self.q_networks # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) - return tuple(q_net(qvalue_input) for q_net in self.q_networks) + return tuple(q_net(qvalue_input) for q_net in q_networks) def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: """ diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f87337c62..525ffd410 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -99,6 +99,8 @@ def create_mlp( net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, + dropout_rate: float = 0.0, + layer_norm: bool = False, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is @@ -117,12 +119,22 @@ def create_mlp( """ if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + additional_modules = [] + if dropout_rate > 0.0: + additional_modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + additional_modules.append(nn.LayerNorm(net_arch[0])) + modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()] + else: modules = [] for idx in range(len(net_arch) - 1): modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + if dropout_rate > 0.0: + modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + modules.append(nn.LayerNorm(net_arch[idx + 1])) modules.append(activation_fn()) if output_dim > 0: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index ac9324925..297ccc65f 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -223,6 +223,9 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + # For the critic only + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -263,6 +266,8 @@ def __init__( "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, + "dropout_rate": dropout_rate, + "layer_norm": layer_norm, } ) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 85bdf7897..3b8983309 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -96,6 +96,7 @@ def __init__( replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, + policy_delay: int = 1, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, target_entropy: Union[str, float] = "auto", @@ -144,6 +145,7 @@ def __init__( self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer = None + self.policy_delay = policy_delay if _init_setup_model: self._setup_model() @@ -203,6 +205,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): + self._n_updates += 1 + update_actor = self._n_updates % self.policy_delay == 0 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -211,8 +215,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.actor.reset_noise() # Action by the current actor for the sampled state - actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) - log_prob = log_prob.reshape(-1, 1) + if update_actor: + actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) + log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None: @@ -220,8 +225,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - ent_coef_losses.append(ent_coef_loss.item()) + if update_actor: + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor @@ -237,6 +243,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + # For REDQ, sample q networks to be used + # q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] + # q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] # Compute the next Q values: min over all critics targets next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) @@ -260,16 +269,18 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) - # Min over all critic networks - q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) - min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - min_qf_pi).mean() - actor_losses.append(actor_loss.item()) - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() + if update_actor: + q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) + # Note: REDQ and DropQ does a mean here + # min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - mean_qf_pi).mean() + actor_losses.append(actor_loss.item()) + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: @@ -277,8 +288,6 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) - self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) diff --git a/tests/test_run.py b/tests/test_run.py index 71236a3d9..fef2da29e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -86,6 +86,17 @@ def test_sac(ent_coef): model.learn(total_timesteps=200) +def test_dropq(): + model = SAC( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005), + verbose=1, + buffer_size=250, + ) + model.learn(total_timesteps=300) + + @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG