Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray update #2

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 169 additions & 187 deletions maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,163 +14,32 @@
import logging
from typing import Optional, Type

from ray.rllib.agents.trainer import COMMON_CONFIG, with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from maddpg_tf_policy import MADDPGTFPolicy
from maddpg_torch_policy import MADDPGTorchPolicy
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.simple_q import DEFAULT_CONFIG as SIMPLEQ_DEFAULT_CONFIG
from ray.rllib.agents.trainer import COMMON_CONFIG, Trainer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict

from maddpg_tf_policy import MADDPGTFPolicy
from maddpg_torch_policy import MADDPGTorchPolicy

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Framework to run the algorithm ===
"framework": "tf",

# === Settings for each individual policy ===
# ID of the agent controlled by this policy
"agent_id": None,
# Use a local critic for this policy.
"use_local_critic": False,

# === Evaluation ===
# Evaluation interval
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,

# === Model ===
# Apply a state preprocessor with spec given by the "model" config option
# (like other RL algorithms). This is mostly useful if you have a weird
# observation shape, like an image. Disabled by default.
"use_state_preprocessor": False,
# Postprocess the policy network model output with these hidden layers. If
# use_state_preprocessor is False, then these will be the *only* hidden
# layers in the network.
"actor_hiddens": [64, 64],
# Hidden layers activation of the postprocessing stage of the policy
# network
"actor_hidden_activation": "relu",
# Postprocess the critic network model output with these hidden layers;
# again, if use_state_preprocessor is True, then the state will be
# preprocessed by the model specified with the "model" config option first.
"critic_hiddens": [64, 64],
# Hidden layers activation of the postprocessing state of the critic.
"critic_hidden_activation": "relu",
# N-step Q learning
"n_step": 1,
# Algorithm for good policies.
"good_policy": "maddpg",
# Algorithm for adversary policies.
"adv_policy": "maddpg",
# list of other agent_ids and policies to approximate (See MADDPG Section 4.2)
"learn_other_policies": None,

# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": int(1e6),
# Observation compression. Note that compression makes simulation slow in
# MPE.
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# Force lockstep replay mode for MADDPG.
"multiagent": merge_dicts(COMMON_CONFIG["multiagent"], {
"replay_mode": "lockstep",
}),

# === Exploration ===
"exploration_config": {
"type": "GaussianNoise",
# For how many timesteps should we return completely random actions,
# before we start adding (scaled) noise?
"random_timesteps": 1000,
# The stddev (sigma) to be used for the actions
"stddev": 0.5,
# The initial noise scaling factor.
"initial_scale": 1.0,
# The final noise scaling factor.
"final_scale": 0.02,
# Timesteps over which to anneal scale (from initial to final values).
"scale_timesteps": 10000,
},
# Extra configuration that disables exploration.
"evaluation_config": {
"explore": False
},

# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.
"critic_lr": 1e-2,
# Learning rate for the actor (policy) optimizer.
"actor_lr": 1e-2,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.01,
# Weights for feature regularization for the actor
"actor_feature_reg": 0.001,
# If not None, clip gradients during optimization at this value
"grad_clip": 100,
# How many steps of the model to sample before learning starts.
"learning_starts": 1024 * 25,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 100,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 1024,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,

# torch-specific model configs
"twin_q": False,
# delayed policy update
"policy_delay": 1,
# target policy smoothing
# (this also replaces OU exploration noise with IID Gaussian exploration noise, for now)
"smooth_target_policy": False,
"use_huber": False,
"huber_threshold": 1.0,
"l2_reg": None,

# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,
})
# __sphinx_doc_end__
# yapf: enable


def _make_continuous_space(space):
if isinstance(space, Box):
return space
elif isinstance(space, Discrete):
return Box(low=np.zeros((space.n,)), high=np.ones((space.n,)))
else:
raise UnsupportedSpaceException("Space {} is not supported.".format(space))


def before_learn_on_batch(
multi_agent_batch, policies, train_batch_size, framework="tf"
):
# TODO: This should only operate on agents following maddpg, not ddpg!
def maddpg_learn_on_batch(multi_agent_batch, workers, config):
policies = dict(
workers.local_worker().foreach_trainable_policy(lambda p, i: (i, p))
)
samples = {}
train_batch_size = config["train_batch_size"]
framework = config["framework"]

# Modify keys.
for pid, p in policies.items():
Expand All @@ -195,10 +64,11 @@ def sampler(policy, obs):
sampler(policy, obs) for policy, obs in zip(policies.values(), new_obs_n)
]
else:
target_act_sampler_n = [p.target_act_sampler for p in policies.values()]
new_obs_ph_n = [p.new_obs_ph for p in policies.values()]
feed_dict = dict(zip(new_obs_ph_n, new_obs_n))
new_act_n = p.sess.run(target_act_sampler_n, feed_dict)
for i, p in enumerate(policies.values()):
feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
new_act = p.get_session().run(p.target_act_sampler, feed_dict)
samples.update({"new_actions_%d" % i: new_act})

samples.update(
{"new_actions_%d" % i: new_act for i, new_act in enumerate(new_act_n)}
Expand All @@ -209,43 +79,155 @@ def sampler(policy, obs):
return MultiAgentBatch(policy_batches, train_batch_size)


def add_maddpg_postprocessing(config):
"""Add the before learn on batch hook.

This hook is called explicitly prior to TrainOneStep() in the execution
setups for DQN and APEX.
"""

def f(batch, workers, config):
policies = dict(
workers.local_worker().foreach_trainable_policy(lambda p, i: (i, p))
)
return before_learn_on_batch(
batch, policies, config["train_batch_size"], config["framework"]
)

config["before_learn_on_batch"] = f
return config


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
return MADDPGTorchPolicy
else:
return MADDPGTFPolicy
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
SIMPLEQ_DEFAULT_CONFIG,
{
# === Framework to run the algorithm ===
"framework": "tf",

# === Settings for each individual policy ===
# ID of the agent controlled by this policy
"agent_id": None,
# Use a local critic for this policy.
"use_local_critic": False,

# === Evaluation ===
# Evaluation interval
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,

# === Model ===
# Apply a state preprocessor with spec given by the "model" config option
# (like other RL algorithms). This is mostly useful if you have a weird
# observation shape, like an image. Disabled by default.
"use_state_preprocessor": False,
# Postprocess the policy network model output with these hidden layers. If
# use_state_preprocessor is False, then these will be the *only* hidden
# layers in the network.
"actor_hiddens": [64, 64],
# Hidden layers activation of the postprocessing stage of the policy
# network
"actor_hidden_activation": "relu",
# Postprocess the critic network model output with these hidden layers;
# again, if use_state_preprocessor is True, then the state will be
# preprocessed by the model specified with the "model" config option first.
"critic_hiddens": [64, 64],
# Hidden layers activation of the postprocessing state of the critic.
"critic_hidden_activation": "relu",
# N-step Q learning
"n_step": 1,
# Algorithm for good policies.
"good_policy": "maddpg",
# Algorithm for adversary policies.
"adv_policy": "maddpg",

# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
"capacity": 1000000,
},
# Observation compression. Note that compression makes simulation slow in
# MPE.
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# Force lockstep replay mode for MADDPG.
"multiagent": merge_dicts(COMMON_CONFIG["multiagent"], {
"replay_mode": "lockstep",
}),
# Callback to share multi-agent batch for maddpg
"before_learn_on_batch": maddpg_learn_on_batch,

# === Optimization ===
# Learning rate for the critic (Q-function) optimizer.
"critic_lr": 1e-2,
# Learning rate for the actor (policy) optimizer.
"actor_lr": 1e-2,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 0.01,
# Weights for feature regularization for the actor
"actor_feature_reg": 0.001,
# If not None, clip gradients during optimization at this value
"grad_clip": 100,
# How many steps of the model to sample before learning starts.
"learning_starts": 1024 * 25,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 100,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 1024,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,

# === Exploration ===
"exploration_config": {
"type": "GaussianNoise",
# For how many timesteps should we return completely random actions,
# before we start adding (scaled) noise?
"random_timesteps": 1000,
# The stddev (sigma) to be used for the actions
"stddev": 0.5,
# The initial noise scaling factor.
"initial_scale": 1.0,
# The final noise scaling factor.
"final_scale": 0.02,
# Timesteps over which to anneal scale (from initial to final values).
"scale_timesteps": 10000,
},
# Extra configuration that disables exploration.
"evaluation_config": {
"explore": False
},

# torch-specific model configs
"twin_q": False,
# delayed policy update
"policy_delay": 1,
# target policy smoothing
# (this also replaces OU exploration noise with IID Gaussian exploration noise, for now)
"smooth_target_policy": False,
"use_huber": False,
"huber_threshold": 1.0,
"l2_reg": None,

# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 0,
},
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# yapf: enable


MADDPGTrainer = GenericOffPolicyTrainer.with_updates(
name="MADDPG",
default_config=DEFAULT_CONFIG,
default_policy=MADDPGTFPolicy,
get_policy_class=get_policy_class,
validate_config=add_maddpg_postprocessing,
)
class MADDPGTrainer(DQNTrainer):
@classmethod
@override(DQNTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG

@override(DQNTrainer)
def get_default_policy_class(
self, config: TrainerConfigDict
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return MADDPGTorchPolicy
else:
return MADDPGTFPolicy
Loading