Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQN #347

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

IQN #347

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Tuned hyperparameters for PPO on Swimmer
- Added ``-tags/--wandb-tags`` argument to ``train.py`` to add tags to the wandb run
- Added a sb3 version tag to the wandb run
- Added support for `IQN` from SB3 contrib

### Bug fixes
- Allow `python -m rl_zoo3.cli` to be called directly
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ List and videos of trained agents can be found on our Huggingface page: https://
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| IQN | | | | | | | |
| QR-DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |

Additional Atari Games (to be completed):
Expand All @@ -426,6 +427,7 @@ Additional Atari Games (to be completed):
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| IQN | | | |
| QR-DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |


Expand All @@ -437,6 +439,7 @@ Additional Atari Games (to be completed):
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | N/A |
| IQN | | | | N/A | N/A |
| QR-DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | N/A |
| DDPG | N/A | N/A | N/A | :heavy_check_mark: | :heavy_check_mark: |
| SAC | N/A | N/A | N/A | :heavy_check_mark: | :heavy_check_mark: |
Expand All @@ -453,6 +456,7 @@ Additional Atari Games (to be completed):
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
| DQN | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| IQN | N/A | | N/A | N/A | N/A |
| QR-DQN | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| DDPG | :heavy_check_mark: | N/A | :heavy_check_mark: | | |
| SAC | :heavy_check_mark: | N/A | :heavy_check_mark: | :heavy_check_mark: | |
Expand Down
10 changes: 10 additions & 0 deletions hyperparams/iqn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
atari:
env_wrapper:
- stable_baselines3.common.atari_wrappers.AtariWrapper
frame_stack: 4
policy: 'CnnPolicy'
n_timesteps: !!float 1e7
exploration_fraction: 0.025 # explore 250k steps = 10M * 0.025
# If True, you need to deactivate handle_timeout_termination
# in the replay_buffer_kwargs
optimize_memory_usage: False
2 changes: 1 addition & 1 deletion rl_zoo3/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
if algo == "her":
continue

if algo in ["dqn", "qrdqn", "ddpg", "sac", "td3", "tqc"]:
if algo in ["dqn", "iqn", "qrdqn", "ddpg", "sac", "td3", "tqc"]:
n_envs = 1
n_timesteps *= args.n_envs

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def enjoy() -> None: # noqa: C901
print(f"Loading {model_path}")

# Off-policy algorithm only support one env for now
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
off_policy_algos = ["iqn", "qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]

if algo in off_policy_algos:
args.n_envs = 1
Expand Down
19 changes: 18 additions & 1 deletion rl_zoo3/hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def sample_qrdqn_params(trial: optuna.Trial) -> Dict[str, Any]:
:param trial:
:return:
"""
# TQC is DQN + Distributional RL
# QR-DQN is DQN + Distributional RL
hyperparams = sample_dqn_params(trial)

n_quantiles = trial.suggest_int("n_quantiles", 5, 200)
Expand All @@ -480,6 +480,22 @@ def sample_qrdqn_params(trial: optuna.Trial) -> Dict[str, Any]:
return hyperparams


def sample_iqn_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for IQN hyperparams.

:param trial:
:return:
"""
# IQN is QR-DQN + Implicit Q Network
hyperparams = sample_qrdqn_params(trial)

hyperparams["num_tau_samples"] = trial.suggest_int("num_tau_samples", 5, 200)
hyperparams["num_tau_prime_samples"] = trial.suggest_int("num_tau_prime_samples", 5, 200)

return hyperparams


def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for ARS hyperparams.
Expand Down Expand Up @@ -524,6 +540,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
"ars": sample_ars_params,
"ddpg": sample_ddpg_params,
"dqn": sample_dqn_params,
"iqn": sample_iqn_params,
"qrdqn": sample_qrdqn_params,
"sac": sample_sac_params,
"tqc": sample_tqc_params,
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/push_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def package_to_hub(
print(f"Loading {model_path}")

# Off-policy algorithm only support one env for now
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
off_policy_algos = ["iqn", "qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]

if algo in off_policy_algos:
args.n_envs = 1
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)

print(f"Loading {model_path}")
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
off_policy_algos = ["iqn", "qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]

set_random_seed(args.seed)

Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gym import spaces
from huggingface_hub import HfApi
from huggingface_sb3 import EnvironmentName, ModelName
from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from sb3_contrib import ARS, IQN, QRDQN, TQC, TRPO, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
Expand All @@ -31,6 +31,7 @@
"td3": TD3,
# SB3 Contrib,
"ars": ARS,
"iqn": IQN,
"qrdqn": QRDQN,
"tqc": TQC,
"trpo": TRPO,
Expand Down