-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathpt2ts.py
184 lines (154 loc) · 6.22 KB
/
pt2ts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Load a PyTorch model and convert it to TorchScript."""
import os
import sys
from typing import Optional
# FPTLIB-TODO
# Add a module import with your model here:
# This example assumes the model architecture is in an adjacent module `my_ml_model.py`
import resnet18
import torch
def script_to_torchscript(
model: torch.nn.Module, filename: Optional[str] = "scripted_model.pt"
) -> None:
"""
Save PyTorch model to TorchScript using scripting.
Parameters
----------
model : torch.NN.Module
a PyTorch model
filename : str
name of file to save to
"""
print("Saving model using scripting...", end="")
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
scripted_model = torch.jit.script(model)
# print(scripted_model.code)
scripted_model.save(filename)
print("done.")
def trace_to_torchscript(
model: torch.nn.Module,
dummy_input: torch.Tensor,
filename: Optional[str] = "traced_model.pt",
) -> None:
"""
Save PyTorch model to TorchScript using tracing.
Parameters
----------
model : torch.NN.Module
a PyTorch model
dummy_input : torch.Tensor
appropriate size Tensor to act as input to model
filename : str
name of file to save to
"""
print("Saving model using tracing...", end="")
# FIXME: torch.jit.optimize_for_inference() when PyTorch issue #81085 is resolved
traced_model = torch.jit.trace(model, dummy_input)
# traced_model.save(filename)
frozen_model = torch.jit.freeze(traced_model)
## print(frozen_model.graph)
## print(frozen_model.code)
frozen_model.save(filename)
print("done.")
def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Module:
"""
Load a TorchScript from file.
Parameters
----------
filename : str
name of file containing TorchScript model
"""
model = torch.jit.load(filename)
return model
if __name__ == "__main__":
# =====================================================
# Load model and prepare for saving
# =====================================================
precision = torch.float32
# FPTLIB-TODO
# Load a pre-trained PyTorch model
# Insert code here to load your model as `trained_model`.
# This example assumes my_ml_model has a method `initialize` to load
# architecture, weights, and place in inference mode
trained_model = resnet18.initialize(precision)
# Switch off specific layers/parts of the model that behave
# differently during training and inference.
# This may have been done by the user already, so just make sure here.
trained_model.eval()
# =====================================================
# Prepare dummy input and check model runs
# =====================================================
# FPTLIB-TODO
# Generate a dummy input Tensor `dummy_input` to the model of appropriate size.
# This example assumes one input of size (1x3x244x244) -- one RGB (3) image
# of resolution 244x244 in a batch size of 1.
trained_model_dummy_input = torch.ones(1, 3, 224, 224)
# FPTLIB-TODO
# Uncomment the following lines to save for inference on GPU (rather than CPU):
# device = torch.device('cuda')
# trained_model = trained_model.to(device)
# trained_model.eval()
# trained_model_dummy_input = trained_model_dummy_input.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
# If something isn't working This will generate an error
trained_model_dummy_outputs = trained_model(
trained_model_dummy_input,
)
# =====================================================
# Save model
# =====================================================
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
saved_ts_filename = "saved_resnet18_model_cpu.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
# FPTLIB-TODO
# Save the PyTorch model using either scripting (recommended if possible) or tracing
# -----------
# Scripting
# -----------
script_to_torchscript(trained_model, filename=saved_ts_filename)
# -----------
# Tracing
# -----------
# trace_to_torchscript(
# trained_model, trained_model_dummy_input, filename=saved_ts_filename
# )
print(f"Saved model to TorchScript in '{saved_ts_filename}'.")
# =====================================================
# Check model saved OK
# =====================================================
# Load torchscript and run model as a test
# FPTLIB-TODO
# Scale inputs as above and, if required, move inputs and mode to GPU
trained_model_dummy_input = 2.0 * trained_model_dummy_input
trained_model_testing_outputs = trained_model(
trained_model_dummy_input,
)
ts_model = load_torchscript(filename=saved_ts_filename)
ts_model_outputs = ts_model(
trained_model_dummy_input,
)
if not isinstance(ts_model_outputs, tuple):
ts_model_outputs = (ts_model_outputs,)
if not isinstance(trained_model_testing_outputs, tuple):
trained_model_testing_outputs = (trained_model_testing_outputs,)
for ts_output, output in zip(ts_model_outputs, trained_model_testing_outputs):
if torch.all(ts_output.eq(output)):
print("Saved TorchScript model working as expected in a basic test.")
print("Users should perform further validation as appropriate.")
else:
model_error = (
"Saved Torchscript model is not performing as expected.\n"
"Consider using scripting if you used tracing, or investigate further."
)
raise RuntimeError(model_error)
# Check that the model file is created
filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
"cannot be found."
)
raise FileNotFoundError(torchscript_file_error)