Skip to content

Commit f20b331

Browse files
authored
First implementation of replay buffer with termination (#43)
* First implementation of replay buffer with termination * Update data collection * Debug run works * Fix tests
1 parent 81695d7 commit f20b331

9 files changed

+196
-187
lines changed

actsafe/actsafe/replay_buffer.py

+137-59
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Iterator
1+
from typing import Iterator, Dict
22
import jax
33
import numpy as np
44

55
from actsafe.common.double_buffer import double_buffer
6-
from actsafe.rl.trajectory import TrajectoryData
6+
from actsafe.rl.trajectory import Transition, TrajectoryData
77

88

99
class ReplayBuffer:
@@ -21,67 +21,136 @@ def __init__(
2121
self.episode_id = 0
2222
self.dtype = np.float32
2323
self.obs_dtype = np.uint8
24+
self.max_length = max_length
25+
self.observation_shape = observation_shape
26+
self.action_shape = action_shape
27+
self.num_rewards = num_rewards
28+
29+
# Main storage arrays
2430
self.observation = np.zeros(
25-
(
26-
capacity,
27-
max_length + 1,
28-
)
29-
+ observation_shape,
31+
(capacity, max_length + 1) + observation_shape,
3032
dtype=self.obs_dtype,
3133
)
3234
self.action = np.zeros(
33-
(
34-
capacity,
35-
max_length,
36-
)
37-
+ action_shape,
35+
(capacity, max_length) + action_shape,
3836
dtype=self.dtype,
3937
)
4038
self.reward = np.zeros(
4139
(capacity, max_length, num_rewards),
4240
dtype=self.dtype,
4341
)
4442
self.cost = np.zeros(
45-
(
46-
capacity,
47-
max_length,
48-
),
43+
(capacity, max_length),
4944
dtype=self.dtype,
5045
)
46+
self.done = np.ones(
47+
(capacity, max_length),
48+
dtype=bool,
49+
)
50+
self.episode_lengths = np.zeros(capacity, dtype=np.int32)
51+
52+
# Tracking ongoing episodes
53+
self.ongoing_episodes: Dict[int, Dict] = {}
54+
5155
self._valid_episodes = 0
5256
self.rs = np.random.RandomState(seed)
5357
self.batch_size = batch_size
5458
self.sequence_length = sequence_length
59+
self.capacity = capacity
5560

56-
def add(self, trajectory: TrajectoryData):
57-
capacity, *_ = self.reward.shape
58-
batch_size = min(trajectory.observation.shape[0], capacity)
59-
# Discard data if batch size overflows capacity.
60-
end = min(self.episode_id + batch_size, capacity)
61-
episode_slice = slice(self.episode_id, end)
62-
if trajectory.reward.ndim == 2:
63-
trajectory = TrajectoryData(
64-
trajectory.observation,
65-
trajectory.next_observation,
66-
trajectory.action,
67-
trajectory.reward[..., None],
68-
trajectory.cost,
69-
)
70-
for data, val in zip(
71-
(self.action, self.reward, self.cost),
72-
(trajectory.action, trajectory.reward, trajectory.cost),
73-
):
74-
data[episode_slice] = val[:batch_size].astype(self.dtype)
75-
observation = np.concatenate(
76-
[
77-
trajectory.observation[:batch_size],
78-
trajectory.next_observation[:batch_size, -1:],
79-
],
80-
axis=1,
81-
)
82-
self.observation[episode_slice] = observation.astype(self.obs_dtype)
83-
self.episode_id = (self.episode_id + batch_size) % capacity
84-
self._valid_episodes = min(self._valid_episodes + batch_size, capacity)
61+
def _initialize_ongoing_episode(self, worker_id: int):
62+
"""Initialize storage for a new ongoing episode."""
63+
return {
64+
"observation": np.zeros(
65+
(self.max_length + 1,) + self.observation_shape, dtype=self.obs_dtype
66+
),
67+
"action": np.zeros(
68+
(self.max_length,) + self.action_shape, dtype=self.dtype
69+
),
70+
"reward": np.zeros((self.max_length, self.num_rewards), dtype=self.dtype),
71+
"cost": np.zeros(self.max_length, dtype=self.dtype),
72+
"done": np.zeros(self.max_length, dtype=bool),
73+
"current_step": 0,
74+
}
75+
76+
def _commit_episode(self, worker_id: int):
77+
"""Commit a completed episode to the main buffer."""
78+
episode_data = self.ongoing_episodes[worker_id]
79+
current_step = episode_data["current_step"]
80+
81+
if current_step == 0: # Skip empty episodes
82+
return
83+
84+
# Check if we've reached capacity
85+
if self.episode_id >= self.capacity:
86+
self.episode_id = 0
87+
88+
# Copy data to main arrays
89+
self.observation[self.episode_id, : current_step + 1] = episode_data[
90+
"observation"
91+
][: current_step + 1]
92+
self.action[self.episode_id, :current_step] = episode_data["action"][
93+
:current_step
94+
]
95+
self.reward[self.episode_id, :current_step] = episode_data["reward"][
96+
:current_step
97+
]
98+
self.cost[self.episode_id, :current_step] = episode_data["cost"][:current_step]
99+
self.done[self.episode_id, :current_step] = episode_data["done"][:current_step]
100+
101+
# Set episode length
102+
self.episode_lengths[self.episode_id] = current_step
103+
104+
# Mark remaining timesteps as done
105+
self.done[self.episode_id, current_step:] = True
106+
107+
# Increment counters
108+
self.episode_id += 1
109+
self._valid_episodes = min(self._valid_episodes + 1, self.capacity)
110+
111+
# Clear the ongoing episode
112+
self.ongoing_episodes[worker_id] = self._initialize_ongoing_episode(worker_id)
113+
114+
def add(self, step_data: Transition):
115+
"""Add a single environment step to the buffer."""
116+
# Ensure reward has correct shape
117+
for i in range(step_data.reward.shape[0]):
118+
# Get worker ID for this step
119+
worker_id = i
120+
# Initialize ongoing episode if needed
121+
if worker_id not in self.ongoing_episodes:
122+
self.ongoing_episodes[worker_id] = self._initialize_ongoing_episode(
123+
worker_id
124+
)
125+
126+
episode_data = self.ongoing_episodes[worker_id]
127+
current_step = episode_data["current_step"]
128+
129+
# Store current observation
130+
episode_data["observation"][current_step] = step_data.observation[i]
131+
132+
# If not the first step, store previous step's action, reward, cost, done
133+
if current_step > 0:
134+
episode_data["action"][current_step - 1] = step_data.action[i]
135+
episode_data["reward"][current_step - 1] = step_data.reward[i]
136+
episode_data["cost"][current_step - 1] = step_data.cost[i]
137+
episode_data["done"][current_step - 1] = step_data.done[i]
138+
139+
# If episode terminated
140+
if step_data.done[i]:
141+
# Store final observation
142+
episode_data["observation"][
143+
current_step + 1
144+
] = step_data.next_observation[i]
145+
self._commit_episode(worker_id)
146+
else:
147+
# Continue episode
148+
episode_data["current_step"] = current_step + 1
149+
150+
# Check if we've reached max length
151+
if current_step + 1 >= self.max_length:
152+
episode_data["done"][current_step] = True
153+
self._commit_episode(worker_id)
85154

86155
def _sample_batch(
87156
self,
@@ -93,37 +162,46 @@ def _sample_batch(
93162
valid_episodes = valid_episodes
94163
else:
95164
valid_episodes = self._valid_episodes
96-
time_limit = self.observation.shape[1]
97-
assert time_limit > sequence_length
165+
98166
while True:
99-
low = self.rs.choice(time_limit - sequence_length - 1, batch_size)
167+
episode_ids = self.rs.choice(valid_episodes, size=batch_size)
168+
low = np.array(
169+
[
170+
self.rs.randint(
171+
0, max(1, self.episode_lengths[episode_id] - sequence_length)
172+
)
173+
for episode_id in episode_ids
174+
]
175+
)
100176
timestep_ids = low[:, None] + np.tile(
101177
np.arange(sequence_length + 1),
102178
(batch_size, 1),
103179
)
104-
episode_ids = self.rs.choice(valid_episodes, size=batch_size)
105-
# Sample a sequence of length H for the actions, rewards and costs,
106-
# and a length of H + 1 for the observations (which is needed for
107-
# bootstrapping)
180+
for i, (episode_id, time_steps) in enumerate(
181+
zip(episode_ids, timestep_ids)
182+
):
183+
episode_length = self.episode_lengths[episode_id]
184+
if time_steps[-1] >= episode_length:
185+
# Adjust timesteps to end at episode termination
186+
offset = time_steps[-1] - episode_length + 1
187+
timestep_ids[i] -= offset
188+
108189
a, r, c = [
109190
x[episode_ids[:, None], timestep_ids[:, :-1]]
110-
for x in (
111-
self.action,
112-
self.reward,
113-
self.cost,
114-
)
191+
for x in (self.action, self.reward, self.cost)
115192
]
116193
o = self.observation[episode_ids[:, None], timestep_ids]
117194
o, next_o = o[:, :-1], o[:, 1:]
118-
yield o, next_o, a, r, c
195+
done = self.done[episode_ids[:, None], timestep_ids[:, :-1]]
196+
yield o, next_o, a, r, c, done
119197

120198
def sample(self, n_batches: int) -> Iterator[TrajectoryData]:
121199
if self.empty:
122200
return
123201
iterator = (
124202
TrajectoryData(
125203
*next(self._sample_batch(self.batch_size, self.sequence_length))
126-
) # type: ignore
204+
)
127205
for _ in range(n_batches)
128206
)
129207
if jax.default_backend() == "gpu":

actsafe/configs/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ training:
3939
safety_budget: 25
4040
seed: 0
4141
time_limit: 1000
42-
episodes_per_epoch: 5
42+
steps_per_epoch: 5000
4343
epochs: 200
4444
action_repeat: 1
4545
render_episodes: 0

actsafe/configs/experiment/debug.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ writers:
88

99
training:
1010
epochs: 2
11-
episodes_per_epoch: 1
11+
steps_per_epoch: 500
1212
time_limit: 100
1313
action_repeat: 2
1414
parallel_envs: 5

actsafe/rl/acting.py

+36-50
Original file line numberDiff line numberDiff line change
@@ -3,74 +3,60 @@
33

44
from actsafe.rl.episodic_async_env import EpisodicAsync
55
from actsafe.rl.epoch_summary import EpochSummary
6-
from actsafe.rl.trajectory import Trajectory, TrajectoryData, Transition
6+
from actsafe.rl.trajectory import Trajectory, Transition
77
from actsafe.rl.types import Agent
88

99

10-
def _summarize_episodes(
11-
trajectory: TrajectoryData,
12-
) -> tuple[float, float]:
13-
reward = float(trajectory.reward.sum(1).mean())
14-
cost = float(trajectory.cost.sum(1).mean())
15-
return reward, cost
16-
17-
1810
def interact(
1911
agent: Agent,
2012
environment: EpisodicAsync,
21-
num_episodes: int,
13+
num_steps: int,
2214
train: bool,
2315
step: int,
2416
render_episodes: int = 0,
2517
) -> tuple[list[Trajectory], int]:
2618
observations = environment.reset()
27-
episode_count = 0
2819
episodes: list[Trajectory] = []
29-
trajectory = Trajectory()
30-
with tqdm(
31-
total=num_episodes,
32-
unit=f"Episode (✕ {environment.num_envs} parallel)",
33-
) as pbar:
34-
while episode_count < num_episodes:
35-
render = render_episodes > 0
36-
if render:
37-
trajectory.frames.append(environment.render())
38-
actions = agent(observations, train)
39-
next_observations, rewards, done, infos = environment.step(actions)
40-
costs = np.array([info.get("cost", 0) for info in infos])
41-
transition = Transition(
42-
observations, next_observations, actions, rewards, costs
43-
)
44-
trajectory.transitions.append(transition)
45-
agent.observe_transition(transition)
46-
observations = next_observations
47-
if done.any():
48-
assert (
49-
done.all()
50-
), "No support for environments with different ending conditions"
51-
np_trajectory = trajectory.as_numpy()
52-
step += (
53-
int(np.prod(np_trajectory.cost.shape))
54-
* environment.action_repeat
55-
)
56-
if train:
57-
agent.observe(np_trajectory)
58-
reward, cost = _summarize_episodes(np_trajectory)
59-
pbar.set_postfix({"reward": reward, "cost": cost})
60-
if render:
61-
render_episodes = max(render_episodes - 1, 0)
20+
trajectories = [Trajectory() for _ in range(environment.num_envs)]
21+
track_rewards = np.zeros(environment.num_envs)
22+
track_costs = np.zeros(environment.num_envs)
23+
pbar = tqdm(
24+
range(0, num_steps, environment.action_repeat * environment.num_envs),
25+
unit=f"Steps (✕ {environment.num_envs} parallel)",
26+
)
27+
for _ in pbar:
28+
render = render_episodes > 0
29+
if render:
30+
images = environment.render()
31+
for i, trajectory in enumerate(trajectories):
32+
trajectory.frames.append(images[i])
33+
actions = agent(observations, train)
34+
next_observations, rewards, done, infos = environment.step(actions)
35+
costs = np.array([info.get("cost", 0) for info in infos])
36+
transition = Transition(
37+
observations, next_observations, actions, rewards, costs, done
38+
)
39+
for i, trajectory in enumerate(trajectories):
40+
trajectory.transitions.append(Transition(*map(lambda x: x[i], transition)))
41+
agent.observe_transition(transition)
42+
observations = next_observations
43+
step += environment.action_repeat
44+
track_rewards += rewards * (~done)
45+
track_costs += costs * (~done)
46+
pbar.set_postfix({"reward": track_rewards.mean(), "cost": track_costs.mean()})
47+
if render:
48+
render_episodes = max(render_episodes - done.any(), 0)
49+
for i, (ep_done, trajectory) in enumerate(zip(done, trajectories)):
50+
if ep_done:
6251
episodes.append(trajectory)
63-
trajectory = Trajectory()
64-
pbar.update(1)
65-
episode_count += 1
66-
observations = environment.reset()
52+
trajectories[i] = Trajectory()
6753
return episodes, step
6854

6955

7056
def epoch(
7157
agent: Agent,
7258
env: EpisodicAsync,
73-
num_episodes: int,
59+
num_steps: int,
7460
train: bool,
7561
step: int,
7662
render_episodes: int = 0,
@@ -79,7 +65,7 @@ def epoch(
7965
samples, step = interact(
8066
agent,
8167
env,
82-
num_episodes,
68+
num_steps,
8369
train,
8470
step,
8571
render_episodes,

0 commit comments

Comments
 (0)