Skip to content

Commit 95b2c19

Browse files
authored
Checks for documentation (#91)
1 parent 8838301 commit 95b2c19

File tree

10 files changed

+47
-31
lines changed

10 files changed

+47
-31
lines changed

requirements.dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pytest
22
pytest-runner
33
hypothesis == 4.38
44
flake8
5+
darglint
56
black
67
pep8-naming
78
dgl

tests/test_distributions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
lint = integers(min_value=2, max_value=10)
1010

1111

12-
def enumerate_support(self, expand=True):
12+
def enumerate_support(dist):
1313
"""
1414
Compute the full exponential enumeration set.
1515
16+
Parameters:
17+
dist : Distribution
18+
1619
Returns:
1720
(enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*)
1821
"""
19-
_, _, edges, enum_lengths = test_lookup[self.struct]().enumerate(
20-
self.log_potentials, self.lengths
22+
_, _, edges, enum_lengths = test_lookup[dist.struct]().enumerate(
23+
dist.log_potentials, dist.lengths
2124
)
2225
# if expand:
2326
# edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:])

torch_struct/autoregressive.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def log_prob(self, value, sparse=False):
7373
7474
Parameters:
7575
value (tensor): One-hot events (*sample_shape x batch_shape x N*)
76+
sparse (bool): implement sparse
7677
7778
Returns:
7879
log_probs (*sample_shape x batch_shape*)
@@ -207,6 +208,9 @@ def greedy_tempmax(self, alpha):
207208
208209
* Differentiable Scheduled Sampling for Credit Assignment :cite:`goyal2017differentiable`
209210
211+
Parameters:
212+
alpha : alpha param
213+
210214
Returns:
211215
greedy_path (*batch x N x C*)
212216
greedy_max (*batch*)
@@ -219,6 +223,9 @@ def beam_topk(self, K):
219223
"""
220224
Compute "top-k" using beam search
221225
226+
Parameters:
227+
K : top-k
228+
222229
Returns:
223230
paths (*K x batch x N x C*)
224231

torch_struct/cky.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False):
8686
scores : terms : b x n x T
8787
rules : b x NT x (NT+T) x (NT+T)
8888
root: b x NT
89-
lengths :
89+
lengths : lengths in batch
9090
9191
Returns:
9292
v: b tensor of total sum

torch_struct/deptree.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def to_parts(sequence, extra=None, lengths=None):
135135
136136
Parameters:
137137
sequence : b x N long tensor in [0, N] (indexing is +1)
138+
extra : None
139+
lengths : lengths of sequences
140+
138141
Returns:
139142
arcs : b x N x N arc indicators
140143
"""
@@ -156,6 +159,7 @@ def from_parts(arcs):
156159
157160
Parameters:
158161
arcs : b x N x N arc indicators
162+
159163
Returns:
160164
sequence : b x N long tensor in [0, N] (indexing is +1)
161165
"""
@@ -212,7 +216,9 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
212216
213217
Parameters:
214218
arc_scores : b x N x N arc scores with root scores on diagonal.
215-
semiring
219+
multi_root (bool) : multiple roots
220+
lengths : length of examples
221+
eps (float) : given
216222
217223
Returns:
218224
arc_marginals : b x N x N.

torch_struct/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,8 @@ def argmax(self):
500500
501501
(Currently not implemented)
502502
"""
503-
raise NotImplementedError()
503+
pass
504504

505505
@lazy_property
506506
def entropy(self):
507-
raise NotImplementedError()
507+
pass

torch_struct/helpers.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def _make_chart(self, N, size, potentials, force_grad=False):
6464
for _ in range(N)
6565
]
6666

67-
def sum(self, edge, lengths=None, _autograd=True, _raw=False):
67+
def sum(self, logpotentials, lengths=None, _autograd=True, _raw=False):
6868
"""
6969
Compute the (semiring) sum over all structures model.
7070
7171
Parameters:
72-
params : generic params (see class)
72+
logpotentials : generic params (see class)
7373
lengths: None or b long tensor mask
7474
7575
Returns:
@@ -82,13 +82,13 @@ def sum(self, edge, lengths=None, _autograd=True, _raw=False):
8282
or not hasattr(self, "_dp_backward")
8383
):
8484

85-
v = self._dp(edge, lengths)[0]
85+
v = self._dp(logpotentials, lengths)[0]
8686
if _raw:
8787
return v
8888
return self.semiring.unconvert(v)
8989

9090
else:
91-
v, _, alpha = self._dp(edge, lengths, False)
91+
v, _, alpha = self._dp(logpotentials, lengths, False)
9292

9393
class DPManual(Function):
9494
@staticmethod
@@ -97,20 +97,23 @@ def forward(ctx, input):
9797

9898
@staticmethod
9999
def backward(ctx, grad_v):
100-
marginals = self._dp_backward(edge, lengths, alpha)
100+
marginals = self._dp_backward(logpotentials, lengths, alpha)
101101
return marginals.mul(
102102
grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))
103103
)
104104

105-
return DPManual.apply(edge)
105+
return DPManual.apply(logpotentials)
106106

107-
def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
107+
def marginals(
108+
self, logpotentials, lengths=None, _autograd=True, _raw=False, _combine=False
109+
):
108110
"""
109111
Compute the marginals of a structured model.
110112
111113
Parameters:
112-
params : generic params (see class)
114+
logpotentials : generic params (see class)
113115
lengths: None or b long tensor mask
116+
114117
Returns:
115118
marginals: b x (N-1) x C x C table
116119
@@ -120,7 +123,7 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=Fal
120123
or self.semiring is not LogSemiring
121124
or not hasattr(self, "_dp_backward")
122125
):
123-
v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True)
126+
v, edges, _ = self._dp(logpotentials, lengths=lengths, force_grad=True)
124127
if _raw:
125128
all_m = []
126129
for k in range(v.shape[0]):
@@ -150,8 +153,8 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=Fal
150153
a_m = self._arrange_marginals(marg)
151154
return self.semiring.unconvert(a_m)
152155
else:
153-
v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
154-
return self._dp_backward(edge, lengths, alpha)
156+
v, _, alpha = self._dp(logpotentials, lengths=lengths, force_grad=True)
157+
return self._dp_backward(logpotentials, lengths, alpha)
155158

156159
@staticmethod
157160
def to_parts(spans, extra, lengths=None):

torch_struct/linearchain.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def to_parts(sequence, extra, lengths=None):
9292
9393
Parameters:
9494
sequence : b x N long tensor in [0, C-1]
95-
C : number of states
95+
extra : number of states
9696
lengths: b long tensor of N values
97+
9798
Returns:
9899
edge : b x (N-1) x C x C markov indicators
99100
(t x z_t x z_{t-1})
@@ -117,6 +118,7 @@ def from_parts(edge):
117118
Parameters:
118119
edge : b x (N-1) x C x C markov indicators
119120
(t x z_t x z_{t-1})
121+
120122
Returns:
121123
sequence : b x N long tensor in [0, C-1]
122124
"""

torch_struct/semimarkov.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def to_parts(sequence, extra, lengths=None):
120120
121121
Parameters:
122122
sequence : b x N long tensors in [-1, 0, C-1]
123-
C : number of states
123+
extra : number of states
124124
lengths: b long tensor of N values
125+
125126
Returns:
126127
edge : b x (N-1) x K x C x C semimarkov potentials
127128
(t x z_t x z_{t-1})
@@ -155,6 +156,7 @@ def from_parts(edge):
155156
Parameters:
156157
edge : b x (N-1) x K x C x C semimarkov potentials
157158
(t x z_t x z_{t-1})
159+
158160
Returns:
159161
sequence : b x N long tensors in [-1, 0, C-1]
160162

torch_struct/semirings/semirings.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,8 @@ def sum(xs, dim=-1):
148148
return torch.sum(xs, dim=dim)
149149

150150
@classmethod
151-
def matmul(cls, a, b, dims=1):
152-
"""
153-
Dot product along last dim.
154-
155-
(Faster than calling sum and times.)
156-
"""
151+
def matmul(cls, a, b):
152+
"Dot product along last dim"
157153

158154
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
159155
return b.multiply(a.transpose())
@@ -201,11 +197,7 @@ def sparse_sum(xs, dim=-1):
201197

202198

203199
def KMaxSemiring(k):
204-
"""
205-
Implements the k-max semiring (kmax, +, [-inf, -inf..], [0, -inf, ...]).
206-
207-
Gradients give k-argmax.
208-
"""
200+
"Implements the k-max semiring (kmax, +, [-inf, -inf..], [0, -inf, ...])."
209201

210202
class KMaxSemiring(_BaseLog):
211203
@staticmethod

0 commit comments

Comments
 (0)