Skip to content

Commit

Permalink
[dynamo] Remove transformers ModelOutput hack (pytorch#143567)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#143567
Approved by: https://github.com/williamwen42, https://github.com/jansel
ghstack dependencies: pytorch#143548
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Dec 21, 2024
1 parent 4627cfd commit 0da004f
Show file tree
Hide file tree
Showing 9 changed files with 3 additions and 307 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class MyDataClass(ModelOutput):
x: torch.FloatTensor = None

def fn(x):
obj = MyDataClass(x=x)
obj = MyDataClass(x=x * 3)
return obj

inp = torch.randn(3, 3)
Expand Down
33 changes: 0 additions & 33 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,39 +613,6 @@ def codegen_update_mutated(self, cg: PyCodegen):
]
)

elif isinstance(var, variables.CustomizedDictVariable):
# need to update the dict manually since update method may be invalid
varname_map = {}
for name in _manual_update_dict.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()

cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[create_instruction("STORE_FAST", argval=varname_map["dict_to"])]
)

cg(var, allow_cache=False) # Don't codegen via source
cg.extend_output(
[create_instruction("STORE_FAST", argval=varname_map["dict_from"])]
)

cg(var.source) # type: ignore[attr-defined]
cg.load_method("clear")

# unfortunately can't just use DICT_MERGE due to possible custom behaviors
dict_update_insts = bytecode_from_template(
_manual_update_dict, varname_map=varname_map
)

suffixes.append(
[
*create_call_method(0), # clear
create_instruction("POP_TOP"),
*dict_update_insts,
create_instruction("POP_TOP"),
]
)

elif isinstance(var, variables.ConstDictVariable):
# Reconstruct works as follow:
# (1) Skip codegen if there are no new items
Expand Down
6 changes: 1 addition & 5 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,11 +3172,7 @@ def build_inline_tracer(

code: types.CodeType = func.get_code()
if code.co_name in ("__setitem__", "__setattr__") and not (
args
and isinstance(
args[0],
(variables.CustomizedDictVariable, variables.UserDefinedObjectVariable),
)
args and isinstance(args[0], variables.UserDefinedObjectVariable)
):
unimplemented(f"inline {code.co_name}")

Expand Down
2 changes: 0 additions & 2 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from .dicts import (
ConstDictVariable,
CustomizedDictVariable,
DefaultDictVariable,
DictKeySetVariable,
FrozensetVariable,
Expand Down Expand Up @@ -131,7 +130,6 @@
"CountIteratorVariable",
"CreateTMADescriptorVariable",
"CUDADeviceVariable",
"CustomizedDictVariable",
"CycleIteratorVariable",
"DataPtrVariable",
"DefaultDictVariable",
Expand Down
7 changes: 0 additions & 7 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@
)
from .dicts import (
ConstDictVariable,
CustomizedDictVariable,
DefaultDictVariable,
DictKeySetVariable,
FrozensetVariable,
Expand Down Expand Up @@ -606,10 +605,6 @@ def create_2d_tma_descriptor():
elif value is sys.modules:
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonSysModulesVariable(source=self.source)
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = CustomizedDictVariable.wrap(self, value)
return self.tx.output.side_effects.track_object_existing(value, result)
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)

Expand Down Expand Up @@ -2172,8 +2167,6 @@ def wrap_unspecialized_primitive(self, value):
def _dataclasses_fields_lambda(obj):
if isinstance(obj, UserDefinedObjectVariable):
value = obj.value
elif isinstance(obj, CustomizedDictVariable):
value = obj.user_cls
else:
unimplemented(f"Dataclass fields handling fails for type {obj}")
items = []
Expand Down
1 change: 0 additions & 1 deletion torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,6 @@ def call_setattr(
if isinstance(
obj,
(
variables.CustomizedDictVariable,
variables.PlacementVariable,
variables.NamedTupleVariable,
variables.UserDefinedObjectVariable,
Expand Down
240 changes: 1 addition & 239 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
# mypy: ignore-errors

import collections
import dataclasses
import functools
import inspect
import sys
from typing import Dict, List, Optional, TYPE_CHECKING

from torch._subclasses.fake_tensor import is_fake

from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..eval_frame import skip_code
from ..exc import raise_observed_exception, unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, is_from_local_source
from ..utils import dict_keys, dict_values, istype, specialize_symnode
from ..utils import dict_keys, dict_values, specialize_symnode
from .base import ValueMutationNew, VariableTracker
from .constant import ConstantVariable

Expand Down Expand Up @@ -735,241 +732,6 @@ def python_type(self):
return dict_values


def _is_matching_transformers_cls(cls) -> bool:
mod = sys.modules.get("transformers.file_utils")
if mod is None:
mod = sys.modules.get("transformers.utils.generic")
return mod is not None and issubclass(cls, mod.ModelOutput)


def _is_matching_diffusers_cls(cls) -> bool:
mod = sys.modules.get("diffusers.utils")
return mod is not None and issubclass(cls, mod.BaseOutput)


def _call_hasattr_customobj(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
if tx.output.side_effects.is_attribute_mutation(self):
try:
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
return variables.ConstantVariable.create(
not isinstance(result, variables.DeletedVariable)
)
except KeyError:
pass
if name in self.items or hasattr(self.user_cls, name):
return ConstantVariable(True)
elif istype(self.mutation_type, ValueMutationNew) and self.source is None:
# Something created locally can't have any extra fields on it
return ConstantVariable(False)
elif self.source:
# Maybe add a guard
try:
example = tx.output.root_tx.get_example_value(self.source)
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
return ConstantVariable(hasattr(example, name))
except KeyError:
pass
unimplemented(
f"hasattr({self.__class__.__name__}, {name}) {self.mutation_type} {self.source}"
)


class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls_hf(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)

@staticmethod
def is_matching_cls(cls):
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls is not collections.OrderedDict
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
return CustomizedDictVariable.is_matching_cls_hf(cls)

@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))

# called from user_defined.py
# when is_matching_cls(cls) is true
@classmethod
def create(cls, user_cls, args, kwargs, options):
# avoid tracing when returning ModelOutput from forward func
for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"):
if hasattr(user_cls, attr_name):
fn = getattr(user_cls, attr_name)
assert callable(fn), f"expect callable attr {attr_name}"
if hasattr(fn, "__code__"):
skip_code(fn.__code__)

if dataclasses.is_dataclass(user_cls):
# @dataclass CustomDict(a=1, b=2)
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()

def make_var(x):
if isinstance(x, VariableTracker):
return x
elif ConstantVariable.is_literal(x):
return ConstantVariable.create(x)
else:
unimplemented(
"expect VariableTracker or ConstantVariable.is_literal"
)

bound_args = {}
if cls.is_matching_cls_hf(user_cls):
# Skip none
for k, v in bound.arguments.items():
if isinstance(v, ConstantVariable) and v.value is None or v is None:
continue
bound_args[k] = v
else:
bound_args = bound.arguments

items = {
ConstantVariable.create(k): make_var(v) for k, v in bound_args.items()
}
elif not args:
# CustomDict(a=1, b=2) in the general (non-dataclass) case.
items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs:
# CustomDict({'a': 1, 'b': 2})
items = args[0].items
else:
unimplemented("custom dict init with args/kwargs unimplemented")

return cls(items, user_cls, **options)

# called from builder.py
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)

if not cls.is_matching_cls_hf(user_cls):
unimplemented("custom non-hf dict subclass wrap unimplemented")

items = builder.__class__(tx=builder.tx, source=builder.source)(
collections.OrderedDict(obj)
).items

keys = [f.name for f in dataclasses.fields(user_cls)]
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None:
key = ConstantVariable.create(key)
items[key] = var
return cls(items, user_cls, source=builder.source)

def __init__(self, items, user_cls, **options) -> None:
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)

def as_proxy(self):
raise NotImplementedError

# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py
def reconstruct(self, codegen):
is_hf_model_output = self.is_matching_cls_hf(self.user_cls)

def gen_fn1():
# If the user class is a ModelOutput, then wrap the instance creation in
# torch._dynamo.disable(). Even though we mark the __post_init__ as skip
# in `create` function, this is not enough. TorchDynamo can still get
# triggered on the child functions of __post_init__. This upsets export.
# Since, we know that ModelOutput __post_init__ is not worth optimizing,
# we just wrap the instance creation in torch._dynamo.disable(),
# regardless whether its export or not.
if is_hf_model_output:
# load torch._dynamo.disable
def gen_fn2():
codegen.append_output(codegen.create_load_global("torch", add=True))
codegen.append_output(codegen.create_load_attr("_dynamo"))
codegen.append_output(codegen.create_load_attr("disable"))

codegen.add_push_null(gen_fn2)

codegen.extend_output([codegen.create_load_const_unchecked(self.user_cls)])

if is_hf_model_output:
# Wrap user_cls with disable
codegen.extend_output(create_call_function(1, False))

codegen.add_push_null(gen_fn1)

# All the keys are just wrapped strings
d = self.keys_as_python_constant()
codegen.foreach(d.values())
keys = tuple(d.keys())
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, False))

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
fn = getattr(self.user_cls, name)
source = None if self.source is None else AttrSource(self.source, name)

if hasattr(fn, "__objclass__") and fn.__objclass__ in (
dict,
collections.OrderedDict,
):
# for python dict method without overridden
return super().call_method(tx, name, args, kwargs)
elif name in (
"__getitem__",
"to_tuple",
"__setitem__",
"__setattr__",
"__post_init__",
):
# for user overridden method
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=source),
[self] + list(args),
kwargs,
)
elif fn is getattr(collections.OrderedDict, name, None):
return super().call_method(tx, name, args, kwargs)

unimplemented(f"custom dict: call_method unimplemented name={name}")

def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
if dataclasses.is_dataclass(self.user_cls):
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
return super().var_getattr(tx, name)

call_hasattr = _call_hasattr_customobj


@functools.lru_cache(None)
def _install_PretrainedConfig_patch():
mod = sys.modules.get("transformers.configuration_utils")
Expand Down
14 changes: 0 additions & 14 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,6 @@ def call_method(
value = collections.OrderedDict.__getitem__(self.objvar.value, key)
source = ODictGetItemSource(self.objvar.source, key)
return VariableTracker.build(tx, value, source)
elif inner_fn in (
collections.OrderedDict.__setitem__,
object.__setattr__,
) and isinstance(self.objvar, variables.CustomizedDictVariable):
assert not kwargs and len(args) == 2
return super(variables.CustomizedDictVariable, self.objvar).call_method(
tx, "__setitem__", args, kwargs
)
elif inner_fn is collections.OrderedDict.__getitem__ and isinstance(
self.objvar, variables.CustomizedDictVariable
):
return super(variables.CustomizedDictVariable, self.objvar).call_method(
tx, "__getitem__", args, kwargs
)
elif is_standard_setattr(inner_fn) and isinstance(
self.objvar, UserDefinedObjectVariable
):
Expand Down
Loading

0 comments on commit 0da004f

Please sign in to comment.