Skip to content

Commit

Permalink
Update mappolag_trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chauncygu authored Nov 14, 2022
1 parent 7985d5f commit ab29976
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dexteroushandenvs/algorithms/algorithms/mappolag_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def ppo_update(self, sample, update_actor=True, precomputed_eval=None,
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch, factor_batch, cost_preds_batch, cost_returns_barch, rnn_states_cost_batch, \
cost_adv_targ = sample
cost_adv_targ, aver_episode_costs = sample

old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
adv_targ = check(adv_targ).to(**self.tpdv)
Expand Down Expand Up @@ -278,7 +278,7 @@ def ppo_update(self, sample, update_actor=True, precomputed_eval=None,
self.policy.actor_optimizer.step()

# todo: update lamda_lagr
delta_lamda_lagr = -((cost_values - self.safety_bound) * (1 - self.gamma) + (imp_weights * cost_adv_targ)).mean().detach()
delta_lamda_lagr = -((aver_episode_costs.mean() - self.safety_bound) * (1 - self.gamma) + (imp_weights * cost_adv_targ)).mean().detach()

R_Relu = torch.nn.ReLU()
new_lamda_lagr = R_Relu(self.lamda_lagr - (delta_lamda_lagr * self.lagrangian_coef))
Expand Down

0 comments on commit ab29976

Please sign in to comment.