Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixing tensor & function tracing #14

Open
mavenlin opened this issue Jan 17, 2025 · 0 comments
Open

Mixing tensor & function tracing #14

mavenlin opened this issue Jan 17, 2025 · 0 comments

Comments

@mavenlin
Copy link
Member

There are operators that lead to hybrid jaxpr graphs made of both function nodes and jax.ShapedArray nodes. Examples are:

  1. operators.partial, this operator takes a GeneralArray and several args which are jax.ShapedArray, and create a partial function.
  2. operators.integrate, this operator takes a GeneralArray, when all the args are being integrated, the returning value is a scalar of jax.ShapedArray type.

There are some difficulties in supporting both in the same graph, because in the vjp graph, we may need to convert between GeneralArray and ShapedArray. In the partial case, the vjp needs to integrate the cotangent for all the args that are not bound, and evaluate the resulting function at the bound args. We may need to introduce special functions / operators like the delta function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant