@@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
494494 def reset (self ) -> None :
495495 pass
496496
497- def get_entropy_bonus (self , dist : d .Distribution ) -> torch .Tensor :
497+ def _get_entropy (self , dist : d .Distribution ) -> torch .Tensor | TensorDict :
498498 try :
499499 entropy = dist .entropy ()
500500 except NotImplementedError :
@@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
513513 log_prob = log_prob .select (* self .tensor_keys .sample_log_prob )
514514
515515 entropy = - log_prob .mean (0 )
516- if is_tensor_collection (entropy ):
517- entropy = _sum_td_features (entropy )
518516 return entropy .unsqueeze (- 1 )
519517
520518 def _log_weight (
521519 self , tensordict : TensorDictBase
522- ) -> Tuple [torch .Tensor , d .Distribution ]:
520+ ) -> Tuple [torch .Tensor , d .Distribution , torch . Tensor ]:
523521
524522 with self .actor_network_params .to_module (
525523 self .actor_network
@@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
681679 log_weight = log_weight .view (advantage .shape )
682680 neg_loss = log_weight .exp () * advantage
683681 td_out = TensorDict ({"loss_objective" : - neg_loss }, batch_size = [])
682+ td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
684683 if self .entropy_bonus :
685- entropy = self .get_entropy_bonus (dist )
684+ entropy = self ._get_entropy (dist )
685+ if is_tensor_collection (entropy ):
686+ # Reports the entropy of each action head.
687+ td_out .set ("composite_entropy" , entropy .detach ())
688+ entropy = _sum_td_features (entropy )
686689 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
687- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
688690 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
689691 if self .critic_coef is not None :
690692 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
@@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
956958 # ESS for logging
957959 with torch .no_grad ():
958960 # In theory, ESS should be computed on particles sampled from the same source. Here we sample according
959- # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
960- # of the weights .
961+ # to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
962+ # dispersion .
961963 lw = log_weight .squeeze ()
962964 if not isinstance (lw , torch .Tensor ):
963965 lw = _sum_td_features (lw )
@@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
976978 gain = _sum_td_features (gain )
977979 td_out = TensorDict ({"loss_objective" : - gain }, batch_size = [])
978980 td_out .set ("clip_fraction" , clip_fraction )
981+ td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
979982
980983 if self .entropy_bonus :
981- entropy = self .get_entropy_bonus (dist )
984+ entropy = self ._get_entropy (dist )
985+ if is_tensor_collection (entropy ):
986+ # Reports the entropy of each action head.
987+ td_out .set ("composite_entropy" , entropy .detach ())
988+ entropy = _sum_td_features (entropy )
982989 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
983- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
984990 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
985991 if self .critic_coef is not None :
986992 loss_critic , value_clip_fraction = self .loss_critic (tensordict )
@@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12821288 {
12831289 "loss_objective" : - neg_loss ,
12841290 "kl" : kl .detach (),
1291+ "kl_approx" : kl_approx .detach ().mean (),
12851292 },
12861293 batch_size = [],
12871294 )
12881295
12891296 if self .entropy_bonus :
1290- entropy = self .get_entropy_bonus (dist )
1297+ entropy = self ._get_entropy (dist )
1298+ if is_tensor_collection (entropy ):
1299+ # Reports the entropy of each action head.
1300+ td_out .set ("composite_entropy" , entropy .detach ())
1301+ entropy = _sum_td_features (entropy )
12911302 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1292- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
12931303 td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
12941304 if self .critic_coef is not None :
12951305 loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
0 commit comments