You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There are operators that lead to hybrid jaxpr graphs made of both function nodes and jax.ShapedArray nodes. Examples are:
operators.partial, this operator takes a GeneralArray and several args which are jax.ShapedArray, and create a partial function.
operators.integrate, this operator takes a GeneralArray, when all the args are being integrated, the returning value is a scalar of jax.ShapedArray type.
There are some difficulties in supporting both in the same graph, because in the vjp graph, we may need to convert between GeneralArray and ShapedArray. In the partial case, the vjp needs to integrate the cotangent for all the args that are not bound, and evaluate the resulting function at the bound args. We may need to introduce special functions / operators like the delta function.
The text was updated successfully, but these errors were encountered:
There are operators that lead to hybrid jaxpr graphs made of both
function
nodes andjax.ShapedArray
nodes. Examples are:operators.partial
, this operator takes aGeneralArray
and severalargs
which arejax.ShapedArray
, and create apartial
function.operators.integrate
, this operator takes aGeneralArray
, when all the args are being integrated, the returning value is a scalar ofjax.ShapedArray
type.There are some difficulties in supporting both in the same graph, because in the vjp graph, we may need to convert between
GeneralArray
andShapedArray
. In thepartial
case, the vjp needs to integrate thecotangent
for all the args that are not bound, and evaluate the resulting function at the bound args. We may need to introduce special functions / operators like the delta function.The text was updated successfully, but these errors were encountered: