Skip to content

Commit 8838301

Browse files
committed
fix spacing issues
1 parent 9083dd4 commit 8838301

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

torch_struct/deptree.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
182182
x = x.unsqueeze(2).expand(-1, -1, N)
183183
mask = torch.transpose(x, 1, 2) * x
184184
mask = mask.float()
185-
mask[mask==0] = float('-inf')
186-
mask[mask==1] = 0
185+
mask[mask == 0] = float("-inf")
186+
mask[mask == 1] = 0
187187
arc_scores = arc_scores + mask
188188
input = arc_scores
189189
eye = torch.eye(input.shape[1], device=input.device)
@@ -194,13 +194,13 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
194194
lap += det_offset
195195

196196
if multi_root:
197-
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
197+
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
198198
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
199199
else:
200200
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
201201
return lap.logdet()
202-
203-
202+
203+
204204
def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
205205
"""
206206
Compute the marginals of a non-projective dependency tree using the
@@ -228,10 +228,10 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
228228
x = x.unsqueeze(2).expand(-1, -1, N)
229229
mask = torch.transpose(x, 1, 2) * x
230230
mask = mask.float()
231-
mask[mask==0] = float('-inf')
232-
mask[mask==1] = 0
231+
mask[mask == 0] = float("-inf")
232+
mask[mask == 1] = 0
233233
arc_scores = arc_scores + mask
234-
234+
235235
input = arc_scores
236236
eye = torch.eye(input.shape[1], device=input.device)
237237
laplacian = input.exp() + eps
@@ -241,7 +241,7 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
241241
lap += det_offset
242242

243243
if multi_root:
244-
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
244+
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
245245
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
246246
inv_laplacian = lap.inverse()
247247
factor = (
@@ -254,7 +254,9 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
254254
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
255255
output = term1 - term2
256256
roots_output = (
257-
torch.diagonal(input, 0, -2, -1).exp().mul(torch.diagonal(inv_laplacian.transpose(1, 2), 0, -2, -1))
257+
torch.diagonal(input, 0, -2, -1)
258+
.exp()
259+
.mul(torch.diagonal(inv_laplacian.transpose(1, 2), 0, -2, -1))
258260
)
259261
else:
260262
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
@@ -271,7 +273,9 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
271273
term2[:, 0] = 0
272274
output = term1 - term2
273275
roots_output = (
274-
torch.diagonal(input, 0, -2, -1).exp().mul(inv_laplacian.transpose(1, 2)[:, 0])
276+
torch.diagonal(input, 0, -2, -1)
277+
.exp()
278+
.mul(inv_laplacian.transpose(1, 2)[:, 0])
275279
)
276280
output = output + torch.diag_embed(roots_output, 0, -2, -1)
277281
return output

torch_struct/distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,10 @@ class NonProjectiveDependencyCRF(StructDistribution):
466466
Note: Does not currently implement argmax (Chiu-Liu) or sampling.
467467
468468
"""
469+
469470
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
470471
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
471472
self.multiroot = multiroot
472-
473473

474474
@lazy_property
475475
def marginals(self):

0 commit comments

Comments
 (0)