Skip to content

Commit 9090b0c

Browse files
author
Jan Michelfeit
committed
#625 PebbleStateEntropyReward can switch from unsupervised pretraining
1 parent d348534 commit 9090b0c

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
ReplayBufferView,
1010
ReplayBufferRewardWrapper,
1111
)
12-
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn
12+
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
1313
from imitation.util import util
1414
from imitation.util.networks import RunningNorm
1515

1616

1717
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
1818
# TODO #625: get rid of the observation_space parameter
19-
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
19+
# TODO #625: parametrize nearest_neighbor_k
20+
def __init__(
21+
self,
22+
trained_reward_fn: RewardFn,
23+
observation_space: spaces.Space,
24+
nearest_neighbor_k: int = 5,
25+
):
26+
self.trained_reward_fn = trained_reward_fn
2027
self.nearest_neighbor_k = nearest_neighbor_k
2128
# TODO support n_envs > 1
2229
self.entropy_stats = RunningNorm(1)
@@ -25,24 +32,35 @@ def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
2532
self.replay_buffer_view = ReplayBufferView(
2633
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
2734
)
35+
# This indicates that the training is in the "Unsupervised exploration"
36+
# phase of the Pebble algorithm, where entropy is used as reward
37+
self.unsupervised_exploration_active = True
2838

2939
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
3040
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)
3141

32-
def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape:Tuple):
42+
def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
3343
self.replay_buffer_view = replay_buffer
3444
self.obs_shape = obs_shape
3545

46+
def on_unsupervised_exploration_finished(self):
47+
self.unsupervised_exploration_active = False
48+
3649
def __call__(
3750
self,
3851
state: np.ndarray,
3952
action: np.ndarray,
4053
next_state: np.ndarray,
4154
done: np.ndarray,
4255
) -> np.ndarray:
56+
if self.unsupervised_exploration_active:
57+
return self._entropy_reward(state)
58+
else:
59+
return self.trained_reward_fn(state, action, next_state, done)
60+
61+
def _entropy_reward(self, state):
4362
# TODO: should this work with torch instead of numpy internally?
4463
# (The RewardFn protocol requires numpy)
45-
4664
all_observations = self.replay_buffer_view.observations
4765
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
4866
all_observations = all_observations.reshape(

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pickle
2-
from unittest.mock import patch
2+
from unittest.mock import patch, Mock
33

44
import numpy as np
55
import torch as th
@@ -19,13 +19,14 @@
1919
VENVS = 2
2020

2121

22-
def test_state_entropy_reward_returns_entropy(rng):
22+
def test_pebble_entropy_reward_returns_entropy(rng):
2323
obs_shape = get_obs_shape(SPACE)
2424
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))
2525

26-
27-
reward_fn = PebbleStateEntropyReward(K, SPACE)
28-
reward_fn.set_replay_buffer(ReplayBufferView(all_observations, lambda: slice(None)), obs_shape)
26+
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
27+
reward_fn.set_replay_buffer(
28+
ReplayBufferView(all_observations, lambda: slice(None)), obs_shape
29+
)
2930

3031
# Act
3132
observations = rng.random((BATCH_SIZE, *obs_shape))
@@ -41,16 +42,16 @@ def test_state_entropy_reward_returns_entropy(rng):
4142
np.testing.assert_allclose(reward, expected_normalized)
4243

4344

44-
def test_state_entropy_reward_returns_normalized_values():
45+
def test_pebble_entropy_reward_returns_normalized_values():
4546
with patch("imitation.util.util.compute_state_entropy") as m:
4647
# mock entropy computation so that we can test only stats collection in this test
4748
m.side_effect = lambda obs, all_obs, k: obs
4849

49-
reward_fn = PebbleStateEntropyReward(K, SPACE)
50+
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
5051
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
5152
reward_fn.set_replay_buffer(
5253
ReplayBufferView(all_observations, lambda: slice(None)),
53-
get_obs_shape(SPACE)
54+
get_obs_shape(SPACE),
5455
)
5556

5657
dim = 8
@@ -75,12 +76,12 @@ def test_state_entropy_reward_returns_normalized_values():
7576
)
7677

7778

78-
def test_state_entropy_reward_can_pickle():
79+
def test_pebble_entropy_reward_can_pickle():
7980
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
8081
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
8182

8283
obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
83-
reward_fn = PebbleStateEntropyReward(K, SPACE)
84+
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, K)
8485
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
8586
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
8687

@@ -94,3 +95,33 @@ def test_state_entropy_reward_can_pickle():
9495
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
9596
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
9697
np.testing.assert_allclose(actual_result, expected_result)
98+
99+
100+
def test_pebble_entropy_reward_function_switches_to_inner():
101+
obs_shape = get_obs_shape(SPACE)
102+
103+
expected_reward = np.ones(1)
104+
reward_fn_mock = Mock()
105+
reward_fn_mock.return_value = expected_reward
106+
reward_fn = PebbleStateEntropyReward(reward_fn_mock, SPACE)
107+
108+
# Act
109+
reward_fn.on_unsupervised_exploration_finished()
110+
observations = np.ones((BATCH_SIZE, *obs_shape))
111+
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
112+
113+
# Assert
114+
assert reward == expected_reward
115+
reward_fn_mock.assert_called_once_with(
116+
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
117+
)
118+
119+
120+
def reward_fn_stub(
121+
self,
122+
state: np.ndarray,
123+
action: np.ndarray,
124+
next_state: np.ndarray,
125+
done: np.ndarray,
126+
) -> np.ndarray:
127+
return state

0 commit comments

Comments
 (0)