Skip to content

Commit

Permalink
Merge pull request #5 from braun-steven/fix/categorical-suff-stats-dim-3
Browse files Browse the repository at this point in the history
Fix missing num_var parameter in suff_stat reshape
  • Loading branch information
smatmo authored Aug 22, 2023
2 parents dfeb10a + bbf6a6d commit 6d963e7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/EinsumNetwork/ExponentialFamilyArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def sufficient_statistics(self, x):
if len(x.shape) == 2:
stats = one_hot(x.long(), self.K)
elif len(x.shape) == 3:
stats = one_hot(x.long(), self.K).reshape(-1, self.num_dims * self.K)
stats = one_hot(x.long(), self.K).reshape(-1, self.num_var, self.num_dims * self.K)
else:
raise AssertionError("Input must be 2 or 3 dimensional tensor.")
return stats
Expand Down

0 comments on commit 6d963e7

Please sign in to comment.