Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion genrl/agents/deep/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None:
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`list` of bool): Game over statuses of each environment
"""
compute_returns_and_advantage(
self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage(
self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy()
)

Expand Down
2 changes: 1 addition & 1 deletion genrl/agents/deep/vpg/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_traj_loss(self, values, dones):
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`list` of bool): Game over statuses of each environment
"""
compute_returns_and_advantage(
self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage(
self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy()
)

Expand Down
15 changes: 9 additions & 6 deletions genrl/utils/discount.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,26 @@ def compute_returns_and_advantage(
if use_gae:
gae_lambda = rollout_buffer.gae_lambda
else:
gae_lambda = 1
gae_lambda = 1.0

next_values = last_value
next_non_terminal = 1 - dones
next_value = last_value
next_non_terminal = 1.0 - dones

running_advantage = 0.0
for step in reversed(range(rollout_buffer.buffer_size)):
delta = (
rollout_buffer.rewards[step]
+ rollout_buffer.gamma * next_non_terminal * next_values
+ rollout_buffer.gamma * next_value * next_non_terminal
- rollout_buffer.values[step]
)
running_advantage = (
delta + rollout_buffer.gamma * gae_lambda * running_advantage
delta
+ rollout_buffer.gamma * gae_lambda * next_non_terminal * running_advantage
)
next_non_terminal = 1 - rollout_buffer.dones[step]
next_values = rollout_buffer.values[step]
next_value = rollout_buffer.values[step]
rollout_buffer.advantages[step] = running_advantage

rollout_buffer.returns = rollout_buffer.advantages + rollout_buffer.values

return rollout_buffer.returns, rollout_buffer.advantages