Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation rule for tridiagonal_solve #25693

Open
segasai opened this issue Dec 29, 2024 · 2 comments
Open

Differentiation rule for tridiagonal_solve #25693

segasai opened this issue Dec 29, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@segasai
Copy link

segasai commented Dec 29, 2024

Hi,

I don't think it's bug report, but a feature request. I tried to recently switch away from my own implementation of tridiagonal_solve using Thomas algorithm to the jax.lax.tridiagonal_solve and I discovered that the tridiagonal_solve implementation in jax does not seem to support differentiation. The code (given below) gives this error:

  File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
  File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 395, in process_primitive
    raise NotImplementedError(msg)
NotImplementedError: Differentiation rule for 'tridiagonal_solve' not implemented

I also did not see any mention of not supported differentiation in the docs.

I use python 3.10, jax 0.38 on CPU.

If the error if not an oversight, and it is not too difficult to implement differentiation tridiagonal_solve, I could maybe take a look at doing that, if I get some pointers where to look.

Thanks

The reproducer (together with my own implementation of tridiagonal solve that does support differentiation).

import jax
import jax.numpy as jnp
import numpy as np
import jax.lax.linalg as jll


def func1(carry, x):
    # forward loop of the thomas algo
    oldc_, oldd_ = carry
    a, b, c, d = x
    cdash = c / (b - a * oldc_)
    ddash = (d - a * oldd_) / (b - a * oldc_)
    return (cdash, ddash), (cdash, ddash)


def func2(carry, xx):
    # backwards loop of thomas algo
    xold = carry
    c, d = xx
    xnew = (d - c * xold)
    return xnew, xnew


def solver(a, b, c, d):
    """
    Solve A x = d where A consists of (a,b,c) on lower/mid/upper diagonal
    uses Thomas algorithm 
    """
    cdnew = jax.lax.scan(func1, (0, 0),
                         jnp.transpose(jnp.vstack((a, b, c, d)), (1, 0)))[1]
    xnew = jax.lax.scan(
        func2, (0),
        jnp.transpose(jnp.vstack((cdnew[0], cdnew[1])), (1, 0))[::-1, :])[-1]
    return xnew[::-1]


if __name__ == '__main__':
    np.random.seed(43)
    a, b, c, d = np.random.normal(size=(4, 3))
    a[0] = 0
    c[-1] = 0
    print('x', solver(a, b, c, d))
    a, b, c, d = [jnp.array(_) for _ in [a, b, c, d]]
    print('y', jll.tridiagonal_solve(a, b, c, d[:, None])[:, 0])

    def func(x):
        ret = jnp.sum(jll.tridiagonal_solve(a, b, c, x[:, None]))
        return ret

    JG = jax.grad(func, [0])
    JG(d)
@segasai segasai added the enhancement New feature or request label Dec 29, 2024
@patrick-kidger
Copy link
Collaborator

Take a look at Lineax, which should support this :)

@dfm dfm self-assigned this Jan 2, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 2, 2025

@segasai — Thanks for bringing this up! I agree that it makes sense for JAX to support this directly, and I don't expect it would be too complicated to add. If you're keen to make a PR, that would be awesome and I'll add some pointers below. I'm also happy to take a stab at adding it, depending on your level of enthusiasm :D

To add AD support to that primitive, you'll need to add a JVP rule (it looks like there is already a transpose rule!) here:

jax/jax/_src/lax/linalg.py

Lines 2457 to 2464 in dbe9ccd

tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
ad.primitive_transposes[tridiagonal_solve_p] = _tridiagonal_solve_transpose_rule
batching.primitive_batchers[tridiagonal_solve_p] = _tridiagonal_solve_batching_rule
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?

There's a TODO there for @tomhennigan to add AD using jax.lax.custom_linear_solve (I don't expect it's high on his to-do list :D), but I'd probably just implement the JVP rule directly. Take a look at the JVP rule for triangular_solve here:

jax/jax/_src/lax/linalg.py

Lines 1310 to 1312 in dbe9ccd

ad.defjvp2(triangular_solve_p,
_triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))

and here:

jax/jax/_src/lax/linalg.py

Lines 1230 to 1261 in dbe9ccd

def _triangular_solve_jvp_rule_a(
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
m, n = b.shape[-2:]
k = 1 if unit_diagonal else 0
g_a = _tril(g_a, k=-k) if lower else _triu(g_a, k=k)
g_a = lax.neg(g_a)
g_a = _T(g_a) if transpose_a else g_a
g_a = g_a.conj() if conjugate_a else g_a
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)
def a_inverse(rhs):
return triangular_solve(a, rhs, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
# for matrix/vector inputs). Order these operations in whichever order is
# cheaper.
if left_side:
assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
if m > n:
return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X)
else:
return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X
else:
assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
if m < n:
return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1}
else:
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})

for an idea of what that might look like.

Tests would go somewhere close to here:

def testTridiagonal(self, shape, dtype, lower):

Again, here's where the AD tests for triangular solve live as a reference:

jax/tests/linalg_test.py

Lines 1610 to 1643 in dbe9ccd

@jtu.sample_product(
[dict(left_side=left_side, a_shape=a_shape, b_shape=b_shape)
for left_side, a_shape, b_shape in [
(False, (4, 4), (4,)),
(False, (4, 4), (1, 4,)),
(False, (3, 3), (4, 3)),
(True, (4, 4), (4,)),
(True, (4, 4), (4, 1)),
(True, (4, 4), (4, 3)),
(True, (2, 8, 8), (2, 8, 10)),
]
],
[dict(dtype=dtype, conjugate_a=conjugate_a)
for dtype in float_types + complex_types
for conjugate_a in (
[False] if jnp.issubdtype(dtype, jnp.floating) else [False, True])
],
lower=[False, True],
unit_diagonal=[False, True],
transpose_a=[False, True],
)
def testTriangularSolveGrad(
self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
b_shape, dtype):
rng = jtu.rand_default(self.rng())
# Test lax.linalg.triangular_solve instead of scipy.linalg.solve_triangular
# because it exposes more options.
A = jnp.tril(rng(a_shape, dtype) + 5 * np.eye(a_shape[-1], dtype=dtype))
A = A if lower else T(A)
B = rng(b_shape, dtype)
f = partial(lax.linalg.triangular_solve, lower=lower, transpose_a=transpose_a,
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal,
left_side=left_side)
jtu.check_grads(f, (A, B), order=1, rtol=4e-2, eps=1e-3)

Let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants