Skip to content

Commit a0bacca

Browse files
author
Jan Michelfeit
committed
#641 fix static analysis and tests
1 parent 531b353 commit a0bacca

File tree

6 files changed

+69
-22
lines changed

6 files changed

+69
-22
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Reward function for the PEBBLE training algorithm."""
22

33
import enum
4-
from typing import Optional, Tuple
4+
from typing import Any, Callable, Optional, Tuple
55

66
import gym
77
import numpy as np
@@ -18,10 +18,16 @@
1818

1919

2020
class InsufficientObservations(RuntimeError):
21+
"""Error signifying not enough observations for entropy calculation."""
22+
2123
pass
2224

2325

2426
class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
27+
"""RewardNet wrapping entropy reward function."""
28+
29+
__call__: Callable[..., Any] # Needed to appease pytype
30+
2531
def __init__(
2632
self,
2733
nearest_neighbor_k: int,
@@ -53,6 +59,9 @@ def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper)
5359
5460
This method needs to be called, e.g., after unpickling.
5561
See also __getstate__() / __setstate__().
62+
63+
Args:
64+
replay_buffer: replay buffer with history of observations
5665
"""
5766
assert self.observation_space == replay_buffer.observation_space
5867
assert self.action_space == replay_buffer.action_space
@@ -72,16 +81,18 @@ def forward(
7281
all_observations = self._replay_buffer_view.observations
7382
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
7483
all_observations = all_observations.reshape(
75-
(-1,) + self.observation_space.shape
84+
(-1,) + self.observation_space.shape,
7685
)
7786

7887
if all_observations.shape[0] < self.nearest_neighbor_k:
7988
raise InsufficientObservations(
80-
"Insufficient observations for entropy calculation"
89+
"Insufficient observations for entropy calculation",
8190
)
8291

8392
return util.compute_state_entropy(
84-
state, all_observations, self.nearest_neighbor_k
93+
state,
94+
all_observations,
95+
self.nearest_neighbor_k,
8596
)
8697

8798
def preprocess(
@@ -95,6 +106,15 @@ def preprocess(
95106
96107
We also know forward() only works with state, so no need to convert
97108
other tensors.
109+
110+
Args:
111+
state: The observation input.
112+
action: The action input.
113+
next_state: The observation input.
114+
done: Whether the episode has terminated.
115+
116+
Returns:
117+
Observations preprocessed by converting them to Tensor.
98118
"""
99119
state_th = util.safe_to_tensor(state).to(self.device)
100120
action_th = next_state_th = done_th = th.empty(0)
@@ -172,8 +192,8 @@ def __call__(
172192
try:
173193
return self.entropy_reward_fn(state, action, next_state, done)
174194
except InsufficientObservations:
175-
# not enough observations to compare to, fall back to the learned function;
176-
# (falling back to a constant may also be ok)
195+
# not enough observations to compare to, fall back to the learned
196+
# function; (falling back to a constant may also be ok)
177197
return self.learned_reward_fn(state, action, next_state, done)
178198
else:
179199
return self.learned_reward_fn(state, action, next_state, done)

src/imitation/algorithms/preference_comparisons.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,17 @@ def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
9696
"""Pre-train an agent before collecting comparisons.
9797
9898
Override this behavior in subclasses that implement pre-training.
99-
If not overriden, this method raises ValueError when non-zero steps are
99+
If not overridden, this method raises ValueError when non-zero steps are
100100
allocated for pre-training.
101101
102102
Args:
103103
steps: number of environment steps to train for.
104104
**kwargs: additional keyword arguments to pass on to
105105
the training procedure.
106+
107+
Raises:
108+
ValueError: Unsupervised pre-training not implemented but non-zero
109+
steps are allocated for pre-training.
106110
"""
107111
if steps > 0:
108112
raise ValueError(

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,21 @@ def make_reward_function(
8080
reward_net.predict_processed,
8181
update_stats=False,
8282
)
83-
observation_space = reward_net.observation_space
84-
action_space = reward_net.action_space
8583
if pebble_enabled:
8684
relabel_reward_fn = create_pebble_reward_fn(
8785
relabel_reward_fn,
8886
pebble_nearest_neighbor_k,
89-
action_space,
90-
observation_space,
87+
reward_net.action_space,
88+
reward_net.observation_space,
9189
)
9290
return relabel_reward_fn
9391

9492

9593
def create_pebble_reward_fn(
96-
relabel_reward_fn, pebble_nearest_neighbor_k, action_space, observation_space
94+
relabel_reward_fn,
95+
pebble_nearest_neighbor_k,
96+
action_space,
97+
observation_space,
9798
):
9899
entropy_reward_net = EntropyRewardNet(
99100
nearest_neighbor_k=pebble_nearest_neighbor_k,
@@ -111,7 +112,8 @@ def __call__(self, *args, **kwargs) -> np.ndarray:
111112
return normalized_entropy_reward_net.predict_processed(*args, **kwargs)
112113

113114
def on_replay_buffer_initialized(
114-
self, replay_buffer: ReplayBufferRewardWrapper
115+
self,
116+
replay_buffer: ReplayBufferRewardWrapper,
115117
):
116118
entropy_reward_net.on_replay_buffer_initialized(replay_buffer)
117119

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining():
4040

4141
np.testing.assert_allclose(reward, expected_result)
4242
entropy_fn.assert_called_once_with(
43-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
43+
observations,
44+
PLACEHOLDER,
45+
PLACEHOLDER,
46+
PLACEHOLDER,
4447
)
4548

4649

@@ -57,7 +60,10 @@ def test_pebble_entropy_reward_returns_learned_rew_on_insufficient_observations(
5760

5861
np.testing.assert_allclose(reward, expected_result)
5962
learned_fn.assert_called_once_with(
60-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
63+
observations,
64+
PLACEHOLDER,
65+
PLACEHOLDER,
66+
PLACEHOLDER,
6167
)
6268

6369

@@ -74,7 +80,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
7480

7581
np.testing.assert_allclose(reward, expected_result)
7682
learned_fn.assert_called_once_with(
77-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
83+
observations,
84+
PLACEHOLDER,
85+
PLACEHOLDER,
86+
PLACEHOLDER,
7887
)
7988

8089

@@ -97,7 +106,10 @@ def test_entropy_reward_net_returns_entropy_for_pretraining(rng):
97106

98107
# Act
99108
reward = reward_net.predict_processed(
100-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
109+
observations,
110+
PLACEHOLDER,
111+
PLACEHOLDER,
112+
PLACEHOLDER,
101113
)
102114

103115
# Assert
@@ -118,7 +130,10 @@ def test_entropy_reward_net_raises_on_insufficient_observations(rng):
118130
# Act
119131
with pytest.raises(InsufficientObservations):
120132
reward_net.predict_processed(
121-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
133+
observations,
134+
PLACEHOLDER,
135+
PLACEHOLDER,
136+
PLACEHOLDER,
122137
)
123138

124139

tests/algorithms/test_preference_comparisons.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818

1919
import imitation.testing.reward_nets as testing_reward_nets
2020
from imitation.algorithms import preference_comparisons
21-
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
2221
from imitation.data import types
2322
from imitation.data.types import TrajectoryWithRew
2423
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
2524
from imitation.regularization import regularizers, updaters
2625
from imitation.rewards import reward_nets
26+
from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn
2727
from imitation.util import networks, util
2828

2929
UNCERTAINTY_ON = ["logit", "probability", "label"]
@@ -84,9 +84,13 @@ def replay_buffer(rng):
8484
def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
8585
replay_buffer_mock = Mock()
8686
replay_buffer_mock.buffer_view = replay_buffer
87-
replay_buffer_mock.obs_shape = (4,)
88-
reward_fn = PebbleStateEntropyReward(
89-
reward_net.predict_processed, venv.observation_space, venv.action_space
87+
replay_buffer_mock.observation_space = venv.observation_space
88+
replay_buffer_mock.action_space = venv.action_space
89+
reward_fn = create_pebble_reward_fn(
90+
reward_net.predict_processed,
91+
5,
92+
venv.action_space,
93+
venv.observation_space,
9094
)
9195
reward_fn.on_replay_buffer_initialized(replay_buffer_mock)
9296
return preference_comparisons.PebbleAgentTrainer(

tests/scripts/test_train_preference_comparisons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Tests train_preferences_comparisons helper methods."""
2+
13
from unittest.mock import Mock, patch
24

35
import numpy as np

0 commit comments

Comments
 (0)