1
1
"""Wrapper for reward labeling for transitions sampled from a replay buffer."""
2
2
3
+ from typing import Callable
3
4
from typing import Mapping , Type
4
5
5
6
import numpy as np
10
11
from imitation .rewards .reward_function import RewardFn
11
12
from imitation .util import util
12
13
from imitation .util .networks import RunningNorm
13
- from typing import Callable
14
14
15
15
16
16
def _samples_to_reward_fn_input (
@@ -59,6 +59,7 @@ def __init__(
59
59
* ,
60
60
replay_buffer_class : Type [ReplayBuffer ],
61
61
reward_fn : RewardFn ,
62
+ on_initialized_callback : Callable [["ReplayBufferRewardWrapper" ], None ] = None ,
62
63
** kwargs ,
63
64
):
64
65
"""Builds ReplayBufferRewardWrapper.
@@ -69,6 +70,9 @@ def __init__(
69
70
action_space: Action space
70
71
replay_buffer_class: Class of the replay buffer.
71
72
reward_fn: Reward function for reward relabeling.
73
+ on_initialized_callback: Callback called with reference to this object after
74
+ this instance is fully initialized. This provides a hook to access the
75
+ buffer after it is created from inside a Stable Baselines algorithm.
72
76
**kwargs: keyword arguments for ReplayBuffer.
73
77
"""
74
78
# Note(yawen-d): we directly inherit ReplayBuffer and leave out the case of
@@ -86,6 +90,8 @@ def __init__(
86
90
self .reward_fn = reward_fn
87
91
_base_kwargs = {k : v for k , v in kwargs .items () if k in ["device" , "n_envs" ]}
88
92
super ().__init__ (buffer_size , observation_space , action_space , ** _base_kwargs )
93
+ if on_initialized_callback is not None :
94
+ on_initialized_callback (self )
89
95
90
96
# TODO(juan) remove the type ignore once the merged PR
91
97
# https://github.com/python/mypy/pull/13475
0 commit comments