Skip to content

Differentiation rule for tridiagonal_solve #25693

Closed
@segasai

Description

@segasai

Hi,

I don't think it's bug report, but a feature request. I tried to recently switch away from my own implementation of tridiagonal_solve using Thomas algorithm to the jax.lax.tridiagonal_solve and I discovered that the tridiagonal_solve implementation in jax does not seem to support differentiation. The code (given below) gives this error:

  File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
  File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 395, in process_primitive
    raise NotImplementedError(msg)
NotImplementedError: Differentiation rule for 'tridiagonal_solve' not implemented

I also did not see any mention of not supported differentiation in the docs.

I use python 3.10, jax 0.38 on CPU.

If the error if not an oversight, and it is not too difficult to implement differentiation tridiagonal_solve, I could maybe take a look at doing that, if I get some pointers where to look.

Thanks

The reproducer (together with my own implementation of tridiagonal solve that does support differentiation).

import jax
import jax.numpy as jnp
import numpy as np
import jax.lax.linalg as jll


def func1(carry, x):
    # forward loop of the thomas algo
    oldc_, oldd_ = carry
    a, b, c, d = x
    cdash = c / (b - a * oldc_)
    ddash = (d - a * oldd_) / (b - a * oldc_)
    return (cdash, ddash), (cdash, ddash)


def func2(carry, xx):
    # backwards loop of thomas algo
    xold = carry
    c, d = xx
    xnew = (d - c * xold)
    return xnew, xnew


def solver(a, b, c, d):
    """
    Solve A x = d where A consists of (a,b,c) on lower/mid/upper diagonal
    uses Thomas algorithm 
    """
    cdnew = jax.lax.scan(func1, (0, 0),
                         jnp.transpose(jnp.vstack((a, b, c, d)), (1, 0)))[1]
    xnew = jax.lax.scan(
        func2, (0),
        jnp.transpose(jnp.vstack((cdnew[0], cdnew[1])), (1, 0))[::-1, :])[-1]
    return xnew[::-1]


if __name__ == '__main__':
    np.random.seed(43)
    a, b, c, d = np.random.normal(size=(4, 3))
    a[0] = 0
    c[-1] = 0
    print('x', solver(a, b, c, d))
    a, b, c, d = [jnp.array(_) for _ in [a, b, c, d]]
    print('y', jll.tridiagonal_solve(a, b, c, d[:, None])[:, 0])

    def func(x):
        ret = jnp.sum(jll.tridiagonal_solve(a, b, c, x[:, None]))
        return ret

    JG = jax.grad(func, [0])
    JG(d)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions