From 373c52c8e26f6224316069be1d0838fc11216acc Mon Sep 17 00:00:00 2001 From: Ruofan Kong Date: Mon, 22 Jun 2020 19:26:57 -0700 Subject: [PATCH 1/2] No Case: Add Multivariate Diag Gaussian distribution. --- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 9 +++-- python/ray/rllib/models/action_dist.py | 40 +++++++++++++++++++ python/ray/rllib/models/catalog.py | 15 ++++++- 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index b352f5608479..4bff7ba13119 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -39,7 +39,8 @@ def __init__(self, clip_param=0.1, vf_clip_param=0.1, vf_loss_coeff=1.0, - use_gae=True): + use_gae=True, + model_config={}): """Constructs the loss for Proximal Policy Objective. Arguments: @@ -70,7 +71,7 @@ def __init__(self, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( @@ -284,7 +285,9 @@ def __init__(self, clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], vf_loss_coeff=self.config["vf_loss_coeff"], - use_gae=self.config["use_gae"]) + use_gae=self.config["use_gae"], + model_config=self.config["model"] + ) LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 026a6c493e5c..5ba0857d4c20 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -6,6 +6,7 @@ import distutils.version import tensorflow as tf import numpy as np +from tensorflow.contrib.distributions import MultivariateNormalDiag from ray.rllib.utils.annotations import override, DeveloperAPI @@ -139,6 +140,45 @@ def _build_sample_op(self): return tf.stack([cat.sample() for cat in self.cats], axis=1) +class MultiVariateDiagGaussian(ActionDistribution): + """ + Action distribution where each vector element is a gaussian with + its independent mean and correlated std. + """ + def __init__(self, inputs): + mean, log_std = tf.split(inputs, 2, axis=1) + self.mean = mean + self.log_std = log_std + self.std = tf.exp(log_std) + self.distribution = MultivariateNormalDiag(loc=self.mean, scale_diag=self.std) + ActionDistribution.__init__(self, inputs) + + @override(ActionDistribution) + def logp(self, x): + return self.distribution.log_prob(x) + + @override(ActionDistribution) + def kl(self, other): + if not isinstance(other, MultiVariateDiagGaussian): + raise TypeError( + "Argument other expected type MultiVariateDiagGaussian. " + "Received type {}.".format(type(other)) + ) + return tf.reduce_sum( + other.log_std - self.log_std + + (tf.square(self.std) + tf.square(self.mean - other.mean)) / + (2.0 * tf.square(other.std)) - 0.5, + reduction_indices=[1]) + + @override(ActionDistribution) + def entropy(self): + return self.distribution.entropy() + + @override(ActionDistribution) + def _build_sample_op(self): + return self.distribution.sample() + + class DiagGaussian(ActionDistribution): """Action distribution where each vector element is a gaussian. diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 776773552df1..667d336db486 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -14,7 +14,8 @@ from ray.rllib.models.extra_spaces import Simplex from ray.rllib.models.action_dist import (Categorical, MultiCategorical, Deterministic, DiagGaussian, - MultiActionDistribution, Dirichlet) + MultiActionDistribution, Dirichlet, + MultiVariateDiagGaussian) from ray.rllib.models.torch_action_dist import (TorchCategorical, TorchDiagGaussian) from ray.rllib.models.preprocessors import get_preprocessor @@ -114,7 +115,17 @@ def get_action_dist(action_space, config, dist_type=None, torch=False): "Consider reshaping this into a single dimension, " "using a Tuple action space, or the multi-agent API.") if dist_type is None: - dist = TorchDiagGaussian if torch else DiagGaussian + if torch: + dist = TorchDiagGaussian + else: + custom_options = config.get("custom_options") + if custom_options is None: + dist = DiagGaussian + else: + if custom_options.get("use_multi_variate_normal_diag") is None: + dist = DiagGaussian + else: + dist = MultiVariateDiagGaussian if config.get("squash_to_range"): raise ValueError( "The squash_to_range option is deprecated. See the " From 07a56025c433ad868df5da5935b7c61fc6f5a204 Mon Sep 17 00:00:00 2001 From: Ruofan Kong Date: Mon, 22 Jun 2020 22:13:26 -0700 Subject: [PATCH 2/2] use tfp. --- python/ray/rllib/models/action_dist.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 5ba0857d4c20..790d96ad58c4 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -5,8 +5,8 @@ from collections import namedtuple import distutils.version import tensorflow as tf +import tensorflow_probability as tfp import numpy as np -from tensorflow.contrib.distributions import MultivariateNormalDiag from ray.rllib.utils.annotations import override, DeveloperAPI @@ -147,10 +147,9 @@ class MultiVariateDiagGaussian(ActionDistribution): """ def __init__(self, inputs): mean, log_std = tf.split(inputs, 2, axis=1) - self.mean = mean - self.log_std = log_std - self.std = tf.exp(log_std) - self.distribution = MultivariateNormalDiag(loc=self.mean, scale_diag=self.std) + std = tf.exp(log_std) + self.distribution = tfp.distributions.MultivariateNormalDiag( + loc=mean, scale_diag=std) ActionDistribution.__init__(self, inputs) @override(ActionDistribution) @@ -164,11 +163,7 @@ def kl(self, other): "Argument other expected type MultiVariateDiagGaussian. " "Received type {}.".format(type(other)) ) - return tf.reduce_sum( - other.log_std - self.log_std + - (tf.square(self.std) + tf.square(self.mean - other.mean)) / - (2.0 * tf.square(other.std)) - 0.5, - reduction_indices=[1]) + return self.distribution.kl_divergence(other.distribution) @override(ActionDistribution) def entropy(self):