-
Notifications
You must be signed in to change notification settings - Fork 934
Define ufunc JO and JTO simultaneously #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
I think this obviates the changes in HIPS#292.
This should minimize memory overhead.
autograd/numpy/numpy_jvps.py
Outdated
| unbroadcast_f(args[argnum], lambda g: -g)), | ||
| 'mul': (lambda argnum, deriv: lambda g, ans, *args: g * deriv(ans, *args), | ||
| lambda argnum, deriv: lambda ans, *args: | ||
| unbroadcast_f(args[argnum], lambda g, d=deriv(ans, *args): g * d)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the vjps I've used this slightly weird d=deriv(ans, *args) default argument syntax to ensure that deriv is evaluated during the forward pass, allowing *args and ans to potentially be garbage collected.
Any objections? I could also have done this using a kind of helper closure to evaluate deriv, which would have been a bit more explicit.
Also move broadcast_to_adjoint into numpy_wrapper.
|
Have added a couple of benchmarks and run a bench compare. Almost everything is the same but there are some differences from compute which I've shifted from the backward pass to the forward pass:
|
|
I'm wondering whether, for consistency, all of the extra numpy-ish primitives that we define (things like the dot and tensordot adjoints) should be in numpy_wrapper, alongside things like make_diagonal and broadcast_to_adjoint. They can be viewed as extra primitives that we want to add to numpy (primitives which happen to be useful for calculating derivatives), so perhaps it makes more sense for them to be there. |
86820fd to
2f6cc22
Compare
To possibly do:
Summary of the changes in this pr
def_ufunc_jpsfor defining the jvp and vjp of a ufunc in one shot. The linesdef_ufunc_jpsexplaining how to use it.broadcast_tointo a primitive, define its adjoint in numpy_wrapper.py and setup derivatives. This is roughly the same as make internal broadcast and unbroadcast both primitives #292.def_ufunc_jps_inv_pairfor defining the jps of an inverse pair of ufuncs in one shot. So for example the four defsmatch_complex,unbroadcastandunbroadcast_finto newly createdautograd.numpy.utilalongside the new helper functions (I think this rearrangement makes sense).scipy.special.rel_entr.Notes
We could reduce the number of lines of code for other primitive defs in this way. In particular I've got my eye on the reductions (sum, mean, var, std) to potentially do next. I also think, at least in the case of ufuncs, that this style is clearer. I guess we'd better check that there's little or no harm to performance. I think any overhead introduced could potentially be optimized away by carefully handling different special cases in def_ufunc_jps.
During higher order derivatives, the computation of what I've called
'derivative'in the snippet above could be cached and reused. This computation is currently being re-done, because the same primitive's jvp/vjp is being called at each layer of unboxing, with numerically the same ans and *args inputs (although g will be different for each layer). I might write a little blog post about this as I suspect this is an optimization that may not yet be implemented in e.g. Theano or Tensorflow. Implementing this in Autograd might require something similar to what was discussed in #188.