Skip to content

Commit 58f7aeb

Browse files
authored
Added call to setup function of serializer class to set data format (#96)
1 parent c87662a commit 58f7aeb

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

src/litdata/streaming/item_loader.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
3737
self._config = config
3838
self._chunks = chunks
3939
self._serializers = serializers
40+
self._data_format = self._config["data_format"]
41+
self._shift_idx = len(self._data_format) * 4
42+
43+
# setup the serializers on restart
44+
for data_format in self._data_format:
45+
serializer = self._serializers[self._data_format_to_key(data_format)]
46+
serializer.setup(data_format)
47+
48+
@functools.lru_cache(maxsize=128)
49+
def _data_format_to_key(self, data_format: str) -> str:
50+
if ":" in data_format:
51+
serialier, serializer_sub_type = data_format.split(":")
52+
if serializer_sub_type in self._serializers:
53+
return serializer_sub_type
54+
return serialier
55+
return data_format
4056

4157
def state_dict(self) -> Dict:
4258
return {}
@@ -109,21 +125,12 @@ def load_item_from_chunk(
109125

110126
return self.deserialize(data)
111127

112-
@functools.lru_cache(maxsize=128)
113-
def _data_format_to_key(self, data_format: str) -> str:
114-
if ":" in data_format:
115-
serialier, serializer_sub_type = data_format.split(":")
116-
if serializer_sub_type in self._serializers:
117-
return serializer_sub_type
118-
return serialier
119-
return data_format
120-
121128
def deserialize(self, raw_item_data: bytes) -> "PyTree":
122129
"""Deserialize the raw bytes into their python equivalent."""
123-
idx = len(self._config["data_format"]) * 4
130+
idx = self._shift_idx
124131
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
125132
data = []
126-
for size, data_format in zip(sizes, self._config["data_format"]):
133+
for size, data_format in zip(sizes, self._data_format):
127134
serializer = self._serializers[self._data_format_to_key(data_format)]
128135
data_bytes = raw_item_data[idx : idx + size]
129136
data.append(serializer.deserialize(data_bytes))
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from unittest.mock import MagicMock
2+
3+
from litdata.streaming.item_loader import PyTreeLoader
4+
5+
6+
def test_serializer_setup():
7+
config_mock = MagicMock()
8+
config_mock.__getitem__.return_value = ["fake:12"]
9+
serializer_mock = MagicMock()
10+
item_loader = PyTreeLoader()
11+
item_loader.setup(config_mock, [], {"fake": serializer_mock})
12+
serializer_mock.setup._mock_mock_calls[0].args[0] == "fake:12"

0 commit comments

Comments
 (0)