From 99dcdba624168d23aea240a3f1692330bbc7e099 Mon Sep 17 00:00:00 2001 From: "Steven H. Wang" Date: Wed, 18 Dec 2019 12:57:14 -0800 Subject: [PATCH] Update VecNormalize normalization (#609) * VecNormalize: Add public normalize_{obs..,rew} methods * Update changelog * VecNormalize: get_original_{obs,rews} * VecNormalize: Update rewards in reset() Note that after the _update_rews() refactor, self.ret doesn't update anymore if `not self.training`. * update changelog * renames * changelog: fix indent * changelog: nested list needs blank lines * Add tests * Address review, fix tests * update tests * More annotations * Update stable_baselines/common/vec_env/vec_normalize.py Co-Authored-By: Adam Gleave * Address review comments * Defensive copy --- docs/misc/changelog.rst | 9 ++- .../common/vec_env/vec_normalize.py | 70 ++++++++++++------- tests/test_vec_normalize.py | 47 +++++++++++++ 3 files changed, 101 insertions(+), 25 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 681167e9a2..12535e7f71 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -24,7 +24,14 @@ New Features: - Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor - Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation - Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation -- `VecNormalize` now supports being pickled and unpickled. +- `VecNormalize` changes: + + - Now supports being pickled and unpickled (@AdamGleave). + - New methods `.normalize_obs(obs)` and `normalize_reward(rews)` apply normalization + to arbitrary observation or rewards without updating statistics (@shwang) + - `.get_original_reward()` returns the unnormalized rewards from the most recent timestep + - `.reset()` now collects observation statistics (used to only apply normalization) + - Add parameter `exploration_initial_eps` to DQN. (@jdossgollin) - Add type checking and PEP 561 compliance. Note: most functions are still not annotated, this will be a gradual process. diff --git a/stable_baselines/common/vec_env/vec_normalize.py b/stable_baselines/common/vec_env/vec_normalize.py index dc93c5ecbf..6ab308b13f 100644 --- a/stable_baselines/common/vec_env/vec_normalize.py +++ b/stable_baselines/common/vec_env/vec_normalize.py @@ -39,7 +39,8 @@ def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, self.training = training self.norm_obs = norm_obs self.norm_reward = norm_reward - self.old_obs = np.array([]) + self.old_obs = None + self.old_rews = None def __getstate__(self): """ @@ -88,48 +89,69 @@ def step_wait(self): where 'news' is a boolean vector indicating whether each element is new. """ obs, rews, news, infos = self.venv.step_wait() - self.ret = self.ret * self.gamma + rews self.old_obs = obs - obs = self._normalize_observation(obs) - if self.norm_reward: - if self.training: - self.ret_rms.update(self.ret) - rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) + self.old_rews = rews + + if self.training: + self.obs_rms.update(obs) + obs = self.normalize_obs(obs) + + if self.training: + self._update_reward(rews) + rews = self.normalize_reward(rews) + self.ret[news] = 0 return obs, rews, news, infos - def _normalize_observation(self, obs): + def _update_reward(self, reward: np.ndarray) -> None: + """Update reward normalization statistics.""" + self.ret = self.ret * self.gamma + reward + self.ret_rms.update(self.ret) + + def normalize_obs(self, obs: np.ndarray) -> np.ndarray: """ - :param obs: (numpy tensor) + Normalize observations using this VecNormalize's observations statistics. + Calling this method does not update statistics. """ if self.norm_obs: - if self.training: - self.obs_rms.update(obs) - obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, + obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), + -self.clip_obs, self.clip_obs) - return obs - else: - return obs + return obs + + def normalize_reward(self, reward: np.ndarray) -> np.ndarray: + """ + Normalize rewards using this VecNormalize's rewards statistics. + Calling this method does not update statistics. + """ + if self.norm_reward: + reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), + -self.clip_reward, self.clip_reward) + return reward - def get_original_obs(self): + def get_original_obs(self) -> np.ndarray: """ - returns the unnormalized observation + Returns an unnormalized version of the observations from the most recent + step or reset. + """ + return self.old_obs.copy() - :return: (numpy float) + def get_original_reward(self) -> np.ndarray: + """ + Returns an unnormalized version of the rewards from the most recent step. """ - return self.old_obs + return self.old_rews.copy() def reset(self): """ Reset all environments """ obs = self.venv.reset() - if len(np.array(obs).shape) == 1: # for when num_cpu is 1 - self.old_obs = [obs] - else: - self.old_obs = obs + self.old_obs = obs self.ret = np.zeros(self.num_envs) - return self._normalize_observation(obs) + if self.training: + self._update_reward(self.ret) + return self.normalize_obs(obs) @staticmethod def load(load_path, venv): diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 9c70482f4d..6e97ed3c7c 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -76,6 +76,53 @@ def test_vec_env(tmpdir): check_vec_norm_equal(norm_venv, deserialized) +def _make_warmstart_cartpole(): + """Warm-start VecNormalize by stepping through CartPole""" + venv = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + venv = VecNormalize(venv) + venv.reset() + venv.get_original_obs() + + for _ in range(100): + actions = [venv.action_space.sample()] + venv.step(actions) + return venv + + +def test_get_original(): + venv = _make_warmstart_cartpole() + for _ in range(3): + actions = [venv.action_space.sample()] + obs, rewards, _, _ = venv.step(actions) + obs = obs[0] + orig_obs = venv.get_original_obs()[0] + rewards = rewards[0] + orig_rewards = venv.get_original_reward()[0] + + assert np.all(orig_rewards == 1) + assert orig_obs.shape == obs.shape + assert orig_rewards.dtype == rewards.dtype + assert not np.array_equal(orig_obs, obs) + assert not np.array_equal(orig_rewards, rewards) + np.testing.assert_allclose(venv.normalize_obs(orig_obs), obs) + np.testing.assert_allclose(venv.normalize_reward(orig_rewards), rewards) + + +def test_normalize_external(): + venv = _make_warmstart_cartpole() + + rewards = np.array([1, 1]) + norm_rewards = venv.normalize_reward(rewards) + assert norm_rewards.shape == rewards.shape + # Episode return is almost always >= 1 in CartPole. So reward should shrink. + assert np.all(norm_rewards < 1) + + # Don't have any guarantees on obs normalization, except shape, really. + obs = np.array([0, 0, 0, 0]) + norm_obs = venv.normalize_obs(obs) + assert obs.shape == norm_obs.shape + + def test_mpi_runningmeanstd(): """Test RunningMeanStd object for MPI""" return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2',