-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathexport_serialize.py
2916 lines (2605 loc) · 113 KB
/
export_serialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-ignore-all-errors
# Copied over from caffe2/torch/_export/serde/serialize.py until dialects
# are supported in torch export serializer.
import base64
import copy
import copyreg
import dataclasses
import heapq
import inspect
import io
import json
import logging
import math
import operator
import re
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Callable,
cast,
Dict,
final,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
import sympy
import torch
import torch.export.exported_program
import torch.export.exported_program as ep
from torch._export.serde.schema import SchemaVersion
from torch._export.verifier import load_verifier
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental import symbolic_shapes
from torch.utils import _pytree as pytree
from torch.utils._pytree import treespec_dumps, treespec_loads
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
# pyre-ignore
from .schema import ( # type: ignore[attr-defined]
Argument,
BufferMutationSpec,
ConstantInputSpec,
ConstantValue,
CustomObjArgument,
Device,
ExportedProgram,
GradientToParameterSpec,
GradientToUserInputSpec,
Graph,
GraphArgument,
GraphModule,
GraphSignature,
InputSpec,
InputToBufferSpec,
InputToCustomObjSpec,
InputTokenSpec,
InputToParameterSpec,
InputToTensorConstantSpec,
Layout,
LossOutputSpec,
MemoryFormat,
ModuleCallEntry,
ModuleCallSignature,
NamedArgument,
Node,
OptionalTensorArgument,
OutputSpec,
OutputTokenSpec,
RangeConstraint,
ScalarType,
SCHEMA_VERSION,
SymBool,
SymBoolArgument,
SymExpr,
SymExprHint,
SymInt,
SymIntArgument,
TensorArgument,
TensorMeta,
TokenArgument,
TREESPEC_VERSION,
UserInputMutationSpec,
UserInputSpec,
UserOutputSpec,
)
from .union import _Union
__all__ = [
"serialize",
"GraphModuleSerializer",
"ExportedProgramSerializer",
"GraphModuleDeserializer",
"ExportedProgramDeserializer",
]
from .upgrade import GraphModuleOpUpgrader
log = logging.getLogger(__name__)
class SerializeError(RuntimeError):
pass
def _reverse_map(d: Dict[Any, Enum]):
return {v.value: k for k, v in d.items()}
MetaType = Union[
FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument
]
ST_DELIMITER = ";"
_TORCH_TO_SERIALIZE_DTYPE = {
torch.uint8: ScalarType.BYTE,
torch.int8: ScalarType.CHAR,
torch.int16: ScalarType.SHORT,
torch.int32: ScalarType.INT,
torch.int64: ScalarType.LONG,
torch.float16: ScalarType.HALF,
torch.float32: ScalarType.FLOAT,
torch.float64: ScalarType.DOUBLE,
torch.complex32: ScalarType.COMPLEXHALF,
torch.complex64: ScalarType.COMPLEXFLOAT,
torch.complex128: ScalarType.COMPLEXDOUBLE,
torch.bool: ScalarType.BOOL,
torch.bfloat16: ScalarType.BFLOAT16,
torch.uint16: ScalarType.UINT16
}
_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type]
_TORCH_TO_SERIALIZE_LAYOUT = {
torch.sparse_coo: Layout.SparseCoo,
torch.sparse_csr: Layout.SparseCsr,
torch.sparse_csc: Layout.SparseCsc,
torch.sparse_bsr: Layout.SparseBsr,
torch.sparse_bsc: Layout.SparseBsc,
torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined]
torch.strided: Layout.Strided,
}
_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type]
_TORCH_TO_SERIALIZE_MEMORY_FORMAT = {
torch.contiguous_format: MemoryFormat.ContiguousFormat,
torch.channels_last: MemoryFormat.ChannelsLast,
torch.channels_last_3d: MemoryFormat.ChannelsLast3d,
torch.preserve_format: MemoryFormat.PreserveFormat,
}
_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type]
_SYM_INT_OPS = {
operator.mul,
operator.add,
operator.sub,
operator.floordiv,
operator.mod,
torch.sym_int,
torch.sym_float,
torch.sym_ite,
torch.sym_max,
torch.sym_min,
torch.sym_sqrt,
}
_SYM_BOOL_OPS = {
operator.eq,
operator.ne,
operator.le,
operator.ge,
operator.lt,
operator.gt,
torch.sym_not,
}
@dataclass
class SerializedArtifact:
exported_program: bytes
state_dict: bytes
constants: bytes
example_inputs: bytes
@dataclass
class _SerializedProgram:
exported_program: ExportedProgram
state_dict: bytes
constants: bytes
example_inputs: bytes
def deserialize_device(d: Device) -> torch.device:
if d.index is None:
return torch.device(type=d.type) # type: ignore[call-overload]
return torch.device(type=d.type, index=d.index)
def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
if isinstance(s, (torch.SymInt, int)):
if symbolic_shapes.is_concrete_int(s):
return SymInt.create(as_int=int(s))
else:
assert isinstance(s, torch.SymInt)
if s.node.hint is None:
return SymInt.create(as_expr=SymExpr(str(s)))
else:
return SymInt.create(
as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))
)
else:
raise SerializeError(
f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`"
)
def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool:
if isinstance(s, (torch.SymBool, bool)):
if symbolic_shapes.is_concrete_bool(s):
return SymBool.create(as_bool=bool(s))
else:
return SymBool.create(as_expr=SymExpr(expr_str=str(s)))
else:
raise SerializeError(
f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`"
)
def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta:
"""
Extract a TensorMeta describing `t`.
"""
return TensorMeta(
dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype],
sizes=[serialize_sym_int(s) for s in t.shape],
requires_grad=t.requires_grad,
device=Device(type=t.device.type, index=t.device.index),
strides=[serialize_sym_int(s) for s in t.stride()],
storage_offset=serialize_sym_int(0), # TODO needs to be fixed.
layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout],
)
_CURRENT_DESERIALIZER: List["GraphModuleDeserializer"] = []
def _reduce_fake_tensor(fake_tensor: FakeTensor):
is_parameter = isinstance(fake_tensor, torch.nn.Parameter)
tensor_meta = serialize_tensor_meta(fake_tensor)
tensor_meta_bytes = json.dumps(
_dataclass_to_dict(tensor_meta), cls=EnumEncoder
).encode("utf-8")
return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter)
def _reconstruct_fake_tensor(
serialized_tensor_meta: bytes, is_parameter: bool
) -> FakeTensor:
# Deserialize the bytes into a TensorMeta
json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8"))
tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta)
# Find the current fake mode
assert len(_CURRENT_DESERIALIZER) != 0, "Need access to current deserializer state"
fake_tensor = _CURRENT_DESERIALIZER[-1].deserialize_tensor_meta(tensor_meta)
if is_parameter:
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
return fake_tensor
def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes:
assert (
FakeTensor not in copyreg.dispatch_table
), "Refusing to stomp on existing FakeTensor reducer"
try:
copyreg.pickle(FakeTensor, _reduce_fake_tensor)
buffer = io.BytesIO()
# This is a workaround for backend's tensor deserialization problem:
# unpickleTensor() always create a tensor on the device where it was originally saved
# This behavior is bad for multi-gpu training, as we wish to directly load the tensor
# on the designated device.
# For now, we simply move the tensor to cpu before saving.
# TODO: this should be fixed by deserialization instead.
torch.save(artifact, buffer)
return buffer.getvalue()
finally:
del copyreg.dispatch_table[FakeTensor]
def deserialize_torch_artifact(
serialized: Union[Dict[str, Any], Tuple[Any, ...], bytes]
):
if isinstance(serialized, (dict, tuple)):
return serialized
if len(serialized) == 0:
return {}
buffer = io.BytesIO(serialized)
buffer.seek(0)
artifact = torch.load(buffer)
assert isinstance(artifact, (tuple, dict))
return artifact
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
return math.inf
if val in (-sympy.oo, -int_oo):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
# TODO: Remove this adjustment when Ed gets rid of fractional ranges
log.warning(
"Export constraints cannot be non-integer expressions. Found "
"type %s, and value %s. We will attempt to %s "
"this value.",
type(val),
val,
adjust,
)
if adjust == "floor":
return math.floor(val)
elif adjust == "ceil":
return math.ceil(val)
else:
raise RuntimeError(f"Got invalid adjustment {adjust}")
def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val == math.inf:
return int_oo
if val == -math.inf:
return -int_oo
return sympy.Integer(val)
def serialize_range_constraints(
range_constraints: Dict[sympy.Symbol, ValueRanges]
) -> Dict[str, RangeConstraint]:
return {
str(k): RangeConstraint(
_sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
_sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type]
)
for k, v in range_constraints.items()
}
def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
returns = target._schema.returns
return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
returns = target._schema.returns
if len(returns) != 1:
return False
return_type = returns[0].real_type
return isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
)
def _output_node_at_index(node, index):
for user in node.users:
assert user.target is operator.getitem, f"{user} is not a getitem node"
if index == user.args[1]:
return user
return None
@dataclass
class GraphState:
inputs: List[Argument] = field(default_factory=list)
outputs: List[Argument] = field(default_factory=list)
nodes: List[Node] = field(default_factory=list)
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
class Final(type):
def __new__(metacls, name, bases, classdict):
for b in bases:
if isinstance(b, Final):
raise TypeError(f"type '{b.__name__}' is not an acceptable base type")
return type.__new__(metacls, name, bases, dict(classdict))
class GraphModuleSerializer:
def __init__(
self,
graph_signature: ep.ExportGraphSignature,
module_call_graph: List[ep.ModuleCallEntry],
):
self.graph_state = GraphState()
self.graph_signature = graph_signature
self.module_call_graph = module_call_graph
self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
@contextmanager
def save_graph_state(self):
saved = self.graph_state
self.graph_state = GraphState()
try:
yield
finally:
self.graph_state = saved
def handle_placeholder(self, node: torch.fx.Node):
assert node.op == "placeholder"
if isinstance(node.meta["val"], torch.Tensor):
graph_input = Argument.create(as_tensor=TensorArgument(name=node.name))
self.graph_state.tensor_values[node.name] = serialize_tensor_meta(
node.meta["val"]
)
elif isinstance(node.meta["val"], torch.SymInt):
graph_input = Argument.create(
as_sym_int=SymIntArgument.create(as_name=node.name)
)
self.graph_state.sym_int_values[node.name] = serialize_sym_int(
node.meta["val"]
)
elif isinstance(node.meta["val"], (int, bool, str, float, type(None))):
graph_input = self.serialize_input(node.meta["val"])
elif isinstance(node.meta["val"], ep.CustomObjArgument):
class_fqn = node.meta["val"].class_fqn
graph_input = Argument.create(
as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)
)
self.graph_state.custom_obj_values[node.name] = (
self.serialize_script_obj_meta(node.meta["val"])
)
else:
raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}")
self.graph_state.inputs.append(graph_input)
def handle_output(self, node: torch.fx.Node):
assert node.op == "output"
assert len(node.args) == 1, "FX.Node's args should have one arg"
node_args = node.args[0]
if isinstance(node_args, torch.fx.Node):
# For singleton tensor returns
self.graph_state.is_single_tensor_return = True
self.graph_state.outputs = [self.serialize_input(node_args)]
else:
assert isinstance(node_args, (tuple, list))
self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args]
def serialize_operator(self, target) -> str:
if isinstance(target, str):
return target
elif target.__module__.startswith("torch._ops"):
# TODO(zhxchen17) Maybe provide a function name helper in FX.
# From torch.fx.node._get_qualified_name
module = target.__module__.replace("torch._ops", "torch.ops")
return f"{module}.{target.__name__}"
else: # TODO(zhxchen17) Don't catch all here.
return f"{target.__module__}.{target.__name__}"
def handle_call_function(self, node: torch.fx.Node):
assert node.op == "call_function"
# getitem has been handled in the producer node, skip it here
if node.target is operator.getitem:
return
if node.target in _SYM_INT_OPS:
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Argument.create(
as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
)
],
metadata=self.serialize_metadata(node),
)
elif node.target in _SYM_BOOL_OPS:
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Argument.create(
as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
)
],
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
# TODO: create a new tensor_values here, meta might have faketensor info
metadata=self.serialize_metadata(node),
)
elif isinstance(node.target, torch._ops.HigherOrderOperator):
ex_node = Node(
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
outputs=self.serialize_hoo_outputs(node),
metadata=self.serialize_metadata(node),
)
else:
raise SerializeError(f"Serializing {node.target} is not supported")
self.graph_state.nodes.append(ex_node)
def handle_get_attr(self, node):
pass
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
ret = {}
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace
if nn_module_stack := node.meta.get("nn_module_stack"):
def export_nn_module_stack(val):
assert isinstance(val, tuple) and len(val) == 2
path, ty = val
assert isinstance(path, str)
# node.meta["nn_module_stack"] could have two forms:
# 1. (path: str, module_type: 'type'), e.g.
# ('', <class 'sigmoid.inference.MySimpleModel'>)
# 2. (path: str, module_type: str), e.g.
# ('', 'sigmoid.inference.MySimpleModel')
# ExportedProgram directly produced by torch.export() has form 1
# ExportedProgram deserialized from disk has form 2
# TODO: This is not ideal, we should fix this.
if isinstance(ty, str):
normalized_ty = ty
else:
normalized_ty = ty.__module__ + "." + ty.__qualname__
return path + "," + normalized_ty
# Serialize to "key,orig_path,type_str"
nn_module_list = [
f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items()
]
ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list)
if source_fn_st := node.meta.get("source_fn_stack"):
source_fn_list = [
f"{source_fn[0]},{self.serialize_operator(source_fn[1])}"
for source_fn in source_fn_st
]
ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list)
if torch_fn := node.meta.get("torch_fn"):
ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
return ret
def serialize_script_obj_meta(
self, script_obj_meta: ep.CustomObjArgument
) -> CustomObjArgument:
return CustomObjArgument(
name=script_obj_meta.name,
class_fqn=script_obj_meta.class_fqn,
)
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
serialized_args = []
args_names = inspect.signature(op).parameters.keys()
for args_name, arg in zip(args_names, args):
serialized_args.append(
NamedArgument(name=args_name, arg=self.serialize_input(arg))
)
return serialized_args
def serialize_inputs(
self, target: torch._ops.OpOverload, args, kwargs=None
) -> List[NamedArgument]:
assert isinstance(target, torch._ops.OpOverload)
kwargs = kwargs or {}
serialized_args = []
for i, schema_arg in enumerate(target._schema.arguments):
if schema_arg.name in kwargs:
serialized_args.append(
NamedArgument(
name=schema_arg.name,
arg=self.serialize_input(
kwargs[schema_arg.name], schema_arg.type
),
)
)
elif not schema_arg.kwarg_only and i < len(args):
serialized_args.append(
NamedArgument(
name=schema_arg.name,
arg=self.serialize_input(args[i], schema_arg.type),
)
)
else:
# We intentionally don't serialize the missing arguments
# with default values
pass
return serialized_args
def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]:
"""
For serializing HOO inputs since HOOs do not have a schema.
"""
inputs = [
NamedArgument(
name="",
arg=self.serialize_input(a),
)
for a in args
]
inputs.extend(
[
NamedArgument(name=name, arg=self.serialize_input(a))
for name, a in kwargs.items()
]
)
return inputs
def is_sym_int_arg(self, arg) -> bool:
return isinstance(arg, int) or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_int_values
)
def is_sym_bool_arg(self, arg) -> bool:
return isinstance(arg, bool) or (
isinstance(arg, torch.fx.Node)
and arg.name in self.graph_state.sym_bool_values
)
def serialize_input(
self, arg, arg_type: Optional[torch._C.Argument] = None
) -> Argument:
import torch._inductor.ir as inductor_ir
inductor_tensor_buffers = (
inductor_ir.Buffer,
inductor_ir.ReinterpretView,
)
if isinstance(arg, torch.fx.Node):
if arg.op == "get_attr":
assert isinstance(arg.target, str)
attr = getattr(arg.graph.owning_module, arg.target)
if isinstance(attr, torch.Tensor):
raise SerializeError(
"getattr nodes containing tensors should not appear in the graph"
)
elif isinstance(attr, torch.fx.GraphModule):
with self.save_graph_state():
graph = self.serialize_graph(attr)
return Argument.create(
as_graph=GraphArgument(name=arg.target, graph=graph)
)
else:
raise SerializeError(
f"Unsupported getattr attribute {arg.target} with type: {type(attr)}"
)
elif self.is_sym_int_arg(arg):
return Argument.create(
as_sym_int=SymIntArgument.create(as_name=arg.name)
)
elif self.is_sym_bool_arg(arg):
return Argument.create(
as_sym_bool=SymBoolArgument.create(as_name=arg.name)
)
else:
if isinstance(arg.meta["val"], ep.CustomObjArgument):
return Argument.create(
as_custom_obj=CustomObjArgument(
name=arg.name, class_fqn=arg.meta["val"].class_fqn
)
)
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, inductor_tensor_buffers):
# Other branches are for arguments in fx node.
# This is a special branch for handling buffers (representing tensor arguments)
# for inductor's ExternalFallbackNode
# export_extern_kernel_node() is using this function to serialize arguments
arg_name = arg.get_name()
assert arg_name is not None, "Buffer must have valid name"
return Argument.create(as_tensor=TensorArgument(name=arg_name))
elif isinstance(arg, torch.SymInt):
# This is a special branch for handling SymInt args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_int_arg(arg) being true
return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg)))
elif isinstance(arg, bool):
return Argument.create(as_bool=arg)
elif isinstance(arg, str):
return Argument.create(as_string=arg)
elif isinstance(arg, int):
return Argument.create(as_int=arg)
elif isinstance(arg, float):
return Argument.create(as_float=arg)
elif arg is None:
return Argument.create(as_none=())
elif isinstance(arg, (list, tuple)):
if len(arg) == 0:
if arg_type is not None:
if isinstance(arg_type, torch.OptionalType):
arg_type = arg_type.getElementType() # type: ignore[assignment]
assert isinstance(arg_type, torch.ListType)
elem_type = arg_type.getElementType()
if isinstance(elem_type, torch.OptionalType):
elem_type = elem_type.getElementType()
if isinstance(elem_type, torch.BoolType):
return Argument.create(as_bools=[])
elif isinstance(elem_type, torch.IntType):
return Argument.create(as_ints=[])
elif isinstance(elem_type, torch.FloatType):
return Argument.create(as_floats=[])
elif isinstance(elem_type, torch.StringType):
return Argument.create(as_strings=[])
elif isinstance(elem_type, torch.TensorType):
return Argument.create(as_tensors=[])
else:
# I believe empty symint lists default to ints, but
# please file an issue if this is not the case
raise SerializeError(f"Empty list with type {elem_type} nyi.")
else:
# We could serialize this by default to a tensor list. This
# is needed in the HOO case
log.warning(
"Unsure how to serialize the given empty list, "
"as we don't know what is the type of this argument. "
"Serializing it as a tensor list by default."
)
return Argument.create(as_tensors=[])
# Must check bool first, as bool is also treated as int
if all(isinstance(a, bool) for a in arg):
return Argument.create(as_bools=list(arg))
elif all(isinstance(a, int) for a in arg):
return Argument.create(as_ints=list(arg))
elif all(isinstance(a, float) for a in arg):
return Argument.create(as_floats=list(arg))
elif all(isinstance(a, str) for a in arg):
return Argument.create(as_strings=list(arg))
elif all(isinstance(a, torch.SymInt) for a in arg):
# This is a special branch for handling SymInt args in inductor's
# ExternalFallbackNode.
# For regular FX graph, SymInt arg should be a fx.Node with
# self.is_sym_int_arg(arg) being true
return Argument.create(
as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg]
)
elif all(self.is_sym_int_arg(a) for a in arg):
# list of sym_ints
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymIntArgument.create(as_name=a.name))
elif isinstance(a, int):
values.append(SymIntArgument.create(as_int=a))
return Argument.create(as_sym_ints=values)
elif all(self.is_sym_bool_arg(a) for a in arg):
# list of sym_bools
values = []
for a in arg:
if isinstance(a, torch.fx.Node):
values.append(SymBoolArgument.create(as_name=a.name))
elif isinstance(a, bool):
values.append(SymBoolArgument.create(as_bool=a))
return Argument.create(as_sym_bools=values)
elif all(isinstance(a, torch.fx.Node) for a in arg):
# list of tensors
arguments = []
for a in arg:
if a.op == "get_attr":
raise SerializeError(
"getattr nodes containing tensors should not appear in the graph"
)
arguments.append(TensorArgument(name=a.name))
return Argument.create(as_tensors=arguments)
elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg):
# list of optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=())
elif isinstance(a, torch.fx.Node):
return OptionalTensorArgument.create(
as_tensor=TensorArgument(name=a.name)
)
else:
raise SerializeError(f"Unsupported list/tuple argument: {a}")
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
elif all(isinstance(a, inductor_tensor_buffers) for a in arg):
# list of inductor buffers
return Argument.create(
as_tensors=[TensorArgument(name=a.get_name()) for a in arg],
)
elif all(
isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg
):
# list of inductor buffers as optional tensors
def serialize_optional_tensor_args(a):
if a is None:
return OptionalTensorArgument.create(as_none=())
elif isinstance(a, inductor_tensor_buffers):
return OptionalTensorArgument.create(
as_tensor=TensorArgument(name=a.get_name())
)
else:
raise SerializeError(f"Unsupported list/tuple argument: {a}")
return Argument.create(
as_optional_tensors=list(map(serialize_optional_tensor_args, arg))
)
else:
raise SerializeError(
f"Unsupported list/tuple argument type: {[type(a) for a in arg]}"
)
elif isinstance(arg, torch.dtype):
return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg])
elif isinstance(arg, torch.device):
return Argument.create(as_device=Device(type=arg.type, index=arg.index))
elif isinstance(arg, torch.memory_format):
return Argument.create(
as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]
)
elif isinstance(arg, torch.layout):
return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg])
elif isinstance(arg, torch._C.ScriptObject):
if not (
arg._has_method("__getstate__") # type: ignore[attr-defined]
and arg._has_method("__setstate__") # type: ignore[attr-defined]
):
raise SerializeError(
f"Unable to serialize custom class {arg}. Please define "
"serialization methods via def_pickle()."
)
# Custom objects through torchind are serializable with pickle,
# through implementing the .def_pickle function. This should result
# in the object containing a __getstate__ and __setstate__
# serialize/deserialize function.
custom_obj_name = f"_custom_obj_{len(self.custom_objs)}"
self.custom_objs[custom_obj_name] = arg
class_fqn = arg._type().qualified_name() # type: ignore[attr-defined]
return Argument.create(
as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)
)
elif isinstance(arg, torch._ops.OpOverload):
return Argument.create(as_operator=self.serialize_operator(arg))
else:
raise SerializeError(f"Unsupported argument type: {type(arg)}")
def serialize_tensor_output(self, name, meta_val) -> TensorArgument:
assert name not in self.graph_state.tensor_values
self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val)
return TensorArgument(name=name)
def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument:
assert name not in self.graph_state.sym_int_values
self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val)
return SymIntArgument.create(as_name=name)
def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument:
assert name not in self.graph_state.sym_bool_values
self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val)
return SymBoolArgument.create(as_name=name)
def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
if spec.kind == ep.InputKind.USER_INPUT:
if isinstance(spec.arg, ep.ConstantArgument):
if isinstance(spec.arg.value, int):
constant_spec = ConstantValue.create(as_int=spec.arg.value)
elif isinstance(spec.arg.value, bool):
constant_spec = ConstantValue.create(as_bool=spec.arg.value)
elif isinstance(spec.arg.value, str):
constant_spec = ConstantValue.create(as_string=spec.arg.value)
elif isinstance(spec.arg.value, float):
constant_spec = ConstantValue.create(as_float=spec.arg.value)
elif spec.arg.value is None:
constant_spec = ConstantValue.create(as_none=())
else:
raise SerializeError(
f"Unhandled constant input {spec.arg.value} to serialize"
)
return InputSpec.create(
constant_input=ConstantInputSpec(
name=spec.arg.name, value=constant_spec
)
)
else:
return InputSpec.create(
user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg))
)
elif spec.kind == ep.InputKind.PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
parameter=InputToParameterSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.InputKind.BUFFER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
assert spec.persistent is not None
return InputSpec.create(
buffer=InputToBufferSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target,
persistent=spec.persistent,
)
)
elif spec.kind == ep.InputKind.CONSTANT_TENSOR:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return InputSpec.create(
tensor_constant=InputToTensorConstantSpec(
arg=TensorArgument(name=spec.arg.name),
tensor_constant_name=spec.target,
)
)
elif spec.kind == ep.InputKind.CUSTOM_OBJ:
assert spec.target is not None
assert isinstance(spec.arg, ep.CustomObjArgument)
return InputSpec.create(
custom_obj=InputToCustomObjSpec(
arg=CustomObjArgument(
name=spec.arg.name, class_fqn=spec.arg.class_fqn
),
custom_obj_name=spec.target,
)
)
elif spec.kind == ep.InputKind.TOKEN:
assert isinstance(spec.arg, ep.TokenArgument)
return InputSpec.create(
token=InputTokenSpec(
arg=TokenArgument(name=spec.arg.name),
)
)
else:
raise AssertionError(f"Unknown argument kind: {spec}")
def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
if spec.kind == ep.OutputKind.USER_OUTPUT:
return OutputSpec.create(
user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg))
)
elif spec.kind == ep.OutputKind.LOSS_OUTPUT:
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name))
)