Skip to content

Commit c9f2049

Browse files
committedNov 17, 2021
Initial commit
0 parents  commit c9f2049

8 files changed

+609
-0
lines changed
 

‎LICENSE.md

Whitespace-only changes.

‎README.md

Whitespace-only changes.

‎marl_baselines3/__init__.py

Whitespace-only changes.

‎marl_baselines3/independent_ppo.py

+352
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
import time
2+
from collections import deque
3+
from typing import Any, Dict, List, Optional, Type, Union
4+
5+
import gym
6+
import numpy as np
7+
import torch as th
8+
from gym.spaces import Box, Discrete
9+
from stable_baselines3 import PPO
10+
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
11+
from stable_baselines3.common.policies import ActorCriticPolicy
12+
from stable_baselines3.common.type_aliases import (GymEnv, MaybeCallback,
13+
Schedule)
14+
from stable_baselines3.common.utils import (configure_logger, obs_as_tensor,
15+
safe_mean)
16+
from stable_baselines3.common.vec_env import DummyVecEnv
17+
18+
19+
class DummyGymEnv(gym.Env):
20+
def __init__(self, observation_space, action_space):
21+
self.observation_space = observation_space
22+
self.action_space = action_space
23+
24+
25+
class IndependentPPO(OnPolicyAlgorithm):
26+
def __init__(
27+
self,
28+
policy: Union[str, Type[ActorCriticPolicy]],
29+
num_agents: int,
30+
env: GymEnv,
31+
learning_rate: Union[float, Schedule] = 1e-4,
32+
n_steps: int = 1000,
33+
batch_size: int = 6000,
34+
n_epochs: int = 10,
35+
gamma: float = 0.99,
36+
gae_lambda: float = 1.0,
37+
clip_range: Union[float, Schedule] = 0.2,
38+
clip_range_vf: Union[None, float, Schedule] = None,
39+
ent_coef: float = 0.0,
40+
vf_coef: float = 0.5,
41+
max_grad_norm: float = 40,
42+
use_sde: bool = False,
43+
sde_sample_freq: int = -1,
44+
target_kl: Optional[float] = None,
45+
tensorboard_log: Optional[str] = None,
46+
policy_kwargs: Optional[Dict[str, Any]] = None,
47+
verbose: int = 0,
48+
device: Union[th.device, str] = "auto",
49+
):
50+
self.env = env
51+
self.num_agents = num_agents
52+
self.num_envs = env.num_envs // num_agents
53+
self.observation_space = env.observation_space
54+
self.action_space = env.action_space
55+
self.n_steps = n_steps
56+
self.tensorboard_log = tensorboard_log
57+
self.verbose = verbose
58+
self._logger = None
59+
env_fn = lambda: DummyGymEnv(self.observation_space, self.action_space)
60+
dummy_env = DummyVecEnv([env_fn] * self.num_envs)
61+
self.policies = [
62+
PPO(
63+
policy=policy,
64+
env=dummy_env,
65+
learning_rate=learning_rate,
66+
n_steps=n_steps,
67+
batch_size=batch_size,
68+
n_epochs=n_epochs,
69+
gamma=gamma,
70+
gae_lambda=gae_lambda,
71+
clip_range=clip_range,
72+
clip_range_vf=clip_range_vf,
73+
ent_coef=ent_coef,
74+
vf_coef=vf_coef,
75+
max_grad_norm=max_grad_norm,
76+
target_kl=target_kl,
77+
use_sde=use_sde,
78+
sde_sample_freq=sde_sample_freq,
79+
policy_kwargs=policy_kwargs,
80+
verbose=verbose,
81+
device=device,
82+
)
83+
for _ in range(self.num_agents)
84+
]
85+
86+
def learn(
87+
self,
88+
total_timesteps: int,
89+
callbacks: Optional[List[MaybeCallback]] = None,
90+
log_interval: int = 1,
91+
tb_log_name: str = "IndependentPPO",
92+
reset_num_timesteps: bool = True,
93+
):
94+
95+
num_timesteps = 0
96+
all_total_timesteps = []
97+
if not callbacks:
98+
callbacks = [None] * self.num_agents
99+
self._logger = configure_logger(
100+
self.verbose,
101+
self.tensorboard_log,
102+
tb_log_name,
103+
reset_num_timesteps,
104+
)
105+
logdir = self.logger.dir
106+
107+
# Setup for each policy
108+
for polid, policy in enumerate(self.policies):
109+
policy.start_time = time.time()
110+
if policy.ep_info_buffer is None or reset_num_timesteps:
111+
policy.ep_info_buffer = deque(maxlen=100)
112+
policy.ep_success_buffer = deque(maxlen=100)
113+
114+
if policy.action_noise is not None:
115+
policy.action_noise.reset()
116+
117+
if reset_num_timesteps:
118+
policy.num_timesteps = 0
119+
policy._episode_num = 0
120+
all_total_timesteps.append(total_timesteps)
121+
policy._total_timesteps = total_timesteps
122+
else:
123+
# make sure training timestamps are ahead of internal counter
124+
all_total_timesteps.append(total_timesteps + policy.num_timesteps)
125+
policy._total_timesteps = total_timesteps + policy.num_timesteps
126+
127+
policy._logger = configure_logger(
128+
policy.verbose,
129+
logdir,
130+
"policy",
131+
reset_num_timesteps,
132+
)
133+
134+
callbacks[polid] = policy._init_callback(callbacks[polid])
135+
136+
for callback in callbacks:
137+
callback.on_training_start(locals(), globals())
138+
139+
last_obs = self.env.reset()
140+
for policy in self.policies:
141+
policy._last_episode_starts = np.ones((self.num_envs,), dtype=bool)
142+
143+
while num_timesteps < total_timesteps:
144+
last_obs = self.collect_rollouts(last_obs, callbacks)
145+
num_timesteps += self.num_envs * self.n_steps
146+
for polid, policy in enumerate(self.policies):
147+
policy._update_current_progress_remaining(
148+
policy.num_timesteps, total_timesteps
149+
)
150+
if log_interval is not None and num_timesteps % log_interval == 0:
151+
fps = int(policy.num_timesteps / (time.time() - policy.start_time))
152+
policy.logger.record("policy_id", polid, exclude="tensorboard")
153+
policy.logger.record(
154+
"time/iterations", num_timesteps, exclude="tensorboard"
155+
)
156+
if (
157+
len(policy.ep_info_buffer) > 0
158+
and len(policy.ep_info_buffer[0]) > 0
159+
):
160+
policy.logger.record(
161+
"rollout/ep_rew_mean",
162+
safe_mean(
163+
[ep_info["r"] for ep_info in policy.ep_info_buffer]
164+
),
165+
)
166+
policy.logger.record(
167+
"rollout/ep_len_mean",
168+
safe_mean(
169+
[ep_info["l"] for ep_info in policy.ep_info_buffer]
170+
),
171+
)
172+
policy.logger.record("time/fps", fps)
173+
policy.logger.record(
174+
"time/time_elapsed",
175+
int(time.time() - policy.start_time),
176+
exclude="tensorboard",
177+
)
178+
policy.logger.record(
179+
"time/total_timesteps",
180+
policy.num_timesteps,
181+
exclude="tensorboard",
182+
)
183+
policy.logger.dump(step=policy.num_timesteps)
184+
185+
policy.train()
186+
187+
for callback in callbacks:
188+
callback.on_training_end()
189+
190+
def collect_rollouts(self, last_obs, callbacks):
191+
192+
all_last_episode_starts = [None] * self.num_agents
193+
all_obs = [None] * self.num_agents
194+
all_last_obs = [None] * self.num_agents
195+
all_rewards = [None] * self.num_agents
196+
all_dones = [None] * self.num_agents
197+
all_infos = [None] * self.num_agents
198+
steps = 0
199+
200+
for polid, policy in enumerate(self.policies):
201+
for envid in range(self.num_envs):
202+
assert (
203+
last_obs[envid * self.num_agents + polid] is not None
204+
), f"No previous observation was provided for env_{envid}_policy_{polid}"
205+
all_last_obs[polid] = np.array(
206+
[
207+
last_obs[envid * self.num_agents + polid]
208+
for envid in range(self.num_envs)
209+
]
210+
)
211+
policy.policy.set_training_mode(False)
212+
policy.rollout_buffer.reset()
213+
callbacks[polid].on_rollout_start()
214+
all_last_episode_starts[polid] = policy._last_episode_starts
215+
216+
while steps < self.n_steps:
217+
all_actions = [None] * self.num_agents
218+
all_values = [None] * self.num_agents
219+
all_log_probs = [None] * self.num_agents
220+
all_clipped_actions = [None] * self.num_agents
221+
with th.no_grad():
222+
for polid, policy in enumerate(self.policies):
223+
obs_tensor = obs_as_tensor(all_last_obs[polid], policy.device)
224+
(
225+
all_actions[polid],
226+
all_values[polid],
227+
all_log_probs[polid],
228+
) = policy.policy.forward(obs_tensor)
229+
clipped_actions = all_actions[polid].cpu().numpy()
230+
if isinstance(self.action_space, Box):
231+
clipped_actions = np.clip(
232+
clipped_actions,
233+
self.action_space.low,
234+
self.action_space.high,
235+
)
236+
elif isinstance(self.action_space, Discrete):
237+
# get integer from numpy array
238+
clipped_actions = np.array(
239+
[action.item() for action in clipped_actions]
240+
)
241+
all_clipped_actions[polid] = clipped_actions
242+
243+
all_clipped_actions = (
244+
np.vstack(all_clipped_actions).transpose().reshape(-1)
245+
) # reshape as (env, action)
246+
obs, rewards, dones, infos = self.env.step(all_clipped_actions)
247+
248+
for polid in range(self.num_agents):
249+
all_obs[polid] = np.array(
250+
[
251+
obs[envid * self.num_agents + polid]
252+
for envid in range(self.num_envs)
253+
]
254+
)
255+
all_rewards[polid] = np.array(
256+
[
257+
rewards[envid * self.num_agents + polid]
258+
for envid in range(self.num_envs)
259+
]
260+
)
261+
all_dones[polid] = np.array(
262+
[
263+
dones[envid * self.num_agents + polid]
264+
for envid in range(self.num_envs)
265+
]
266+
)
267+
all_infos[polid] = np.array(
268+
[
269+
infos[envid * self.num_agents + polid]
270+
for envid in range(self.num_envs)
271+
]
272+
)
273+
274+
for policy in self.policies:
275+
policy.num_timesteps += self.num_envs
276+
277+
for callback in callbacks:
278+
callback.update_locals(locals())
279+
if not [callback.on_step() for callback in callbacks]:
280+
break
281+
282+
for polid, policy in enumerate(self.policies):
283+
policy._update_info_buffer(all_infos[polid])
284+
285+
steps += 1
286+
287+
# add data to the rollout buffers
288+
for polid, policy in enumerate(self.policies):
289+
if isinstance(self.action_space, Discrete):
290+
all_actions[polid] = all_actions[polid].reshape(-1, 1)
291+
all_actions[polid] = all_actions[polid].cpu().numpy()
292+
policy.rollout_buffer.add(
293+
all_last_obs[polid],
294+
all_actions[polid],
295+
all_rewards[polid],
296+
all_last_episode_starts[polid],
297+
all_values[polid],
298+
all_log_probs[polid],
299+
)
300+
all_last_obs = all_obs
301+
all_last_episode_starts = all_dones
302+
303+
with th.no_grad():
304+
for polid, policy in enumerate(self.policies):
305+
obs_tensor = obs_as_tensor(all_last_obs[polid], policy.device)
306+
_, value, _ = policy.policy.forward(obs_tensor)
307+
policy.rollout_buffer.compute_returns_and_advantage(
308+
last_values=value, dones=all_dones[polid]
309+
)
310+
311+
for callback in callbacks:
312+
callback.on_rollout_end()
313+
314+
for polid, policy in enumerate(self.policies):
315+
policy._last_episode_starts = all_last_episode_starts[polid]
316+
317+
return obs
318+
319+
@classmethod
320+
def load(
321+
cls,
322+
path: str,
323+
policy: Union[str, Type[ActorCriticPolicy]],
324+
num_agents: int,
325+
env: GymEnv,
326+
n_steps: int,
327+
policy_kwargs: Optional[Dict[str, Any]] = None,
328+
tensorboard_log: Optional[str] = None,
329+
verbose: int = 0,
330+
**kwargs,
331+
) -> "IndependentPPO":
332+
model = cls(
333+
policy=policy,
334+
num_agents=num_agents,
335+
env=env,
336+
n_steps=n_steps,
337+
policy_kwargs=policy_kwargs,
338+
tensorboard_log=tensorboard_log,
339+
verbose=verbose,
340+
**kwargs,
341+
)
342+
env_fn = lambda: DummyGymEnv(env.observation_space, env.action_space)
343+
dummy_env = DummyVecEnv([env_fn] * (env.num_envs // num_agents))
344+
for polid in range(num_agents):
345+
model.policies[polid] = PPO.load(
346+
path=path + f"/policy_{polid + 1}/model", env=dummy_env, **kwargs
347+
)
348+
return model
349+
350+
def save(self, path: str) -> None:
351+
for polid in range(self.num_agents):
352+
self.policies[polid].save(path=path + f"/policy_{polid + 1}/model")

0 commit comments

Comments
 (0)