Skip to content

Commit 800c572

Browse files
authored
ref(jpeg/deserialize): Replace PIL with torchvision's decode_image for robust JPEG deserialization (#660)
* feat(serializers): Improve JPEG deserialization to handle PNG images with JPEG extensions * remove `mode`.
1 parent fdc8290 commit 800c572

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

src/litdata/streaming/serializers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,16 @@ def serialize(self, item: Any) -> tuple[bytes, Optional[str]]:
129129
raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.")
130130

131131
def deserialize(self, data: bytes) -> torch.Tensor:
132-
from torchvision.io import decode_jpeg
133-
from torchvision.transforms.functional import pil_to_tensor
132+
from torchvision.io import decode_image, decode_jpeg
134133

135134
array = torch.frombuffer(data, dtype=torch.uint8)
136-
# Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL
135+
# Try decoding as JPEG. Some datasets (e.g., ImageNet) may have PNG images with a JPEG extension,
136+
# which will cause decode_jpeg to fail. In that case, fall back to a generic image decoder.
137137
with suppress(RuntimeError):
138138
return decode_jpeg(array)
139139

140-
# Fallback to PIL
141-
if not _PIL_AVAILABLE:
142-
raise ModuleNotFoundError("PIL is required. Run `pip install pillow`")
143-
from PIL import Image
144-
145-
img = Image.open(io.BytesIO(data))
146-
return pil_to_tensor(img)
140+
# Fallback: decode as a generic image (handles PNG, etc.)
141+
return decode_image(array)
147142

148143
def can_serialize(self, item: Any) -> bool:
149144
if not _PIL_AVAILABLE:

0 commit comments

Comments
 (0)