Skip to content

Commit f3decf1

Browse files
author
Jan Michelfeit
committed
#625 fix pre-commit errors
1 parent 2ab0780 commit f3decf1

File tree

8 files changed

+88
-52
lines changed

8 files changed

+88
-52
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1+
"""Reward function for the PEBBLE training algorithm."""
2+
13
from enum import Enum, auto
2-
from typing import Tuple
4+
from typing import Dict, Optional, Tuple, Union
35

46
import numpy as np
57
import torch as th
68

79
from imitation.policies.replay_buffer_wrapper import (
8-
ReplayBufferView,
910
ReplayBufferRewardWrapper,
11+
ReplayBufferView,
1012
)
1113
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
1214
from imitation.util import util
1315
from imitation.util.networks import RunningNorm
1416

1517

1618
class PebbleRewardPhase(Enum):
17-
"""States representing different behaviors for PebbleStateEntropyReward"""
19+
"""States representing different behaviors for PebbleStateEntropyReward."""
1820

1921
UNSUPERVISED_EXPLORATION = auto() # Entropy based reward
2022
POLICY_AND_REWARD_LEARNING = auto() # Learned reward
2123

2224

2325
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
24-
"""
25-
Reward function for implementation of the PEBBLE learning algorithm
26-
(https://arxiv.org/pdf/2106.05091.pdf).
26+
"""Reward function for implementation of the PEBBLE learning algorithm.
27+
28+
See https://arxiv.org/pdf/2106.05091.pdf .
2729
2830
The rewards returned by this function go through the three phases:
2931
1. Before enough samples are collected for entropy calculation, the
@@ -38,33 +40,38 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
3840
supplied with set_replay_buffer() or on_replay_buffer_initialized().
3941
To transition to the last phase, unsupervised_exploration_finish() needs
4042
to be called.
41-
42-
Args:
43-
learned_reward_fn: The learned reward function used after unsupervised
44-
exploration is finished
45-
nearest_neighbor_k: Parameter for entropy computation (see
46-
compute_state_entropy())
4743
"""
4844

49-
# TODO #625: parametrize nearest_neighbor_k
5045
def __init__(
5146
self,
5247
learned_reward_fn: RewardFn,
5348
nearest_neighbor_k: int = 5,
5449
):
50+
"""Builds this class.
51+
52+
Args:
53+
learned_reward_fn: The learned reward function used after unsupervised
54+
exploration is finished
55+
nearest_neighbor_k: Parameter for entropy computation (see
56+
compute_state_entropy())
57+
"""
5558
self.learned_reward_fn = learned_reward_fn
5659
self.nearest_neighbor_k = nearest_neighbor_k
5760
self.entropy_stats = RunningNorm(1)
5861
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
5962

6063
# These two need to be set with set_replay_buffer():
61-
self.replay_buffer_view = None
62-
self.obs_shape = None
64+
self.replay_buffer_view: Optional[ReplayBufferView] = None
65+
self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None
6366

6467
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
6568
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)
6669

67-
def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
70+
def set_replay_buffer(
71+
self,
72+
replay_buffer: ReplayBufferView,
73+
obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]],
74+
):
6875
self.replay_buffer_view = replay_buffer
6976
self.obs_shape = obs_shape
7077

@@ -87,7 +94,7 @@ def __call__(
8794
def _entropy_reward(self, state, action, next_state, done):
8895
if self.replay_buffer_view is None:
8996
raise ValueError(
90-
"Replay buffer must be supplied before entropy reward can be used"
97+
"Replay buffer must be supplied before entropy reward can be used",
9198
)
9299
all_observations = self.replay_buffer_view.observations
93100
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that

src/imitation/algorithms/preference_comparisons.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
7777
""" # noqa: DAR202
7878

7979
def unsupervised_pretrain(self, steps: int, **kwargs: Any) -> None:
80-
"""Pre-train an agent if the trajectory generator uses one that
81-
needs pre-training.
80+
"""Pre-train an agent before collecting comparisons.
8281
8382
By default, this method does nothing and doesn't need
8483
to be overridden in subclasses that don't require pre-training.
@@ -331,8 +330,8 @@ def logger(self, value: imit_logger.HierarchicalLogger) -> None:
331330

332331

333332
class PebbleAgentTrainer(AgentTrainer):
334-
"""
335-
Specialization of AgentTrainer for PEBBLE training.
333+
"""Specialization of AgentTrainer for PEBBLE training.
334+
336335
Includes unsupervised pretraining with an entropy based reward function.
337336
"""
338337

@@ -344,9 +343,20 @@ def __init__(
344343
reward_fn: PebbleStateEntropyReward,
345344
**kwargs,
346345
) -> None:
346+
"""Builds PebbleAgentTrainer.
347+
348+
Args:
349+
reward_fn: Pebble reward function
350+
**kwargs: additional keyword arguments to pass on to
351+
the parent class
352+
353+
Raises:
354+
ValueError: Unexpected type of reward_fn given.
355+
"""
347356
if not isinstance(reward_fn, PebbleStateEntropyReward):
348357
raise ValueError(
349-
f"{self.__class__.__name__} expects {PebbleStateEntropyReward.__name__} reward function"
358+
f"{self.__class__.__name__} expects "
359+
f"{PebbleStateEntropyReward.__name__} reward function",
350360
)
351361
super().__init__(reward_fn=reward_fn, **kwargs)
352362

@@ -1729,10 +1739,10 @@ def train(
17291739
###################################################
17301740
with self.logger.accumulate_means("agent"):
17311741
self.logger.log(
1732-
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps"
1742+
f"Pre-training agent for {unsupervised_pretrain_timesteps} timesteps",
17331743
)
17341744
self.trajectory_generator.unsupervised_pretrain(
1735-
unsupervised_pretrain_timesteps
1745+
unsupervised_pretrain_timesteps,
17361746
)
17371747

17381748
for i, num_pairs in enumerate(preference_query_schedule):
@@ -1811,7 +1821,7 @@ def _preference_gather_schedule(self, total_comparisons):
18111821

18121822
def _compute_timesteps(self, total_timesteps: int) -> Tuple[int, int, int]:
18131823
unsupervised_pretrain_timesteps = int(
1814-
total_timesteps * self.unsupervised_agent_pretrain_frac
1824+
total_timesteps * self.unsupervised_agent_pretrain_frac,
18151825
)
18161826
timesteps_per_iteration, extra_timesteps = divmod(
18171827
total_timesteps - unsupervised_pretrain_timesteps,

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stable_baselines3.common.buffers import ReplayBuffer
88
from stable_baselines3.common.type_aliases import ReplayBufferSamples
99

10-
from imitation.rewards.reward_function import RewardFn, ReplayBufferAwareRewardFn
10+
from imitation.rewards.reward_function import ReplayBufferAwareRewardFn, RewardFn
1111
from imitation.util import util
1212

1313

@@ -24,19 +24,20 @@ def _samples_to_reward_fn_input(
2424

2525

2626
class ReplayBufferView:
27-
"""A read-only view over a valid records in a ReplayBuffer.
28-
29-
Args:
30-
observations_buffer: Array buffer holding observations
31-
buffer_slice_provider: Function returning slice of buffer
32-
with valid observations
33-
"""
27+
"""A read-only view over a valid records in a ReplayBuffer."""
3428

3529
def __init__(
3630
self,
3731
observations_buffer: np.ndarray,
3832
buffer_slice_provider: Callable[[], slice],
3933
):
34+
"""Builds ReplayBufferView.
35+
36+
Args:
37+
observations_buffer: Array buffer holding observations
38+
buffer_slice_provider: Function returning slice of buffer
39+
with valid observations
40+
"""
4041
self._observations_buffer_view = observations_buffer.view()
4142
self._observations_buffer_view.flags.writeable = False
4243
self._buffer_slice_provider = buffer_slice_provider
@@ -67,9 +68,6 @@ def __init__(
6768
action_space: Action space
6869
replay_buffer_class: Class of the replay buffer.
6970
reward_fn: Reward function for reward relabeling.
70-
on_initialized_callback: Callback called with reference to this object after
71-
this instance is fully initialized. This provides a hook to access the
72-
buffer after it is created from inside a Stable Baselines algorithm.
7371
**kwargs: keyword arguments for ReplayBuffer.
7472
"""
7573
# Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of

src/imitation/rewards/reward_function.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def __call__(
3535

3636

3737
class ReplayBufferAwareRewardFn(RewardFn, abc.ABC):
38+
"""Abstract class for a reward function that needs access to a replay buffer."""
39+
3840
@abc.abstractmethod
39-
def on_replay_buffer_initialized(self, replay_buffer: "ReplayBufferRewardWrapper"):
41+
def on_replay_buffer_initialized(
42+
self,
43+
replay_buffer: "ReplayBufferRewardWrapper", # type: ignore[name-defined]
44+
):
4045
pass

src/imitation/scripts/common/rl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def _maybe_add_relabel_buffer(
8989
_buffer_kwargs = dict(
9090
reward_fn=relabel_reward_fn,
9191
replay_buffer_class=rl_kwargs.get(
92-
"replay_buffer_class", buffers.ReplayBuffer
92+
"replay_buffer_class",
93+
buffers.ReplayBuffer,
9394
),
9495
)
9596
rl_kwargs["replay_buffer_class"] = ReplayBufferRewardWrapper

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def train_defaults():
6060

6161
checkpoint_interval = 0 # Num epochs between saving (<0 disables, =0 final only)
6262
query_schedule = "hyperbolic"
63+
6364
# Whether to use the PEBBLE algorithm (https://arxiv.org/pdf/2106.05091.pdf)
6465
pebble_enabled = False
66+
unsupervised_agent_pretrain_frac = 0.0
6567

6668

6769
@train_preference_comparisons_ex.named_config

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import numpy as np
1111
import torch as th
1212
from sacred.observers import FileStorageObserver
13-
from stable_baselines3.common import type_aliases, base_class, vec_env
13+
from stable_baselines3.common import base_class, type_aliases, vec_env
1414

1515
from imitation.algorithms import preference_comparisons
1616
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1717
from imitation.data import types
1818
from imitation.policies import serialize
19-
from imitation.rewards import reward_nets, reward_function
19+
from imitation.rewards import reward_function, reward_nets
2020
from imitation.scripts.common import common, reward
2121
from imitation.scripts.common import rl as rl_common
2222
from imitation.scripts.common import train
@@ -65,15 +65,16 @@ def make_reward_function(
6565
reward_net: reward_nets.RewardNet,
6666
*,
6767
pebble_enabled: bool = False,
68-
pebble_nearest_neighbor_k: Optional[int] = None,
68+
pebble_nearest_neighbor_k: int = 5,
6969
):
7070
relabel_reward_fn = functools.partial(
7171
reward_net.predict_processed,
7272
update_stats=False,
7373
)
7474
if pebble_enabled:
7575
relabel_reward_fn = PebbleStateEntropyReward(
76-
relabel_reward_fn, pebble_nearest_neighbor_k
76+
relabel_reward_fn, # type: ignore[assignment]
77+
pebble_nearest_neighbor_k,
7778
)
7879
return relabel_reward_fn
7980

@@ -92,6 +93,7 @@ def make_agent_trajectory_generator(
9293
trajectory_generator_kwargs: Mapping[str, Any],
9394
) -> preference_comparisons.AgentTrainer:
9495
if pebble_enabled:
96+
assert isinstance(relabel_reward_fn, PebbleStateEntropyReward)
9597
return preference_comparisons.PebbleAgentTrainer(
9698
algorithm=agent,
9799
reward_fn=relabel_reward_fn,
@@ -138,7 +140,7 @@ def train_preference_comparisons(
138140
allow_variable_horizon: bool,
139141
checkpoint_interval: int,
140142
query_schedule: Union[str, type_aliases.Schedule],
141-
unsupervised_agent_pretrain_frac: Optional[float],
143+
unsupervised_agent_pretrain_frac: float,
142144
) -> Mapping[str, Any]:
143145
"""Train a reward model using preference comparisons.
144146

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
"""Tests for `imitation.algorithms.entropy_reward`."""
2+
13
import pickle
2-
from unittest.mock import patch, Mock
4+
from unittest.mock import Mock, patch
35

46
import numpy as np
57
import torch as th
68
from gym.spaces import Discrete
7-
from stable_baselines3.common.preprocessing import get_obs_shape
89

910
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1011
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
1112
from imitation.util import util
1213

1314
SPACE = Discrete(4)
14-
OBS_SHAPE = get_obs_shape(SPACE)
15+
OBS_SHAPE = (1,)
1516
PLACEHOLDER = np.empty(OBS_SHAPE)
1617

1718
BUFFER_SIZE = 20
@@ -25,7 +26,8 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
2526

2627
reward_fn = PebbleStateEntropyReward(Mock(), K)
2728
reward_fn.set_replay_buffer(
28-
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
29+
ReplayBufferView(all_observations, lambda: slice(None)),
30+
OBS_SHAPE,
2931
)
3032

3133
# Act
@@ -34,17 +36,20 @@ def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
3436

3537
# Assert
3638
expected = util.compute_state_entropy(
37-
observations, all_observations.reshape(-1, *OBS_SHAPE), K
39+
observations,
40+
all_observations.reshape(-1, *OBS_SHAPE),
41+
K,
3842
)
3943
expected_normalized = reward_fn.entropy_stats.normalize(
40-
th.as_tensor(expected)
44+
th.as_tensor(expected),
4145
).numpy()
4246
np.testing.assert_allclose(reward, expected_normalized)
4347

4448

4549
def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
4650
with patch("imitation.util.util.compute_state_entropy") as m:
47-
# mock entropy computation so that we can test only stats collection in this test
51+
# mock entropy computation so that we can test
52+
# only stats collection in this test
4853
m.side_effect = lambda obs, all_obs, k: obs
4954

5055
reward_fn = PebbleStateEntropyReward(Mock(), K)
@@ -64,7 +69,10 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
6469
reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
6570

6671
normalized_reward = reward_fn(
67-
np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
72+
np.zeros(dim),
73+
PLACEHOLDER,
74+
PLACEHOLDER,
75+
PLACEHOLDER,
6876
)
6977

7078
# Assert
@@ -91,7 +99,10 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
9199
# Assert
92100
assert reward == expected_reward
93101
learned_reward_mock.assert_called_once_with(
94-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
102+
observations,
103+
PLACEHOLDER,
104+
PLACEHOLDER,
105+
PLACEHOLDER,
95106
)
96107

97108

0 commit comments

Comments
 (0)