@@ -37,6 +37,22 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
3737 self ._config = config
3838 self ._chunks = chunks
3939 self ._serializers = serializers
40+ self ._data_format = self ._config ["data_format" ]
41+ self ._shift_idx = len (self ._data_format ) * 4
42+
43+ # setup the serializers on restart
44+ for data_format in self ._data_format :
45+ serializer = self ._serializers [self ._data_format_to_key (data_format )]
46+ serializer .setup (data_format )
47+
48+ @functools .lru_cache (maxsize = 128 )
49+ def _data_format_to_key (self , data_format : str ) -> str :
50+ if ":" in data_format :
51+ serialier , serializer_sub_type = data_format .split (":" )
52+ if serializer_sub_type in self ._serializers :
53+ return serializer_sub_type
54+ return serialier
55+ return data_format
4056
4157 def state_dict (self ) -> Dict :
4258 return {}
@@ -109,21 +125,12 @@ def load_item_from_chunk(
109125
110126 return self .deserialize (data )
111127
112- @functools .lru_cache (maxsize = 128 )
113- def _data_format_to_key (self , data_format : str ) -> str :
114- if ":" in data_format :
115- serialier , serializer_sub_type = data_format .split (":" )
116- if serializer_sub_type in self ._serializers :
117- return serializer_sub_type
118- return serialier
119- return data_format
120-
121128 def deserialize (self , raw_item_data : bytes ) -> "PyTree" :
122129 """Deserialize the raw bytes into their python equivalent."""
123- idx = len ( self ._config [ "data_format" ]) * 4
130+ idx = self ._shift_idx
124131 sizes = np .frombuffer (raw_item_data [:idx ], np .uint32 )
125132 data = []
126- for size , data_format in zip (sizes , self ._config [ "data_format" ] ):
133+ for size , data_format in zip (sizes , self ._data_format ):
127134 serializer = self ._serializers [self ._data_format_to_key (data_format )]
128135 data_bytes = raw_item_data [idx : idx + size ]
129136 data .append (serializer .deserialize (data_bytes ))
0 commit comments