-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiation rule for tridiagonal_solve #25693
Comments
Take a look at Lineax, which should support this :) |
@segasai — Thanks for bringing this up! I agree that it makes sense for JAX to support this directly, and I don't expect it would be too complicated to add. If you're keen to make a PR, that would be awesome and I'll add some pointers below. I'm also happy to take a stab at adding it, depending on your level of enthusiasm :D To add AD support to that primitive, you'll need to add a JVP rule (it looks like there is already a transpose rule!) here: Lines 2457 to 2464 in dbe9ccd
There's a TODO there for @tomhennigan to add AD using Lines 1310 to 1312 in dbe9ccd
and here: Lines 1230 to 1261 in dbe9ccd
for an idea of what that might look like. Tests would go somewhere close to here: Line 1787 in dbe9ccd
Again, here's where the AD tests for triangular solve live as a reference: Lines 1610 to 1643 in dbe9ccd
Let me know what you think! |
Hi,
I don't think it's bug report, but a feature request. I tried to recently switch away from my own implementation of tridiagonal_solve using Thomas algorithm to the jax.lax.tridiagonal_solve and I discovered that the tridiagonal_solve implementation in jax does not seem to support differentiation. The code (given below) gives this error:
I also did not see any mention of not supported differentiation in the docs.
I use python 3.10, jax 0.38 on CPU.
If the error if not an oversight, and it is not too difficult to implement differentiation tridiagonal_solve, I could maybe take a look at doing that, if I get some pointers where to look.
Thanks
The reproducer (together with my own implementation of tridiagonal solve that does support differentiation).
The text was updated successfully, but these errors were encountered: