diff --git a/autograd/core.py b/autograd/core.py index 6d7292294..89b2d78f8 100644 --- a/autograd/core.py +++ b/autograd/core.py @@ -74,13 +74,14 @@ def __call__(self, *args, **kwargs): def vjp(self, argnum, outgrad, ans, vs, gvs, args, kwargs): try: - return self.vjps[argnum](outgrad, ans, vs, gvs, *args, **kwargs) + vjp = self.vjps[argnum] except KeyError: if self.vjps == {}: errstr = "Gradient of {0} not yet implemented." else: errstr = "Gradient of {0} w.r.t. arg number {1} not yet implemented." raise NotImplementedError(errstr.format(self.fun.__name__, argnum)) + return vjp(outgrad, ans, vs, gvs, *args, **kwargs) def defvjp(self, vjpmaker, argnum=0): vjpmaker.__name__ = "VJP_{}_of_{}".format(argnum, self.__name__)