Description
Hi,
I don't think it's bug report, but a feature request. I tried to recently switch away from my own implementation of tridiagonal_solve using Thomas algorithm to the jax.lax.tridiagonal_solve and I discovered that the tridiagonal_solve implementation in jax does not seem to support differentiation. The code (given below) gives this error:
File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
File "/home/koposov/pyenv310/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 395, in process_primitive
raise NotImplementedError(msg)
NotImplementedError: Differentiation rule for 'tridiagonal_solve' not implemented
I also did not see any mention of not supported differentiation in the docs.
I use python 3.10, jax 0.38 on CPU.
If the error if not an oversight, and it is not too difficult to implement differentiation tridiagonal_solve, I could maybe take a look at doing that, if I get some pointers where to look.
Thanks
The reproducer (together with my own implementation of tridiagonal solve that does support differentiation).
import jax
import jax.numpy as jnp
import numpy as np
import jax.lax.linalg as jll
def func1(carry, x):
# forward loop of the thomas algo
oldc_, oldd_ = carry
a, b, c, d = x
cdash = c / (b - a * oldc_)
ddash = (d - a * oldd_) / (b - a * oldc_)
return (cdash, ddash), (cdash, ddash)
def func2(carry, xx):
# backwards loop of thomas algo
xold = carry
c, d = xx
xnew = (d - c * xold)
return xnew, xnew
def solver(a, b, c, d):
"""
Solve A x = d where A consists of (a,b,c) on lower/mid/upper diagonal
uses Thomas algorithm
"""
cdnew = jax.lax.scan(func1, (0, 0),
jnp.transpose(jnp.vstack((a, b, c, d)), (1, 0)))[1]
xnew = jax.lax.scan(
func2, (0),
jnp.transpose(jnp.vstack((cdnew[0], cdnew[1])), (1, 0))[::-1, :])[-1]
return xnew[::-1]
if __name__ == '__main__':
np.random.seed(43)
a, b, c, d = np.random.normal(size=(4, 3))
a[0] = 0
c[-1] = 0
print('x', solver(a, b, c, d))
a, b, c, d = [jnp.array(_) for _ in [a, b, c, d]]
print('y', jll.tridiagonal_solve(a, b, c, d[:, None])[:, 0])
def func(x):
ret = jnp.sum(jll.tridiagonal_solve(a, b, c, x[:, None]))
return ret
JG = jax.grad(func, [0])
JG(d)