1
- from typing import Iterator
1
+ from typing import Iterator , Dict
2
2
import jax
3
3
import numpy as np
4
4
5
5
from actsafe .common .double_buffer import double_buffer
6
- from actsafe .rl .trajectory import TrajectoryData
6
+ from actsafe .rl .trajectory import Transition , TrajectoryData
7
7
8
8
9
9
class ReplayBuffer :
@@ -21,67 +21,136 @@ def __init__(
21
21
self .episode_id = 0
22
22
self .dtype = np .float32
23
23
self .obs_dtype = np .uint8
24
+ self .max_length = max_length
25
+ self .observation_shape = observation_shape
26
+ self .action_shape = action_shape
27
+ self .num_rewards = num_rewards
28
+
29
+ # Main storage arrays
24
30
self .observation = np .zeros (
25
- (
26
- capacity ,
27
- max_length + 1 ,
28
- )
29
- + observation_shape ,
31
+ (capacity , max_length + 1 ) + observation_shape ,
30
32
dtype = self .obs_dtype ,
31
33
)
32
34
self .action = np .zeros (
33
- (
34
- capacity ,
35
- max_length ,
36
- )
37
- + action_shape ,
35
+ (capacity , max_length ) + action_shape ,
38
36
dtype = self .dtype ,
39
37
)
40
38
self .reward = np .zeros (
41
39
(capacity , max_length , num_rewards ),
42
40
dtype = self .dtype ,
43
41
)
44
42
self .cost = np .zeros (
45
- (
46
- capacity ,
47
- max_length ,
48
- ),
43
+ (capacity , max_length ),
49
44
dtype = self .dtype ,
50
45
)
46
+ self .done = np .ones (
47
+ (capacity , max_length ),
48
+ dtype = bool ,
49
+ )
50
+ self .episode_lengths = np .zeros (capacity , dtype = np .int32 )
51
+
52
+ # Tracking ongoing episodes
53
+ self .ongoing_episodes : Dict [int , Dict ] = {}
54
+
51
55
self ._valid_episodes = 0
52
56
self .rs = np .random .RandomState (seed )
53
57
self .batch_size = batch_size
54
58
self .sequence_length = sequence_length
59
+ self .capacity = capacity
55
60
56
- def add (self , trajectory : TrajectoryData ):
57
- capacity , * _ = self .reward .shape
58
- batch_size = min (trajectory .observation .shape [0 ], capacity )
59
- # Discard data if batch size overflows capacity.
60
- end = min (self .episode_id + batch_size , capacity )
61
- episode_slice = slice (self .episode_id , end )
62
- if trajectory .reward .ndim == 2 :
63
- trajectory = TrajectoryData (
64
- trajectory .observation ,
65
- trajectory .next_observation ,
66
- trajectory .action ,
67
- trajectory .reward [..., None ],
68
- trajectory .cost ,
69
- )
70
- for data , val in zip (
71
- (self .action , self .reward , self .cost ),
72
- (trajectory .action , trajectory .reward , trajectory .cost ),
73
- ):
74
- data [episode_slice ] = val [:batch_size ].astype (self .dtype )
75
- observation = np .concatenate (
76
- [
77
- trajectory .observation [:batch_size ],
78
- trajectory .next_observation [:batch_size , - 1 :],
79
- ],
80
- axis = 1 ,
81
- )
82
- self .observation [episode_slice ] = observation .astype (self .obs_dtype )
83
- self .episode_id = (self .episode_id + batch_size ) % capacity
84
- self ._valid_episodes = min (self ._valid_episodes + batch_size , capacity )
61
+ def _initialize_ongoing_episode (self , worker_id : int ):
62
+ """Initialize storage for a new ongoing episode."""
63
+ return {
64
+ "observation" : np .zeros (
65
+ (self .max_length + 1 ,) + self .observation_shape , dtype = self .obs_dtype
66
+ ),
67
+ "action" : np .zeros (
68
+ (self .max_length ,) + self .action_shape , dtype = self .dtype
69
+ ),
70
+ "reward" : np .zeros ((self .max_length , self .num_rewards ), dtype = self .dtype ),
71
+ "cost" : np .zeros (self .max_length , dtype = self .dtype ),
72
+ "done" : np .zeros (self .max_length , dtype = bool ),
73
+ "current_step" : 0 ,
74
+ }
75
+
76
+ def _commit_episode (self , worker_id : int ):
77
+ """Commit a completed episode to the main buffer."""
78
+ episode_data = self .ongoing_episodes [worker_id ]
79
+ current_step = episode_data ["current_step" ]
80
+
81
+ if current_step == 0 : # Skip empty episodes
82
+ return
83
+
84
+ # Check if we've reached capacity
85
+ if self .episode_id >= self .capacity :
86
+ self .episode_id = 0
87
+
88
+ # Copy data to main arrays
89
+ self .observation [self .episode_id , : current_step + 1 ] = episode_data [
90
+ "observation"
91
+ ][: current_step + 1 ]
92
+ self .action [self .episode_id , :current_step ] = episode_data ["action" ][
93
+ :current_step
94
+ ]
95
+ self .reward [self .episode_id , :current_step ] = episode_data ["reward" ][
96
+ :current_step
97
+ ]
98
+ self .cost [self .episode_id , :current_step ] = episode_data ["cost" ][:current_step ]
99
+ self .done [self .episode_id , :current_step ] = episode_data ["done" ][:current_step ]
100
+
101
+ # Set episode length
102
+ self .episode_lengths [self .episode_id ] = current_step
103
+
104
+ # Mark remaining timesteps as done
105
+ self .done [self .episode_id , current_step :] = True
106
+
107
+ # Increment counters
108
+ self .episode_id += 1
109
+ self ._valid_episodes = min (self ._valid_episodes + 1 , self .capacity )
110
+
111
+ # Clear the ongoing episode
112
+ self .ongoing_episodes [worker_id ] = self ._initialize_ongoing_episode (worker_id )
113
+
114
+ def add (self , step_data : Transition ):
115
+ """Add a single environment step to the buffer."""
116
+ # Ensure reward has correct shape
117
+ for i in range (step_data .reward .shape [0 ]):
118
+ # Get worker ID for this step
119
+ worker_id = i
120
+ # Initialize ongoing episode if needed
121
+ if worker_id not in self .ongoing_episodes :
122
+ self .ongoing_episodes [worker_id ] = self ._initialize_ongoing_episode (
123
+ worker_id
124
+ )
125
+
126
+ episode_data = self .ongoing_episodes [worker_id ]
127
+ current_step = episode_data ["current_step" ]
128
+
129
+ # Store current observation
130
+ episode_data ["observation" ][current_step ] = step_data .observation [i ]
131
+
132
+ # If not the first step, store previous step's action, reward, cost, done
133
+ if current_step > 0 :
134
+ episode_data ["action" ][current_step - 1 ] = step_data .action [i ]
135
+ episode_data ["reward" ][current_step - 1 ] = step_data .reward [i ]
136
+ episode_data ["cost" ][current_step - 1 ] = step_data .cost [i ]
137
+ episode_data ["done" ][current_step - 1 ] = step_data .done [i ]
138
+
139
+ # If episode terminated
140
+ if step_data .done [i ]:
141
+ # Store final observation
142
+ episode_data ["observation" ][
143
+ current_step + 1
144
+ ] = step_data .next_observation [i ]
145
+ self ._commit_episode (worker_id )
146
+ else :
147
+ # Continue episode
148
+ episode_data ["current_step" ] = current_step + 1
149
+
150
+ # Check if we've reached max length
151
+ if current_step + 1 >= self .max_length :
152
+ episode_data ["done" ][current_step ] = True
153
+ self ._commit_episode (worker_id )
85
154
86
155
def _sample_batch (
87
156
self ,
@@ -93,37 +162,46 @@ def _sample_batch(
93
162
valid_episodes = valid_episodes
94
163
else :
95
164
valid_episodes = self ._valid_episodes
96
- time_limit = self .observation .shape [1 ]
97
- assert time_limit > sequence_length
165
+
98
166
while True :
99
- low = self .rs .choice (time_limit - sequence_length - 1 , batch_size )
167
+ episode_ids = self .rs .choice (valid_episodes , size = batch_size )
168
+ low = np .array (
169
+ [
170
+ self .rs .randint (
171
+ 0 , max (1 , self .episode_lengths [episode_id ] - sequence_length )
172
+ )
173
+ for episode_id in episode_ids
174
+ ]
175
+ )
100
176
timestep_ids = low [:, None ] + np .tile (
101
177
np .arange (sequence_length + 1 ),
102
178
(batch_size , 1 ),
103
179
)
104
- episode_ids = self .rs .choice (valid_episodes , size = batch_size )
105
- # Sample a sequence of length H for the actions, rewards and costs,
106
- # and a length of H + 1 for the observations (which is needed for
107
- # bootstrapping)
180
+ for i , (episode_id , time_steps ) in enumerate (
181
+ zip (episode_ids , timestep_ids )
182
+ ):
183
+ episode_length = self .episode_lengths [episode_id ]
184
+ if time_steps [- 1 ] >= episode_length :
185
+ # Adjust timesteps to end at episode termination
186
+ offset = time_steps [- 1 ] - episode_length + 1
187
+ timestep_ids [i ] -= offset
188
+
108
189
a , r , c = [
109
190
x [episode_ids [:, None ], timestep_ids [:, :- 1 ]]
110
- for x in (
111
- self .action ,
112
- self .reward ,
113
- self .cost ,
114
- )
191
+ for x in (self .action , self .reward , self .cost )
115
192
]
116
193
o = self .observation [episode_ids [:, None ], timestep_ids ]
117
194
o , next_o = o [:, :- 1 ], o [:, 1 :]
118
- yield o , next_o , a , r , c
195
+ done = self .done [episode_ids [:, None ], timestep_ids [:, :- 1 ]]
196
+ yield o , next_o , a , r , c , done
119
197
120
198
def sample (self , n_batches : int ) -> Iterator [TrajectoryData ]:
121
199
if self .empty :
122
200
return
123
201
iterator = (
124
202
TrajectoryData (
125
203
* next (self ._sample_batch (self .batch_size , self .sequence_length ))
126
- ) # type: ignore
204
+ )
127
205
for _ in range (n_batches )
128
206
)
129
207
if jax .default_backend () == "gpu" :
0 commit comments