forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_deploy.py
75 lines (62 loc) · 2.74 KB
/
_deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import io
import torch
import importlib
from torch.package._custom_import_pickler import create_custom_import_pickler
from torch.package.package_importer import _UnpicklerWrapper
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
from torch.serialization import _maybe_decode_ascii
from typing import Callable
from types import ModuleType
def _save_storages(importer, obj):
serialized_storages = []
serialized_dtypes = []
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if torch.is_storage(obj):
serialized_storages.append(obj)
serialized_dtypes.append(obj.dtype)
return ('storage', len(serialized_storages) - 1)
return None
# Write the pickle data for `obj`
data_buf = io.BytesIO()
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
importers: Importer
if importer is not None:
importers = OrderedImporter(importer, sys_importer)
else:
importers = sys_importer
pickler = create_custom_import_pickler(data_buf, importers)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
def _load_storages(id, zip_reader, obj_bytes, serialized_storages):
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
return serialized_storages[data[0]]
import_module : Callable[[str], ModuleType] = importlib.import_module
if zip_reader is not None:
importer = _get_package(zip_reader)
def import_module(name: str):
try:
return importer.import_module(name)
except ModuleNotFoundError:
return importlib.import_module(name)
unpickler = _UnpicklerWrapper(import_module, io.BytesIO(obj_bytes))
unpickler.persistent_load = persistent_load
result = _deploy_objects[id] = unpickler.load()
return result
def _get_package(zip_reader):
if zip_reader not in _raw_packages:
_raw_packages[zip_reader] = PackageImporter(zip_reader)
return _raw_packages[zip_reader]
_raw_packages: dict = {}
_deploy_objects: dict = {}