Skip to content

Commit

Permalink
Merge branch 'master' into feat/crossq
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Oct 18, 2024
2 parents 244b930 + 3d9a975 commit 497ea7e
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ To run tests with `pytest`:
make pytest
```

Type checking with `pytype` and `mypy`:
Type checking with `mypy`:

```
make type
```

Codestyle check with `black`, `isort` and `flake8`:
Codestyle check with `black` and `ruff`:

```
make check-codestyle
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a8 (WIP)
Release 2.4.0a9 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -20,6 +20,7 @@ Bug Fixes:
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Updated QR-DQN paper link in docs (@corentinlger)
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 4 additions & 5 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def _setup_model(self) -> None:
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions

if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
Expand All @@ -164,8 +163,6 @@ def _setup_model(self) -> None:
f"which corresponds to {self.n_envs} steps."
)

self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

def _create_aliases(self) -> None:
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
Expand All @@ -177,7 +174,9 @@ def _on_step(self) -> None:
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a8
2.4.0a9
12 changes: 12 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
Expand Down Expand Up @@ -488,3 +489,14 @@ def test_save_load_pytorch_var(tmp_path):
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
assert th.allclose(ent_coef_before, ent_coef_after)


def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #258.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
model = QRDQN("MlpPolicy", env, verbose=1, target_update_interval=100)
model.save(tmp_path / "dqn_cartpole")
del model
model = QRDQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100

0 comments on commit 497ea7e

Please sign in to comment.