diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 1f94992e..1bfa3a88 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -129,7 +129,7 @@ def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None: values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`list` of bool): Game over statuses of each environment """ - compute_returns_and_advantage( + self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage( self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy() ) diff --git a/genrl/agents/deep/vpg/vpg.py b/genrl/agents/deep/vpg/vpg.py index 2d6a6ef9..3b34d9d3 100644 --- a/genrl/agents/deep/vpg/vpg.py +++ b/genrl/agents/deep/vpg/vpg.py @@ -116,7 +116,7 @@ def get_traj_loss(self, values, dones): values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`list` of bool): Game over statuses of each environment """ - compute_returns_and_advantage( + self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage( self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy() ) diff --git a/genrl/utils/discount.py b/genrl/utils/discount.py index 54d8754d..aeec56ba 100644 --- a/genrl/utils/discount.py +++ b/genrl/utils/discount.py @@ -29,23 +29,26 @@ def compute_returns_and_advantage( if use_gae: gae_lambda = rollout_buffer.gae_lambda else: - gae_lambda = 1 + gae_lambda = 1.0 - next_values = last_value - next_non_terminal = 1 - dones + next_value = last_value + next_non_terminal = 1.0 - dones running_advantage = 0.0 for step in reversed(range(rollout_buffer.buffer_size)): delta = ( rollout_buffer.rewards[step] - + rollout_buffer.gamma * next_non_terminal * next_values + + rollout_buffer.gamma * next_value * next_non_terminal - rollout_buffer.values[step] ) running_advantage = ( - delta + rollout_buffer.gamma * gae_lambda * running_advantage + delta + + rollout_buffer.gamma * gae_lambda * next_non_terminal * running_advantage ) next_non_terminal = 1 - rollout_buffer.dones[step] - next_values = rollout_buffer.values[step] + next_value = rollout_buffer.values[step] rollout_buffer.advantages[step] = running_advantage rollout_buffer.returns = rollout_buffer.advantages + rollout_buffer.values + + return rollout_buffer.returns, rollout_buffer.advantages