1
1
import pickle
2
- from unittest .mock import patch
2
+ from unittest .mock import patch , Mock
3
3
4
4
import numpy as np
5
5
import torch as th
19
19
VENVS = 2
20
20
21
21
22
- def test_state_entropy_reward_returns_entropy (rng ):
22
+ def test_pebble_entropy_reward_returns_entropy (rng ):
23
23
obs_shape = get_obs_shape (SPACE )
24
24
all_observations = rng .random ((BUFFER_SIZE , VENVS , * obs_shape ))
25
25
26
-
27
- reward_fn = PebbleStateEntropyReward (K , SPACE )
28
- reward_fn .set_replay_buffer (ReplayBufferView (all_observations , lambda : slice (None )), obs_shape )
26
+ reward_fn = PebbleStateEntropyReward (Mock (), SPACE , K )
27
+ reward_fn .set_replay_buffer (
28
+ ReplayBufferView (all_observations , lambda : slice (None )), obs_shape
29
+ )
29
30
30
31
# Act
31
32
observations = rng .random ((BATCH_SIZE , * obs_shape ))
@@ -41,16 +42,16 @@ def test_state_entropy_reward_returns_entropy(rng):
41
42
np .testing .assert_allclose (reward , expected_normalized )
42
43
43
44
44
- def test_state_entropy_reward_returns_normalized_values ():
45
+ def test_pebble_entropy_reward_returns_normalized_values ():
45
46
with patch ("imitation.util.util.compute_state_entropy" ) as m :
46
47
# mock entropy computation so that we can test only stats collection in this test
47
48
m .side_effect = lambda obs , all_obs , k : obs
48
49
49
- reward_fn = PebbleStateEntropyReward (K , SPACE )
50
+ reward_fn = PebbleStateEntropyReward (Mock () , SPACE , K )
50
51
all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
51
52
reward_fn .set_replay_buffer (
52
53
ReplayBufferView (all_observations , lambda : slice (None )),
53
- get_obs_shape (SPACE )
54
+ get_obs_shape (SPACE ),
54
55
)
55
56
56
57
dim = 8
@@ -75,12 +76,12 @@ def test_state_entropy_reward_returns_normalized_values():
75
76
)
76
77
77
78
78
- def test_state_entropy_reward_can_pickle ():
79
+ def test_pebble_entropy_reward_can_pickle ():
79
80
all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
80
81
replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
81
82
82
83
obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
83
- reward_fn = PebbleStateEntropyReward (K , SPACE )
84
+ reward_fn = PebbleStateEntropyReward (reward_fn_stub , SPACE , K )
84
85
reward_fn .set_replay_buffer (replay_buffer , get_obs_shape (SPACE ))
85
86
reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
86
87
@@ -94,3 +95,33 @@ def test_state_entropy_reward_can_pickle():
94
95
expected_result = reward_fn (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
95
96
actual_result = reward_fn_deserialized (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
96
97
np .testing .assert_allclose (actual_result , expected_result )
98
+
99
+
100
+ def test_pebble_entropy_reward_function_switches_to_inner ():
101
+ obs_shape = get_obs_shape (SPACE )
102
+
103
+ expected_reward = np .ones (1 )
104
+ reward_fn_mock = Mock ()
105
+ reward_fn_mock .return_value = expected_reward
106
+ reward_fn = PebbleStateEntropyReward (reward_fn_mock , SPACE )
107
+
108
+ # Act
109
+ reward_fn .on_unsupervised_exploration_finished ()
110
+ observations = np .ones ((BATCH_SIZE , * obs_shape ))
111
+ reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
112
+
113
+ # Assert
114
+ assert reward == expected_reward
115
+ reward_fn_mock .assert_called_once_with (
116
+ observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER
117
+ )
118
+
119
+
120
+ def reward_fn_stub (
121
+ self ,
122
+ state : np .ndarray ,
123
+ action : np .ndarray ,
124
+ next_state : np .ndarray ,
125
+ done : np .ndarray ,
126
+ ) -> np .ndarray :
127
+ return state
0 commit comments