Skip to content

Commit

Permalink
[dynamo] Support user defined dicts (pytorch#143548)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Dec 21, 2024
1 parent 9cb743d commit 4627cfd
Show file tree
Hide file tree
Showing 11 changed files with 944 additions and 13 deletions.
759 changes: 759 additions & 0 deletions test/dynamo/test_dicts.py

Large diffs are not rendered by default.

11 changes: 5 additions & 6 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10599,7 +10599,7 @@ def foo():

foo()

def test_dict_subclass_cannot_be_initialized_in_graph(self):
def test_dict_subclass_initialization_in_graph(self):
for super_class in (
collections.OrderedDict,
dict,
Expand All @@ -10615,11 +10615,10 @@ def fn(x):
assert "key" in c
return c["key"] + 1

fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"
):
print(fn_opt(torch.zeros(1)))
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))

@wrapDeterministicFlagAPITest
def test_backward_deterministic_mode_mismatch_warning(self):
Expand Down
Empty file.
1 change: 1 addition & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class ThrowingModule:
not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1,
'"dill" not found or not correct version'
)
@skipIfTorchDynamo("Different behavior between 3.11 and 3.13, causing CI issues")
def test_serialization_dill(self):
x = torch.randn(5, 5)

Expand Down
83 changes: 78 additions & 5 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import collections
import contextlib
import functools
import inspect
Expand All @@ -20,7 +21,7 @@
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
AttributeMutation,
AttributeMutationExisting,
Expand All @@ -37,6 +38,17 @@ def _manual_update_dict(dict_from, dict_to):
dict_to[k] = v


def _manual_dict_setitem(dict_from, dict_to, mro_index):
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
# to be careful because we don't want to trigger the user defined object
# setitem or clear. The mro_index is used to find the dict/OrderedDict from
# the class mro.
dict_class = type(dict_to).__mro__[mro_index]
dict_class.clear(dict_to)
for k, v in dict_from.items():
dict_class.__setitem__(dict_to, k, v)


class SideEffects:
"""
Track side effects (list mutation, setattr, etc) that need to be
Expand Down Expand Up @@ -181,9 +193,9 @@ def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker)

@staticmethod
def cls_supports_mutation_side_effects(cls):
return (
inspect.getattr_static(cls, "__getattribute__", None)
is object.__getattribute__
return inspect.getattr_static(cls, "__getattribute__", None) in (
object.__getattribute__,
dict.__getattribute__,
)

def is_attribute_mutation(self, item):
Expand Down Expand Up @@ -254,6 +266,8 @@ def track_object_new(
obj = torch.autograd.Function()
elif issubclass(user_cls, torch.nn.Module):
obj = nn_module_new(user_cls)
elif issubclass(user_cls, (dict, collections.OrderedDict)):
obj = dict_new(user_cls)
else:
try:
obj = object_new(user_cls)
Expand Down Expand Up @@ -284,6 +298,8 @@ def track_object_new_from_user_defined_class(
] = variables.UserDefinedObjectVariable
if issubclass(user_cls, torch.nn.Module):
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, (dict, collections.OrderedDict)):
variable_cls = variables.UserDefinedDictVariable
elif issubclass(user_cls, MutableMapping):
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
Expand Down Expand Up @@ -442,7 +458,12 @@ def codegen_save_tempvars(self, cg: PyCodegen):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "object_new")
lambda: cg.load_import_from(
utils.__name__,
"dict_new"
if isinstance(var, variables.UserDefinedDictVariable)
else "object_new",
)
)
cg(var.mutation_type.cls_source)
cg.extend_output(create_call_function(1, False))
Expand Down Expand Up @@ -695,6 +716,58 @@ def codegen_update_mutated(self, cg: PyCodegen):
suffixes.append([cg.create_store_deref(var.local_name)])

elif self.is_attribute_mutation(var):
if isinstance(var, variables.UserDefinedDictVariable):
# Do dict related update manually here. The store_attr
# mutations will be applied later.
varname_map = {}
for name in _manual_dict_setitem.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()

try:
mro_index = type(var.value).__mro__.index(
collections.OrderedDict
)
except ValueError:
mro_index = type(var.value).__mro__.index(dict)

cg.extend_output(
[
create_instruction("LOAD_CONST", argval=mro_index),
create_instruction(
"STORE_FAST", argval=varname_map["mro_index"]
),
]
)

cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_to"]
)
]
)

cg(var._dict_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_from"]
)
]
)

dict_update_insts = bytecode_from_template(
_manual_dict_setitem, varname_map=varname_map
)

suffixes.append(
[
*dict_update_insts,
create_instruction("POP_TOP"),
]
)

# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
# apply the mutations.
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,14 @@ def check_numpy_ndarray_args(args, kwargs):
range_iterator: Type[Iterator[Any]] = type(iter(range(0)))
tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
object_new = object.__new__
dict_new = dict.__new__
dict_methods = {
method
for method in itertools.chain(
dict.__dict__.values(), collections.OrderedDict.__dict__.values()
)
if callable(method)
}


def nn_module_new(cls):
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
MutableMappingVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
)

Expand Down
41 changes: 41 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
MutableMappingVariable,
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedObjectVariable,
)

Expand Down Expand Up @@ -1232,6 +1233,46 @@ def build_key_value(i, k, v):
fake_script_obj,
source=self.source,
)
elif (
isinstance(value, (dict, collections.OrderedDict))
and type(value).__new__ is dict.__new__
):
# Construct a dict_vt that will reside inside the UserDefinedDictVariable
self.install_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)

# Guard on the key order
self.tx.output.guard_on_key_order.add(self.source.name())

# We need all the keys to be hashable. We do this within the
# _HashableTracker class in dicts.py
def build_key_value(i, k, v):
source_key = ConstDictKeySource(self.get_source(), i)
key = LazyVariableTracker.create(k, source_key)

source_value = GetItemSource(self.get_source(), source_key)
value = LazyVariableTracker.create(v, source_value)

return key, value

result = dict(
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items())
)

# NB: This is deliberately kept ValueMutationNew because dict_vt is
# an internal representation. dict_vt tracks the mutation on the
# dict side. side_effects infra uses the UserDefinedDictVariable to
# apply side-effects of this dict_vt.
dict_vt = ConstDictVariable(
result,
user_cls=collections.OrderedDict
if isinstance(value, collections.OrderedDict)
else dict,
mutation_type=ValueMutationNew(),
)

result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif issubclass(type(value), MutableMapping):
self.install_guards(GuardBuilder.TYPE_MATCH)
return MutableMappingVariable(value, source=self.source)
Expand Down
9 changes: 8 additions & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,14 @@ def call_method(

arg_hashable = args and is_hashable(args[0])

if name == "__getitem__":
if name == "__init__":
temp_dict_vt = variables.BuiltinVariable(dict).call_dict(
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
assert len(args) == 1
return self.getitem_const_raise_exception_if_absent(tx, args[0])
elif name == "items":
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def call_method(
self.objvar, attr, variables.DeletedVariable()
)
return variables.ConstantVariable(None)
elif (
isinstance(self.objvar, variables.UserDefinedDictVariable)
and inner_fn in self.objvar._dict_methods
):
return self.objvar._dict_vt.call_method(tx, name, args, kwargs)

unimplemented(f"non-function or method super: {inner_fn}")

Expand Down
39 changes: 38 additions & 1 deletion torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
build_checkpoint_variable,
build_invoke_subgraph_variable,
check_constant_args,
dict_methods,
get_custom_getattr,
has_torch_function,
is_frozen_dataclass,
Expand Down Expand Up @@ -637,7 +638,7 @@ def is_standard_new(self):
new_fn = inspect.getattr_static(self.value, "__new__", None)
if isinstance(new_fn, staticmethod):
new_fn = new_fn.__func__
return new_fn in (object.__new__, Generic.__new__)
return new_fn in (object.__new__, Generic.__new__, dict.__new__)

def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if self.source:
Expand Down Expand Up @@ -1441,6 +1442,42 @@ def python_type(self):
return RemovableHandleClass


class UserDefinedDictVariable(UserDefinedObjectVariable):
"""
Represents user defined objects that are subclasses of dict/OrderedDict.
Internally, it uses a ConstDictVariable to represent the dict part of the
variable tracker. For everything else, it falls back to
UserDefinedObjectVariable.
"""

_nonvar_fields = UserDefinedObjectVariable._nonvar_fields

def __init__(self, value, dict_vt=None, **kwargs):
super().__init__(value, **kwargs)
self._dict_vt = dict_vt
if self._dict_vt is None:
assert (
self.source is None
), "dict_vt must be constructed by builder.py when source is present"
self._dict_vt = variables.ConstDictVariable(
{}, mutation_type=ValueMutationNew()
)
self._dict_methods = dict_methods

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
method = self._maybe_get_baseclass_method(name)
if method in self._dict_methods:
return self._dict_vt.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)


class MutableMappingVariable(UserDefinedObjectVariable):
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields

Expand Down

0 comments on commit 4627cfd

Please sign in to comment.