Skip to content

Commit b942a48

Browse files
authored
Add max gap for alignment (#41)
1 parent 34517ec commit b942a48

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

torch_struct/alignment.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414

1515

1616
class Alignment(_Struct):
17-
def __init__(self, semiring=LogSemiring, sparse_rounds=3, local=False):
17+
def __init__(
18+
self, semiring=LogSemiring, sparse_rounds=3, max_gap=None, local=False
19+
):
1820
self.semiring = semiring
1921
self.sparse_rounds = sparse_rounds
2022
self.local = local
23+
self.max_gap = max_gap
2124

2225
def _check_potentials(self, edge, lengths=None):
2326
batch, N_1, M_1, x = edge.shape
@@ -171,6 +174,9 @@ def pad(v):
171174

172175
for n in range(2, log_MN + 1):
173176
chart = merge(chart)
177+
center = (chart.shape[-1] - 1) // 2
178+
if self.max_gap is not None and center > self.max_gap:
179+
chart = chart[..., center - self.max_gap : center + self.max_gap + 1]
174180

175181
if self.local:
176182
v = semiring.sum(semiring.sum(chart[..., 0, Close, Close, Mid, :, :]))

torch_struct/semirings/checkpoint.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def forward(ctx, a, b):
2727

2828
@staticmethod
2929
def backward(ctx, grad_output):
30-
print("check", grad_output.shape)
3130
a, b = ctx.saved_tensors
3231
with torch.enable_grad():
3332
q = cls.matmul(a, b)
@@ -43,7 +42,6 @@ def forward(ctx, a, a_lu, a_ld, b, b_lu, b_ld):
4342

4443
@staticmethod
4544
def backward(ctx, grad_output):
46-
print("check_sparse", grad_output.shape)
4745
a, b, bands = ctx.saved_tensors
4846
a_lu, a_ld, b_lu, b_ld = bands.tolist()
4947
with torch.enable_grad():

torch_struct/test_algorithms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def test_alignment(data):
365365
print(alpha, count)
366366
print(mx[0].nonzero())
367367
# assert torch.isclose(count, alpha).all()
368+
struct = model(semiring, max_gap=1)
369+
alpha = struct.sum(vals)
368370

369371

370372
def test_hmm():

0 commit comments

Comments
 (0)