Skip to content

Commit 74ba96b

Browse files
author
Jan Michelfeit
committed
#641 fix static analysis and tests
1 parent 531b353 commit 74ba96b

File tree

6 files changed

+74
-25
lines changed

6 files changed

+74
-25
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Reward function for the PEBBLE training algorithm."""
22

33
import enum
4-
from typing import Optional, Tuple
4+
from typing import Any, Callable, Optional, Tuple
55

66
import gym
77
import numpy as np
@@ -18,10 +18,16 @@
1818

1919

2020
class InsufficientObservations(RuntimeError):
21+
"""Error signifying not enough observations for entropy calculation."""
22+
2123
pass
2224

2325

2426
class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
27+
"""RewardNet wrapping entropy reward function."""
28+
29+
__call__: Callable[..., Any] # Needed to appease pytype
30+
2531
def __init__(
2632
self,
2733
nearest_neighbor_k: int,
@@ -53,6 +59,9 @@ def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper)
5359
5460
This method needs to be called, e.g., after unpickling.
5561
See also __getstate__() / __setstate__().
62+
63+
Args:
64+
replay_buffer: replay buffer with history of observations
5665
"""
5766
assert self.observation_space == replay_buffer.observation_space
5867
assert self.action_space == replay_buffer.action_space
@@ -72,16 +81,18 @@ def forward(
7281
all_observations = self._replay_buffer_view.observations
7382
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
7483
all_observations = all_observations.reshape(
75-
(-1,) + self.observation_space.shape
84+
(-1,) + self.observation_space.shape,
7685
)
7786

7887
if all_observations.shape[0] < self.nearest_neighbor_k:
7988
raise InsufficientObservations(
80-
"Insufficient observations for entropy calculation"
89+
"Insufficient observations for entropy calculation",
8190
)
8291

8392
return util.compute_state_entropy(
84-
state, all_observations, self.nearest_neighbor_k
93+
state,
94+
all_observations,
95+
self.nearest_neighbor_k,
8596
)
8697

8798
def preprocess(
@@ -95,6 +106,15 @@ def preprocess(
95106
96107
We also know forward() only works with state, so no need to convert
97108
other tensors.
109+
110+
Args:
111+
state: The observation input.
112+
action: The action input.
113+
next_state: The observation input.
114+
done: Whether the episode has terminated.
115+
116+
Returns:
117+
Observations preprocessed by converting them to Tensor.
98118
"""
99119
state_th = util.safe_to_tensor(state).to(self.device)
100120
action_th = next_state_th = done_th = th.empty(0)
@@ -172,8 +192,8 @@ def __call__(
172192
try:
173193
return self.entropy_reward_fn(state, action, next_state, done)
174194
except InsufficientObservations:
175-
# not enough observations to compare to, fall back to the learned function;
176-
# (falling back to a constant may also be ok)
195+
# not enough observations to compare to, fall back to the learned
196+
# function; (falling back to a constant may also be ok)
177197
return self.learned_reward_fn(state, action, next_state, done)
178198
else:
179199
return self.learned_reward_fn(state, action, next_state, done)

src/imitation/algorithms/preference_comparisons.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,17 @@ def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
9696
"""Pre-train an agent before collecting comparisons.
9797
9898
Override this behavior in subclasses that implement pre-training.
99-
If not overriden, this method raises ValueError when non-zero steps are
99+
If not overridden, this method raises ValueError when non-zero steps are
100100
allocated for pre-training.
101101
102102
Args:
103103
steps: number of environment steps to train for.
104104
**kwargs: additional keyword arguments to pass on to
105105
the training procedure.
106+
107+
Raises:
108+
ValueError: Unsupervised pre-training not implemented but non-zero
109+
steps are allocated for pre-training.
106110
"""
107111
if steps > 0:
108112
raise ValueError(

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pathlib
88
from typing import Any, Mapping, Optional, Type, Union
99

10+
import gym
1011
import numpy as np
1112
import torch as th
1213
from sacred.observers import FileStorageObserver
@@ -24,6 +25,7 @@
2425
ReplayBufferRewardWrapper,
2526
)
2627
from imitation.rewards import reward_function, reward_nets
28+
from imitation.rewards.reward_function import RewardFn
2729
from imitation.rewards.reward_nets import NormalizedRewardNet
2830
from imitation.scripts.common import common, reward
2931
from imitation.scripts.common import rl as rl_common
@@ -80,21 +82,22 @@ def make_reward_function(
8082
reward_net.predict_processed,
8183
update_stats=False,
8284
)
83-
observation_space = reward_net.observation_space
84-
action_space = reward_net.action_space
8585
if pebble_enabled:
8686
relabel_reward_fn = create_pebble_reward_fn(
87-
relabel_reward_fn,
87+
relabel_reward_fn, # type: ignore[assignment]
8888
pebble_nearest_neighbor_k,
89-
action_space,
90-
observation_space,
89+
reward_net.action_space,
90+
reward_net.observation_space,
9191
)
9292
return relabel_reward_fn
9393

9494

9595
def create_pebble_reward_fn(
96-
relabel_reward_fn, pebble_nearest_neighbor_k, action_space, observation_space
97-
):
96+
relabel_reward_fn: RewardFn,
97+
pebble_nearest_neighbor_k: int,
98+
action_space: gym.Space,
99+
observation_space: gym.Space,
100+
) -> PebbleStateEntropyReward:
98101
entropy_reward_net = EntropyRewardNet(
99102
nearest_neighbor_k=pebble_nearest_neighbor_k,
100103
observation_space=observation_space,
@@ -111,13 +114,14 @@ def __call__(self, *args, **kwargs) -> np.ndarray:
111114
return normalized_entropy_reward_net.predict_processed(*args, **kwargs)
112115

113116
def on_replay_buffer_initialized(
114-
self, replay_buffer: ReplayBufferRewardWrapper
117+
self,
118+
replay_buffer: ReplayBufferRewardWrapper,
115119
):
116120
entropy_reward_net.on_replay_buffer_initialized(replay_buffer)
117121

118122
return PebbleStateEntropyReward(
119123
EntropyRewardFn(),
120-
relabel_reward_fn, # type: ignore[assignment]
124+
relabel_reward_fn,
121125
)
122126

123127

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining():
4040

4141
np.testing.assert_allclose(reward, expected_result)
4242
entropy_fn.assert_called_once_with(
43-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
43+
observations,
44+
PLACEHOLDER,
45+
PLACEHOLDER,
46+
PLACEHOLDER,
4447
)
4548

4649

@@ -57,7 +60,10 @@ def test_pebble_entropy_reward_returns_learned_rew_on_insufficient_observations(
5760

5861
np.testing.assert_allclose(reward, expected_result)
5962
learned_fn.assert_called_once_with(
60-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
63+
observations,
64+
PLACEHOLDER,
65+
PLACEHOLDER,
66+
PLACEHOLDER,
6167
)
6268

6369

@@ -74,7 +80,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
7480

7581
np.testing.assert_allclose(reward, expected_result)
7682
learned_fn.assert_called_once_with(
77-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
83+
observations,
84+
PLACEHOLDER,
85+
PLACEHOLDER,
86+
PLACEHOLDER,
7887
)
7988

8089

@@ -97,7 +106,10 @@ def test_entropy_reward_net_returns_entropy_for_pretraining(rng):
97106

98107
# Act
99108
reward = reward_net.predict_processed(
100-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
109+
observations,
110+
PLACEHOLDER,
111+
PLACEHOLDER,
112+
PLACEHOLDER,
101113
)
102114

103115
# Assert
@@ -118,7 +130,10 @@ def test_entropy_reward_net_raises_on_insufficient_observations(rng):
118130
# Act
119131
with pytest.raises(InsufficientObservations):
120132
reward_net.predict_processed(
121-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
133+
observations,
134+
PLACEHOLDER,
135+
PLACEHOLDER,
136+
PLACEHOLDER,
122137
)
123138

124139

tests/algorithms/test_preference_comparisons.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818

1919
import imitation.testing.reward_nets as testing_reward_nets
2020
from imitation.algorithms import preference_comparisons
21-
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
2221
from imitation.data import types
2322
from imitation.data.types import TrajectoryWithRew
2423
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
2524
from imitation.regularization import regularizers, updaters
2625
from imitation.rewards import reward_nets
26+
from imitation.scripts.train_preference_comparisons import create_pebble_reward_fn
2727
from imitation.util import networks, util
2828

2929
UNCERTAINTY_ON = ["logit", "probability", "label"]
@@ -84,9 +84,13 @@ def replay_buffer(rng):
8484
def pebble_agent_trainer(agent, reward_net, venv, rng, replay_buffer):
8585
replay_buffer_mock = Mock()
8686
replay_buffer_mock.buffer_view = replay_buffer
87-
replay_buffer_mock.obs_shape = (4,)
88-
reward_fn = PebbleStateEntropyReward(
89-
reward_net.predict_processed, venv.observation_space, venv.action_space
87+
replay_buffer_mock.observation_space = venv.observation_space
88+
replay_buffer_mock.action_space = venv.action_space
89+
reward_fn = create_pebble_reward_fn(
90+
reward_net.predict_processed,
91+
5,
92+
venv.action_space,
93+
venv.observation_space,
9094
)
9195
reward_fn.on_replay_buffer_initialized(replay_buffer_mock)
9296
return preference_comparisons.PebbleAgentTrainer(

tests/scripts/test_train_preference_comparisons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Tests train_preferences_comparisons helper methods."""
2+
13
from unittest.mock import Mock, patch
24

35
import numpy as np

0 commit comments

Comments
 (0)