-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Description
if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized)
self.embed_avg has the possibility of 0, and 0 will appear when calculating cluster_size. If the division by 0 situation occurs, you should add a piece of code
cluster_size = torch.clamp(cluster_size, min=epsilon)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels