-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathsimplenet_infer_python.py
80 lines (62 loc) · 2.36 KB
/
simplenet_infer_python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Load saved SimpleNet to TorchScript and run inference example."""
import os
import torch
def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
"""
Load TorchScript SimpleNet and run inference with example Tensor.
Parameters
----------
saved_model : str
location of SimpleNet model saved to Torchscript
device : str
Torch device to run model on, 'cpu' or 'cuda'
batch_size : int
batch size to run (default 1)
Returns
-------
output : torch.Tensor
result of running inference on model with example Tensor input
"""
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]).repeat(batch_size, 1)
if device == "cpu":
# Load saved TorchScript model
model = torch.jit.load(saved_model)
# Inference
output = model.forward(input_tensor)
elif device == "cuda":
# All previously saved modules, no matter their device, are first
# loaded onto CPU, and then are moved to the devices they were saved
# from, so we don't need to manually transfer the model to the GPU
model = torch.jit.load(saved_model)
input_tensor_gpu = input_tensor.to(torch.device("cuda"))
output_gpu = model.forward(input_tensor_gpu)
output = output_gpu.to(torch.device("cpu"))
else:
device_error = f"Device '{device}' not recognised."
raise ValueError(device_error)
return output
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--filepath",
help="Path to the file containing the PyTorch model",
type=str,
default=os.path.dirname(__file__),
)
parsed_args = parser.parse_args()
filepath = parsed_args.filepath
saved_model_file = os.path.join(filepath, "saved_simplenet_model_cpu.pt")
device_to_run = "cpu"
batch_size_to_run = 1
with torch.no_grad():
result = deploy(saved_model_file, device_to_run, batch_size_to_run)
print(result)
if not torch.allclose(result, torch.Tensor([0.0, 2.0, 4.0, 6.0, 8.0])):
result_error = (
f"result:\n{result}\ndoes not match expected value:\n"
f"{torch.Tensor([0.0, 2.0, 4.0, 6.0, 8.0])}"
)
raise ValueError(result_error)