Skip to content

Commit d2bd688

Browse files
Ervin TChris Elion
andauthored
Cherry-pick NaN fix for Match 3 (#4664)
* match3 settings * Add epsilon to log * Add another epsilon * Revert match3 configs * NaN-free masking method * Add comment for paper * Add comment for paper Co-authored-by: Chris Elion <[email protected]> Co-authored-by: Chris Elion <[email protected]>
1 parent 06cdf39 commit d2bd688

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

ml-agents/mlagents/trainers/torch/distributions.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@ def pdf(self, value):
112112
).squeeze(-1)
113113

114114
def log_prob(self, value):
115-
return torch.log(self.pdf(value))
115+
return torch.log(self.pdf(value) + EPSILON)
116116

117117
def all_log_prob(self):
118-
return torch.log(self.probs)
118+
return torch.log(self.probs + EPSILON)
119119

120120
def entropy(self):
121-
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
121+
return -torch.sum(self.probs * torch.log(self.probs + EPSILON), dim=-1)
122122

123123

124124
class GaussianDistribution(nn.Module):
@@ -187,10 +187,13 @@ def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
187187
return nn.ModuleList(branches)
188188

189189
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
190-
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
191-
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1)
192-
normalized_logits = torch.log(normalized_probs + EPSILON)
193-
return normalized_logits
190+
# Zero out masked logits, then subtract a large value. Technique mentionend here:
191+
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly.
192+
flipped_mask = 1.0 - mask
193+
adj_logits = logits * mask - 1e8 * flipped_mask
194+
probs = torch.nn.functional.softmax(adj_logits, dim=-1)
195+
log_probs = torch.log(probs + EPSILON)
196+
return log_probs
194197

195198
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
196199
split_masks = []

0 commit comments

Comments
 (0)