-
Notifications
You must be signed in to change notification settings - Fork 525
/
Copy path_emit_program.py
213 lines (180 loc) · 7.69 KB
/
_emit_program.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch
import torch.fx
from executorch.exir.emit._emitter import (
_DelegateDebugIdentifierMap,
_EmitterState,
_ProgramState,
_TopLevelEmitter,
)
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.schema import Buffer, Program, SubsegmentOffsets
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
from torch.export.exported_program import ExportedProgram, OutputKind
from torch.utils import _pytree as pytree
@dataclass
class EmitterOutput:
"""
The outputs of program emission. Contains the executorch program object as well as
a mapping of instruction ids to debug handles.
"""
# The ExecuTorch program
program: Program
# This dictionary maps the instruction ids to their corresponding
# debug handles or list of debug handles in the case of delegate calls.
debug_handle_map: Dict[int, Union[int, List[int]]]
# This dictionary maps the method name to the corresponding dict which
# contains the mapping of the delegate instruction id to its corresponding
# delegate name and delegate debug identifier mapping.
method_to_delegate_debug_id_map: Dict[
str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]
]
mutable_data: Optional[List[Buffer]]
# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
external_constant_buffer: List[bytes]
# Each constant_tag groups a set of constants together.
# {constant_tag: {fqn: index into external_constant_buffer}}
external_constant_map: Optional[Dict[str, Dict[str, int]]]
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
gm = exported_program.graph_module
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
assert output_node is not None
mutated_outputs: List[Optional[str]] = [
out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None
for out_spec in exported_program.graph_signature.output_specs
]
outputs = pytree.tree_flatten(output_node.args)[0]
user_output_nodes = []
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
if mutated_node_name is None:
user_output_nodes.append(return_node)
continue
with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(tuple(user_output_nodes))
new_output.meta = output_node.meta.copy()
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)
return gm
# For each entry point in the model, determine if its a joint graph,
# and if it is return a map of the indices in the model output that the
# gradient outputs start at and that the parameter outputs start at.
def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int]:
gradients_method_prefix = "__et_training_gradients_index_"
parameters_method_prefix = "__et_training_parameters_index_"
fqn_method_prefix = "__et_training_fqn_"
training_metadata = {}
for name, method in methods.items():
found_grad = False
found_param = False
fqns = []
i = 0
for output_spec in method.graph_signature.output_specs:
if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
if not found_grad:
training_metadata[gradients_method_prefix + name] = i
found_grad = True
fqns.append(output_spec.target)
elif output_spec.kind == OutputKind.TOKEN and not found_param:
assert found_grad # Params must come after gradients
training_metadata[parameters_method_prefix + name] = i
found_param = True
i += 1
if len(fqns) > 0:
training_metadata[fqn_method_prefix + name] = fqns
return training_metadata
def emit_program(
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
emit_stacktrace: bool = False,
prim_getters: Optional[Dict[str, Any]] = None,
emit_mutable_buffer_names: bool = False,
) -> EmitterOutput:
"""
Given a exported program, it returns the program in the format
of the Python version of the flatbuffer Program schema.
Args:
methods: Either the exported program (Exported_Program) that we want to
emit into the flatbuffer, or a dictionary of method names to
ExportedPrograms.
emit_stacktrace: Flag to enable emission of a stacktrace for each
instruction for debugging purposes
Return:
The program in a Python class which mimics the flatbuffer schema
"""
if isinstance(methods, ExportedProgram):
methods = {"forward": methods}
# validation
bad_methods = []
for name, exported_program in methods.items():
if not isinstance(exported_program, ExportedProgram):
bad_methods.append(name)
if len(bad_methods) != 0:
raise ExportError(
ExportErrorType.INVALID_INPUT_TYPE,
f"Did not receive ExportedProgram for the following methods {str(bad_methods)}",
)
plans = []
debug_handle_map = {}
method_to_delegate_debug_id_map = {}
program_state = _ProgramState()
# emit each entry point in order according to name.
for name, exported_program in sorted(methods.items()):
# create empty state
emitter_state = _EmitterState(
values=[],
operators=[],
delegates=[],
operator_cache={},
delegate_cache={},
emit_stacktrace=emit_stacktrace,
emit_mutable_buffer_names=emit_mutable_buffer_names,
)
gm = _remove_non_user_outputs(exported_program)
emitter = _TopLevelEmitter(
name, exported_program, gm, program_state, emitter_state
)
emitter.run()
plans.append(emitter.plan())
debug_handle_map[name] = emitter.debug_handle_map
method_to_delegate_debug_id_map[name] = (
emitter.instr_id_to_delegate_debug_id_map
)
training_metadata = _get_training_metadata(methods)
if len(training_metadata) > 0:
plans.extend(emitter._emit_prim_getters(training_metadata))
# emit any primitive getters
if prim_getters is not None:
plans.extend(emitter._emit_prim_getters(prim_getters))
return EmitterOutput(
debug_handle_map=debug_handle_map,
method_to_delegate_debug_id_map=method_to_delegate_debug_id_map,
program=Program(
version=EXECUTORCH_SCHEMA_VERSION,
execution_plan=plans,
constant_buffer=program_state.constant_buffer,
backend_delegate_data=program_state.backend_delegate_data,
# Segments may be added at serialization time.
segments=[],
# Subsegment offsets may be added at serialization time.
constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
mutable_data_segments=None, # Will be filled in during serialization
),
mutable_data=(
program_state.mutable_buffer
if len(program_state.mutable_buffer) > 1
else None
),
external_constant_buffer=program_state.external_constant_buffer,
external_constant_map=program_state.external_constant_map,
)