Skip to content

Commit

Permalink
[aot] refactor dynamo source and cudagraphs static idx logic (pytorch…
Browse files Browse the repository at this point in the history
…#141748)

Pull Request resolved: pytorch#141748
Approved by: https://github.com/ezyang
  • Loading branch information
xmfan authored and pytorchmergebot committed Dec 21, 2024
1 parent ae3d385 commit ffd1b53
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 45 deletions.
5 changes: 4 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,10 +1120,14 @@ class ExportResult(NamedTuple):
# destructuring so it is BC-breaking


# NOTE: this function only supports graphs created by Dynamo's OutputGraph module
def check_signature_rewritable(graph):
input_errors = []
for node in graph.graph.find_nodes(op="placeholder"):
# set in OutputGraph._call_user_compiler
assert hasattr(node, "_dynamo_source")
assert hasattr(graph, "_source_to_user_stacks")

source = node._dynamo_source
user_stacks = graph._source_to_user_stacks.get(source)
if user_stacks is None:
Expand Down Expand Up @@ -1655,7 +1659,6 @@ def result_capturing_wrapper(*graph_inputs):
graph.print_readable(print_output=False, colored=True),
)
else:
assert hasattr(graph, "_source_to_user_stacks")
assert out_guards is not None, "Failed to produce guards during tracing"
assert fake_mode is not None

Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,8 +1452,10 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
for pl in placeholders:
arg = pl.meta["grapharg"]
# TODO: Why isn't this stored in meta :think:
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
pl._dynamo_source = arg.source

# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]

Expand Down
111 changes: 67 additions & 44 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
Callable,
Dict,
KeysView,
List,
NewType,
Optional,
Expand Down Expand Up @@ -1004,6 +1005,68 @@ def forward(self, *args, **kwargs):
return AOTModule()


def _try_get_metadata_from_dynamo(
mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int
) -> Tuple[Optional[List[torch._guards.Source]], List[int]]:
"""
Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule.
We first verify that `mod` does come from Dynamo, then we handle cases where
metadata might be missing.
Returns:
aot_autograd_arg_pos_to_source: used to dedup params and their guards
static_input_indices: used to identify static inputs for cudagraphs
"""
if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
# graph was not captured by dynamo
return None, []

if not hasattr(mod, "_param_name_to_source"):
# is from export
return None, []

# We now know this came from dynamo, and (1) we care about guards,
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
# Additionally, we mark static indices for cudagraphs.
param_name_to_source = mod._param_name_to_source
seen_sources = set()

aot_autograd_arg_pos_to_source = []
# Collect the new inputs lifted by aotdispatch
for name in param_keys:
assert name in param_name_to_source, f"{name} not found."
source = param_name_to_source[name]
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)

# Collect the dynamo graph inputs
static_input_indices = []
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
assert hasattr(node, "_dynamo_source")
source = node._dynamo_source
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)

if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_inputs_log.debug(
"Adding static input pos %s for source %s", pos, source_name
)
static_input_indices.append(pos)
else:
static_inputs_log.debug(
"Non-static input pos %s for source %s", pos, source_name
)

assert full_args_num == len(aot_autograd_arg_pos_to_source)
return aot_autograd_arg_pos_to_source, static_input_indices


def aot_module_simplified(
mod: nn.Module,
args,
Expand Down Expand Up @@ -1041,8 +1104,6 @@ def aot_module_simplified(
if inference_compiler is None:
inference_compiler = fw_compiler

seen_sources = set()

full_args = []
# First, the params
full_args.extend(params_flat)
Expand All @@ -1054,51 +1115,13 @@ def aot_module_simplified(
tracing_context.params_unwrapped_to_flat_index,
) = unwrap_tensor_subclasses_with_indices_to_original(params_flat)

aot_autograd_arg_pos_to_source = None
# Then, the params 1:1 mapped sources, if relevant.
if hasattr(mod, "_param_name_to_source"):
aot_autograd_arg_pos_to_source = []
# We now know this came from dynamo, and (1) we care about guards,
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
for name in params.keys():
assert name in mod._param_name_to_source, f"{name} not found."
source = mod._param_name_to_source[name]
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)

# Next, the input args
full_args.extend(args)

static_input_indices = []
if hasattr(mod, "graph"):
# Non dynamo entrypoints can get to here...
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
if hasattr(node, "_dynamo_source"):
# ... but not here!
if aot_autograd_arg_pos_to_source is None:
aot_autograd_arg_pos_to_source = []
source = node._dynamo_source
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)

if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_inputs_log.debug(
"Adding static input pos %s for source %s", pos, source_name
)
static_input_indices.append(pos)
else:
static_inputs_log.debug(
"Non-static input pos %s for source %s", pos, source_name
)

if aot_autograd_arg_pos_to_source is not None:
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
(
aot_autograd_arg_pos_to_source,
static_input_indices,
) = _try_get_metadata_from_dynamo(mod, params.keys(), len(full_args))

dynamic_shapes = False
for x in full_args:
Expand Down

0 comments on commit ffd1b53

Please sign in to comment.