Skip to content

Commit 968722f

Browse files
authored
Streamingdataset torch compatibility (#108)
1 parent 577e181 commit 968722f

File tree

12 files changed

+1565
-31
lines changed

12 files changed

+1565
-31
lines changed

.github/workflows/ci-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
artifact-name: dist-packages-${{ github.sha }}
3131
testing-matrix: |
3232
{
33-
"os": ["ubuntu-latest", "macos-latest", "windows-latest"],
33+
"os": ["ubuntu-latest", "macos-13", "windows-latest"],
3434
"python-version": ["3.8", "3.10"]
3535
}
3636

.github/workflows/ci-testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
os: [ubuntu-latest, macOS-latest, windows-latest]
19+
os: [ubuntu-latest, macos-13, windows-latest]
2020
python-version: [3.9]
2121
requires: ["oldest", "latest"]
2222

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ lint.ignore = [
7272
exclude = [
7373
".git",
7474
"docs",
75-
"_notebooks"
75+
"_notebooks",
76+
"src/litdata/utilities/_pytree.py",
7677
]
7778
lint.ignore-init-module-imports = true
7879

src/litdata/processing/data_processor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
_INDEX_FILENAME,
4040
_IS_IN_STUDIO,
4141
_LIGHTNING_CLOUD_AVAILABLE,
42-
_TORCH_GREATER_EQUAL_2_1_0,
4342
)
4443
from litdata.imports import RequirementCache
4544
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
@@ -49,6 +48,7 @@
4948
from litdata.streaming.client import S3Client
5049
from litdata.streaming.dataloader import StreamingDataLoader
5150
from litdata.streaming.resolver import _resolve_dir
51+
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
5252
from litdata.utilities.broadcast import broadcast_object
5353
from litdata.utilities.packing import _pack_greedily
5454

@@ -57,9 +57,6 @@
5757
if _TQDM_AVAILABLE:
5858
from tqdm.auto import tqdm as _tqdm
5959

60-
if _TORCH_GREATER_EQUAL_2_1_0:
61-
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads
62-
6360
if _LIGHTNING_CLOUD_AVAILABLE:
6461
from lightning_cloud.openapi import V1DatasetType
6562

src/litdata/processing/functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424

25-
from litdata.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
25+
from litdata.constants import _IS_IN_STUDIO
2626
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
2727
from litdata.processing.readers import BaseReader
2828
from litdata.processing.utilities import optimize_dns_context
@@ -34,9 +34,7 @@
3434
_execute,
3535
_resolve_dir,
3636
)
37-
38-
if _TORCH_GREATER_EQUAL_2_1_0:
39-
from torch.utils._pytree import tree_flatten
37+
from litdata.utilities._pytree import tree_flatten
4038

4139

4240
def _get_indexed_paths(data: Any) -> Dict[int, str]:

src/litdata/streaming/cache.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from litdata.constants import (
1919
_INDEX_FILENAME,
20-
_TORCH_GREATER_EQUAL_2_1_0,
2120
)
2221
from litdata.streaming.item_loader import BaseItemLoader
2322
from litdata.streaming.reader import BinaryReader
@@ -56,9 +55,6 @@ def __init__(
5655
5756
"""
5857
super().__init__()
59-
if not _TORCH_GREATER_EQUAL_2_1_0:
60-
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
61-
6258
input_dir = _resolve_dir(input_dir)
6359
self._cache_dir = input_dir.path
6460
assert self._cache_dir

src/litdata/streaming/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@
1515
import os
1616
from typing import Any, Dict, List, Optional, Tuple
1717

18-
from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
18+
from litdata.constants import _INDEX_FILENAME
1919
from litdata.streaming.compression import _COMPRESSORS, Compressor
2020
from litdata.streaming.downloader import get_downloader_cls
2121
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
2222
from litdata.streaming.sampler import ChunkedIndex
2323
from litdata.streaming.serializers import Serializer
24-
25-
if _TORCH_GREATER_EQUAL_2_1_0:
26-
from torch.utils._pytree import tree_unflatten, treespec_loads
24+
from litdata.utilities._pytree import tree_unflatten, treespec_loads
2725

2826

2927
class ChunksConfig:

src/litdata/streaming/dataloader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from torch.utils.data.sampler import BatchSampler, Sampler
3535

36-
from litdata.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
36+
from litdata.constants import _DEFAULT_CHUNK_BYTES, _VIZ_TRACKER_AVAILABLE
3737
from litdata.streaming import Cache
3838
from litdata.streaming.combined import (
3939
__NUM_SAMPLES_YIELDED_KEY__,
@@ -42,11 +42,9 @@
4242
)
4343
from litdata.streaming.dataset import StreamingDataset
4444
from litdata.streaming.sampler import CacheBatchSampler
45+
from litdata.utilities._pytree import tree_flatten
4546
from litdata.utilities.env import _DistributedEnv
4647

47-
if _TORCH_GREATER_EQUAL_2_1_0:
48-
from torch.utils._pytree import tree_flatten
49-
5048
logger = logging.Logger(__name__)
5149

5250

src/litdata/streaming/item_loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,9 @@
2222

2323
from litdata.constants import (
2424
_TORCH_DTYPES_MAPPING,
25-
_TORCH_GREATER_EQUAL_2_1_0,
2625
)
2726
from litdata.streaming.serializers import Serializer
28-
29-
if _TORCH_GREATER_EQUAL_2_1_0:
30-
from torch.utils._pytree import PyTree, tree_unflatten
27+
from litdata.utilities._pytree import PyTree, tree_unflatten
3128

3229

3330
class BaseItemLoader(ABC):

src/litdata/streaming/writer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,14 @@
2121
import numpy as np
2222
import torch
2323

24-
from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
24+
from litdata.constants import _INDEX_FILENAME
2525
from litdata.processing.utilities import get_worker_rank
2626
from litdata.streaming.compression import _COMPRESSORS, Compressor
2727
from litdata.streaming.serializers import Serializer, _get_serializers
28+
from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps
2829
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
2930
from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes
3031

31-
if _TORCH_GREATER_EQUAL_2_1_0:
32-
from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps
33-
3432

3533
@dataclass
3634
class Item:

0 commit comments

Comments
 (0)