Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive
from autograd.tracer import trace_primitives_map

from . import numpy_wrapper as anp

Expand All @@ -18,16 +19,40 @@ class ArrayBox(Box):
def __getitem__(A, idx):
return A[idx]

# Constants w.r.t float data just pass though
shape = property(lambda self: self._value.shape)
ndim = property(lambda self: self._value.ndim)
size = property(lambda self: self._value.size)
dtype = property(lambda self: self._value.dtype)
# Basic array attributes just pass through
# Single wrapped scalars are presented as 0-dim, 1-size arrays.
shape = property(lambda self: anp.shape(self._value))
ndim = property(lambda self: anp.ndim(self._value))
size = property(lambda self: anp.size(self._value))
dtype = property(lambda self: anp.result_type(self._value))

T = property(lambda self: anp.transpose(self))

def __array_namespace__(self, *, api_version: Union[str, None] = None):
return anp

# Calls to wrapped ufuncs first forward further handling to the ufunc
# dispatching mechanism, which allows any other operands to also try
# handling the ufunc call. See also tracer.primitive.
#
# In addition, implementing __array_ufunc__ allows ufunc calls to propagate
# through non-differentiable array-like objects (e.g. xarray.DataArray) into
# ArrayBoxes which might be contained within, upon which __array_ufunc__
# below would call autograd's wrapper for the ufunc. For example, given a
# DataArray `a` containing an ArrayBox, this lets us write `np.abs(a)`
# instead of requiring the xarray-specific `xr.apply_func(np.abs, a)`.
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != "__call__":
return NotImplemented
if "out" in kwargs:
return NotImplemented
if ufunc_wrapper := trace_primitives_map.get(ufunc):
try:
return ufunc_wrapper(*inputs, called_by_autograd_dispatcher=True, **kwargs)
except NotImplementedError:
return NotImplemented
return NotImplemented

def __len__(self):
return len(self._value)

Expand Down
30 changes: 22 additions & 8 deletions autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,29 @@ class IntdtypeSubclass(cls):
def wrap_namespace(old, new):
unchanged_types = {float, int, type(None), type}
int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer}
obj_to_wrapped = []
for name, obj in old.items():
if obj in notrace_functions:
new[name] = notrace_primitive(obj)
elif callable(obj) and type(obj) is not type:
new[name] = primitive(obj)
elif type(obj) is type and obj in int_types:
new[name] = wrap_intdtype(obj)
elif type(obj) in unchanged_types:
new[name] = obj
# Map multiple names of the same object (e.g. conj/conjugate)
# to the same wrapped object
for mapped_obj, wrapped in obj_to_wrapped:
if mapped_obj is obj:
new[name] = wrapped
break
else:
if obj in notrace_functions:
wrapped = notrace_primitive(obj)
new[name] = wrapped
obj_to_wrapped.append((obj, wrapped))
elif callable(obj) and type(obj) is not type:
wrapped = primitive(obj)
new[name] = wrapped
obj_to_wrapped.append((obj, wrapped))
elif type(obj) is type and obj in int_types:
wrapped = wrap_intdtype(obj)
new[name] = wrapped
obj_to_wrapped.append((obj, wrapped))
elif type(obj) in unchanged_types:
new[name] = obj


wrap_namespace(_np.__dict__, globals())
Expand Down
45 changes: 41 additions & 4 deletions autograd/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from collections import defaultdict
from contextlib import contextmanager

import numpy as np

import autograd

from .util import subvals, toposort
from .wrap_util import wraps

Expand Down Expand Up @@ -33,28 +37,61 @@ def new_root(cls, *args, **kwargs):
return root


trace_primitives_map = {}


def primitive(f_raw):
"""
Wraps a function so that its gradient can be specified and its invocation
can be recorded. For examples, see the docs."""

@wraps(f_raw)
def f_wrapped(*args, **kwargs):
def f_wrapped(*args, called_by_autograd_dispatcher=False, **kwargs):
boxed_args, trace, node_constructor = find_top_boxed_args(args)
if boxed_args:
# If we are a wrapper around a ufunc, first forward further handling to
# the ufunc dispatching mechanism (if we aren't already running inside it)
# by calling the ufunc. This allows other operands to also try to handle
# the call (it's still possible our handling attempt below will get the
# first shot; the handlers order is determined by the dispatch mechanism).
#
# For example, consider multiplying an ndarray wrapped inside an ArrayBox
# by an xarray.DataArray. The handling below will fail: The ndarray will
# be unboxed and multiplied by the DataArray resulting in a DataArray,
# for which `new_box` will raise an exception. In contrast, the DataArray's
# handling of the call might succeed: it might contain an ndarray, either
# plain or boxed in an ArrayBox, in which case it will be multiplied by
# the other ArrayBox yielding a new ArrayBox, which will be stored in a new
# DataArray.
if (
isinstance(f_raw, np.ufunc)
and not called_by_autograd_dispatcher
and any(isinstance(arg, autograd.numpy.numpy_boxes.ArrayBox) for arg in args)
):
return f_raw(*args, **kwargs)

argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args])
if f_wrapped in notrace_primitives[node_constructor]:
return f_wrapped(*argvals, **kwargs)
return f_wrapped(
*argvals, called_by_autograd_dispatcher=called_by_autograd_dispatcher, **kwargs
)
parents = tuple(box._node for _, box in boxed_args)
argnums = tuple(argnum for argnum, _ in boxed_args)
ans = f_wrapped(*argvals, **kwargs)
ans = f_wrapped(*argvals, called_by_autograd_dispatcher=called_by_autograd_dispatcher, **kwargs)
node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
return new_box(ans, trace, node)
try:
box = new_box(ans, trace, node)
return box
except:
if called_by_autograd_dispatcher:
raise NotImplementedError
raise
else:
return f_raw(*args, **kwargs)

f_wrapped.fun = f_raw
f_wrapped._is_autograd_primitive = True
trace_primitives_map[f_raw] = f_wrapped
return f_wrapped


Expand Down
Loading