Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct usage with SB3 / Callbacks? #3

Open
emrul opened this issue Feb 2, 2023 · 1 comment
Open

Correct usage with SB3 / Callbacks? #3

emrul opened this issue Feb 2, 2023 · 1 comment

Comments

@emrul
Copy link

emrul commented Feb 2, 2023

Hi, this looks like a really interesting set of algorithms. I wanted to try some out using the SB3-zoo and was hoping for a plug-and-play approach. I wondered if I could integrate rlexplore using callbacks so I came up with the following:

from stable_baselines3.common.callbacks import BaseCallback
from rlexplore import REVD
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm

class RLeXploreCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.explorer = None
        self.buffer = None
        pass

    def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
        super().init_callback(model)
        env = self.training_env
        self.explorer = REVD(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, device=model.device, latent_dim=128, beta=1e-2, kappa=1e-5)

        if isinstance(self.model, OnPolicyAlgorithm):
            self.buffer = self.model.rollout_buffer
        elif isinstance(self.model, OffPolicyAlgorithm):
            self.buffer = self.model.replay_buffer
        pass

    def _on_rollout_end(self) -> None:
        intrinsic_rewards = self.explorer.compute_irs(
            rollouts={'observations': self.buffer.observations},
            time_steps=self.num_timesteps,
            k=3)
        self.buffer.rewards += intrinsic_rewards[:, :, 0]
        pass

    def _on_step(self) -> bool:
        # TODO maybe log to TensorBoard?
        return True

Then I include it in my list of callbacks and it seems to run. However, I'm still poking around without fully understanding what I'm doing (dangerous!) so does the above look correct? If it is correct, maybe it can be added as an example for others.

Second question is did I do this bit right: time_steps=self.num_timesteps?

Third question I have is that in the examples directory the sample uses rollout_buffer but is it valid to use this for Off Policy algorithms like DQN (switching for the replay_buffer instead?)

@yuanmingqi
Copy link
Collaborator

yuanmingqi commented Feb 12, 2023

Hello, the repository is still under development, and any attempt is welcome. You can make a PR to add more examples.

  1. REVD only supports on-policy algorithms;
  2. time_steps=self.num_timesteps is correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants