Skip to content

Commit

Permalink
Merge pull request #460 from instadeepai/feature/bugfix-maddpg-mad4pg-v2
Browse files Browse the repository at this point in the history
Bugfix/MADD(4)PG
  • Loading branch information
DriesSmit authored Mar 25, 2022
2 parents 4174121 + 998262e commit 7e3f301
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running MADDPG on pettinzoo MPE environments."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import mad4pg
from mava.utils import lp_utils
from mava.utils.enums import ArchitectureType
from mava.utils.environments import pettingzoo_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
"env_class",
"sisl",
"Pettingzoo environment class, e.g. atari (str).",
)

flags.DEFINE_string(
"env_name",
"multiwalker_v7",
"Pettingzoo environment name, e.g. pong (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:
"""Run example.
Args:
_ (Any): None
"""

# Environment.
environment_factory = functools.partial(
pettingzoo_utils.make_environment,
env_class=FLAGS.env_class,
env_name=FLAGS.env_name,
)

# Networks.
network_factory = lp_utils.partial_kwargs(
mad4pg.make_default_networks,
architecture_type=ArchitectureType.recurrent,
vmin=-150,
vmax=150,
num_atoms=101,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program.
program = mad4pg.MAD4PG(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
trainer_fn=mad4pg.training.MAD4PGDecentralisedRecurrentTrainer,
executor_fn=mad4pg.execution.MAD4PGRecurrentExecutor,
batch_size=32,
sequence_length=20,
period=20,
min_replay_size=1000,
max_replay_size=100000,
prefetch_size=4,
n_step=5,
samples_per_insert=None,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
3 changes: 1 addition & 2 deletions mava/adders/reverb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def get_trajectory_net_agents(
trajectory: Union[Trajectory, mava_types.Transition],
trajectory_net_keys: Dict[str, str],
) -> Tuple[List, Dict[str, List]]:
"""Returns a dictionary that maps network_keys to a list of agents using that
specific network.
"""Maps network_keys to a list of agents using that specific network.
Args:
trajectory: Episode experience recorded by
Expand Down
33 changes: 21 additions & 12 deletions mava/systems/tf/mad4pg/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
critic_loss = losses.categorical(
q_tm1, r_t[agent], discount * d_t[agent], q_t
)
self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0)
self.critic_losses[agent] = tf.reduce_mean(critic_loss)
# Actor learning.
o_t_agent_feed = o_t_trans[agent]
dpg_a_t = self._policy_networks[agent_key](o_t_agent_feed)
Expand All @@ -214,13 +214,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
clip_norm = True if self._max_gradient_norm is not None else False

policy_loss = losses.dpg(
dpg_q_t,
dpg_a_t,
q_max=dpg_q_t,
a_max=dpg_a_t,
tape=tape,
dqda_clipping=dqda_clipping,
clip_norm=clip_norm,
)
self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0)
self.policy_losses[agent] = tf.reduce_mean(policy_loss)
self.tape = tape


Expand Down Expand Up @@ -594,7 +594,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
data: Trajectory = inputs.data

# Note (dries): The unused variable is start_of_episodes.
observations, actions, rewards, discounts, _, extras = (
observations, actions, rewards, end_of_episode, _, extras = (
data.observations,
data.actions,
data.rewards,
Expand Down Expand Up @@ -660,19 +660,24 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:

# Cast the additional discount to match
# the environment discount dtype.
agent_discount = discounts[agent]
agent_discount = end_of_episode[agent]
discount = tf.cast(self._discount, dtype=agent_discount.dtype)
agent_end_of_episode = end_of_episode[agent]
ones_mask = tf.ones(shape=(agent_end_of_episode.shape[0], 1))
step_not_padded = tf.concat(
[ones_mask, agent_end_of_episode[:, :-1]], axis=1
)

# Critic loss.
critic_loss = recurrent_n_step_critic_loss(
q_values,
target_q_values,
rewards[agent],
discount * agent_discount,
q_values=q_values,
target_q_values=target_q_values,
rewards=rewards[agent],
discounts=discount * agent_discount,
bootstrap_n=self._bootstrap_n,
loss_fn=losses.categorical,
)
self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0)
self.critic_losses[agent] = tf.reduce_mean(critic_loss)

# Actor learning.
obs_agent_feed = target_obs_trans[agent]
Expand Down Expand Up @@ -718,7 +723,11 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
dqda_clipping=dqda_clipping,
clip_norm=clip_norm,
)
self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0)
policy_mask = tf.reshape(step_not_padded, policy_loss.shape)
policy_loss = policy_loss * policy_mask
self.policy_losses[agent] = tf.reduce_sum(policy_loss) / tf.reduce_sum(
policy_mask
)
self.tape = tape


Expand Down
103 changes: 63 additions & 40 deletions mava/systems/tf/maddpg/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp
import tree
from acme import types
from acme.specs import EnvironmentSpec

Expand Down Expand Up @@ -93,7 +94,6 @@ def __init__(
variable_client=variable_client,
)

@tf.function
def _policy(
self, agent: str, observation: types.NestedTensor
) -> types.NestedTensor:
Expand Down Expand Up @@ -147,12 +147,19 @@ def select_action(
# Step the recurrent policy/value network forward
# given the current observation and state.
action, policy = self._policy(agent, observation.observation)

# Return a numpy array with squeezed out batch dimension.
action = tf2_utils.to_numpy_squeeze(action)
policy = tf2_utils.to_numpy_squeeze(policy)
return action, policy

@tf.function
def _select_actions(
self, observations: Dict[str, types.NestedArray]
) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]:
"""Select the actions for all agents in the system"""
actions = {}
policies = {}
for agent, observation in observations.items():
actions[agent], policies[agent] = self.select_action(agent, observation)
return actions, policies

def select_actions(
self, observations: Dict[str, types.NestedArray]
) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]:
Expand All @@ -166,10 +173,9 @@ def select_actions(
actions and policies for all agents in the system.
"""

actions = {}
policies = {}
for agent, observation in observations.items():
actions[agent], policies[agent] = self.select_action(agent, observation)
actions, policies = self._select_actions(observations)
actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions)
policies = tree.map_structure(tf2_utils.to_numpy_squeeze, policies)
return actions, policies

def observe_first(
Expand Down Expand Up @@ -284,7 +290,6 @@ def __init__(
store_recurrent_state=store_recurrent_state,
)

@tf.function
def _policy(
self,
agent: str,
Expand Down Expand Up @@ -322,41 +327,36 @@ def _policy(
raise NotImplementedError
return action, policy, new_state

def select_action(
self, agent: str, observation: types.NestedArray
) -> types.NestedArray:
"""select an action for a single agent in the system
def select_actions(
self, observations: Dict[str, types.NestedArray]
) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]:
"""select the actions for all agents in the system
Args:
agent: agent id
observation: observation tensor received from the
observations: agent observations from the
environment.
Returns:
action and policy.
actions and policies for all agents in the system.
"""

# TODO Mask actions here using observation.legal_actions
# Initialize the RNN state if necessary.
if self._states[agent] is None:
# index network either on agent type or on agent id
agent_key = self._agent_net_keys[agent]
self._states[agent] = self._policy_networks[agent_key].initia_state(1)

# Step the recurrent policy forward given the current observation and state.
action, policy, new_state = self._policy(
agent, observation.observation, self._states[agent]
actions, policies, self._states = self._select_actions(
observations, self._states
)
actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions)
policies = tree.map_structure(tf2_utils.to_numpy_squeeze, policies)
return actions, policies

# Bookkeeping of recurrent states for the observe method.
self._update_state(agent, new_state)

# Return a numpy array with squeezed out batch dimension.
action = tf2_utils.to_numpy_squeeze(action)
policy = tf2_utils.to_numpy_squeeze(policy)
return action, policy

def select_actions(
self, observations: Dict[str, types.NestedArray]
) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]:
@tf.function
def _select_actions(
self,
observations: Dict[str, types.NestedArray],
states: Dict[str, types.NestedArray],
) -> Tuple[
Dict[str, types.NestedArray],
Dict[str, types.NestedArray],
Dict[str, types.NestedArray],
]:
"""select the actions for all agents in the system
Args:
observations: agent observations from the
Expand All @@ -367,9 +367,32 @@ def select_actions(

actions = {}
policies = {}
new_states = {}
for agent, observation in observations.items():
actions[agent], policies[agent] = self.select_action(agent, observation)
return actions, policies
actions[agent], policies[agent], new_states[agent] = self.select_action(
agent, observation, states[agent]
)
return actions, policies, new_states

def select_action( # type: ignore
self,
agent: str,
observation: types.NestedArray,
agent_state: types.NestedArray,
) -> Tuple[types.NestedArray, types.NestedArray, types.NestedArray]:
"""select an action for a single agent in the system
Args:
agent: agent id
observation: observation tensor received from the
environment.
Returns:
action and policy.
"""
# Step the recurrent policy forward given the current observation and state.
action, policy, new_state = self._policy(
agent, observation.observation, agent_state
)
return action, policy, new_state

def observe_first(
self,
Expand Down
Loading

0 comments on commit 7e3f301

Please sign in to comment.