Skip to content

Commit ad29c34

Browse files
author
Jan Michelfeit
committed
#625 add initialized callback to ReplayBufferRewardWrapper
1 parent c681ca3 commit ad29c34

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Wrapper for reward labeling for transitions sampled from a replay buffer."""
22

3+
from typing import Callable
34
from typing import Mapping, Type
45

56
import numpy as np
@@ -10,7 +11,6 @@
1011
from imitation.rewards.reward_function import RewardFn
1112
from imitation.util import util
1213
from imitation.util.networks import RunningNorm
13-
from typing import Callable
1414

1515

1616
def _samples_to_reward_fn_input(
@@ -59,6 +59,7 @@ def __init__(
5959
*,
6060
replay_buffer_class: Type[ReplayBuffer],
6161
reward_fn: RewardFn,
62+
on_initialized_callback: Callable[["ReplayBufferRewardWrapper"], None] = None,
6263
**kwargs,
6364
):
6465
"""Builds ReplayBufferRewardWrapper.
@@ -69,6 +70,9 @@ def __init__(
6970
action_space: Action space
7071
replay_buffer_class: Class of the replay buffer.
7172
reward_fn: Reward function for reward relabeling.
73+
on_initialized_callback: Callback called with reference to this object after
74+
this instance is fully initialized. This provides a hook to access the
75+
buffer after it is created from inside a Stable Baselines algorithm.
7276
**kwargs: keyword arguments for ReplayBuffer.
7377
"""
7478
# Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of
@@ -86,6 +90,8 @@ def __init__(
8690
self.reward_fn = reward_fn
8791
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
8892
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)
93+
if on_initialized_callback is not None:
94+
on_initialized_callback(self)
8995

9096
# TODO(juan) remove the type ignore once the merged PR
9197
# https://github.com/python/mypy/pull/13475

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,18 @@ def test_replay_buffer_view_provides_buffered_observations():
264264
# ReplayBuffer internally uses a circular buffer
265265
expected = np.roll(observations, 1, axis=0)
266266
np.testing.assert_allclose(view.observations, expected)
267+
268+
269+
def test_replay_buffer_reward_wrapper_calls_initialization_callback_with_itself():
270+
callback = Mock()
271+
buffer = ReplayBufferRewardWrapper(
272+
10,
273+
spaces.Discrete(2),
274+
spaces.Discrete(2),
275+
replay_buffer_class=ReplayBuffer,
276+
reward_fn=Mock(),
277+
n_envs=2,
278+
handle_timeout_termination=False,
279+
on_initialized_callback=callback,
280+
)
281+
assert callback.call_args.args[0] is buffer

0 commit comments

Comments
 (0)