@@ -64,12 +64,12 @@ def _make_chart(self, N, size, potentials, force_grad=False):
64
64
for _ in range (N )
65
65
]
66
66
67
- def sum (self , edge , lengths = None , _autograd = True , _raw = False ):
67
+ def sum (self , logpotentials , lengths = None , _autograd = True , _raw = False ):
68
68
"""
69
69
Compute the (semiring) sum over all structures model.
70
70
71
71
Parameters:
72
- params : generic params (see class)
72
+ logpotentials : generic params (see class)
73
73
lengths: None or b long tensor mask
74
74
75
75
Returns:
@@ -82,13 +82,13 @@ def sum(self, edge, lengths=None, _autograd=True, _raw=False):
82
82
or not hasattr (self , "_dp_backward" )
83
83
):
84
84
85
- v = self ._dp (edge , lengths )[0 ]
85
+ v = self ._dp (logpotentials , lengths )[0 ]
86
86
if _raw :
87
87
return v
88
88
return self .semiring .unconvert (v )
89
89
90
90
else :
91
- v , _ , alpha = self ._dp (edge , lengths , False )
91
+ v , _ , alpha = self ._dp (logpotentials , lengths , False )
92
92
93
93
class DPManual (Function ):
94
94
@staticmethod
@@ -97,20 +97,23 @@ def forward(ctx, input):
97
97
98
98
@staticmethod
99
99
def backward (ctx , grad_v ):
100
- marginals = self ._dp_backward (edge , lengths , alpha )
100
+ marginals = self ._dp_backward (logpotentials , lengths , alpha )
101
101
return marginals .mul (
102
102
grad_v .view ((grad_v .shape [0 ],) + tuple ([1 ] * marginals .dim ()))
103
103
)
104
104
105
- return DPManual .apply (edge )
105
+ return DPManual .apply (logpotentials )
106
106
107
- def marginals (self , edge , lengths = None , _autograd = True , _raw = False , _combine = False ):
107
+ def marginals (
108
+ self , logpotentials , lengths = None , _autograd = True , _raw = False , _combine = False
109
+ ):
108
110
"""
109
111
Compute the marginals of a structured model.
110
112
111
113
Parameters:
112
- params : generic params (see class)
114
+ logpotentials : generic params (see class)
113
115
lengths: None or b long tensor mask
116
+
114
117
Returns:
115
118
marginals: b x (N-1) x C x C table
116
119
@@ -120,7 +123,7 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=Fal
120
123
or self .semiring is not LogSemiring
121
124
or not hasattr (self , "_dp_backward" )
122
125
):
123
- v , edges , _ = self ._dp (edge , lengths = lengths , force_grad = True )
126
+ v , edges , _ = self ._dp (logpotentials , lengths = lengths , force_grad = True )
124
127
if _raw :
125
128
all_m = []
126
129
for k in range (v .shape [0 ]):
@@ -150,8 +153,8 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=Fal
150
153
a_m = self ._arrange_marginals (marg )
151
154
return self .semiring .unconvert (a_m )
152
155
else :
153
- v , _ , alpha = self ._dp (edge , lengths = lengths , force_grad = True )
154
- return self ._dp_backward (edge , lengths , alpha )
156
+ v , _ , alpha = self ._dp (logpotentials , lengths = lengths , force_grad = True )
157
+ return self ._dp_backward (logpotentials , lengths , alpha )
155
158
156
159
@staticmethod
157
160
def to_parts (spans , extra , lengths = None ):
0 commit comments