@@ -182,8 +182,8 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
182
182
x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
183
183
mask = torch .transpose (x , 1 , 2 ) * x
184
184
mask = mask .float ()
185
- mask [mask == 0 ] = float (' -inf' )
186
- mask [mask == 1 ] = 0
185
+ mask [mask == 0 ] = float (" -inf" )
186
+ mask [mask == 1 ] = 0
187
187
arc_scores = arc_scores + mask
188
188
input = arc_scores
189
189
eye = torch .eye (input .shape [1 ], device = input .device )
@@ -194,13 +194,13 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
194
194
lap += det_offset
195
195
196
196
if multi_root :
197
- rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
197
+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
198
198
lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
199
199
else :
200
200
lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
201
201
return lap .logdet ()
202
-
203
-
202
+
203
+
204
204
def deptree_nonproj (arc_scores , multi_root , lengths , eps = 1e-5 ):
205
205
"""
206
206
Compute the marginals of a non-projective dependency tree using the
@@ -228,10 +228,10 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
228
228
x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
229
229
mask = torch .transpose (x , 1 , 2 ) * x
230
230
mask = mask .float ()
231
- mask [mask == 0 ] = float (' -inf' )
232
- mask [mask == 1 ] = 0
231
+ mask [mask == 0 ] = float (" -inf" )
232
+ mask [mask == 1 ] = 0
233
233
arc_scores = arc_scores + mask
234
-
234
+
235
235
input = arc_scores
236
236
eye = torch .eye (input .shape [1 ], device = input .device )
237
237
laplacian = input .exp () + eps
@@ -241,7 +241,7 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
241
241
lap += det_offset
242
242
243
243
if multi_root :
244
- rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
244
+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
245
245
lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
246
246
inv_laplacian = lap .inverse ()
247
247
factor = (
@@ -254,7 +254,9 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
254
254
term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
255
255
output = term1 - term2
256
256
roots_output = (
257
- torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (torch .diagonal (inv_laplacian .transpose (1 , 2 ), 0 , - 2 , - 1 ))
257
+ torch .diagonal (input , 0 , - 2 , - 1 )
258
+ .exp ()
259
+ .mul (torch .diagonal (inv_laplacian .transpose (1 , 2 ), 0 , - 2 , - 1 ))
258
260
)
259
261
else :
260
262
lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
@@ -271,7 +273,9 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
271
273
term2 [:, 0 ] = 0
272
274
output = term1 - term2
273
275
roots_output = (
274
- torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
276
+ torch .diagonal (input , 0 , - 2 , - 1 )
277
+ .exp ()
278
+ .mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
275
279
)
276
280
output = output + torch .diag_embed (roots_output , 0 , - 2 , - 1 )
277
281
return output
0 commit comments