Skip to content

Commit 4d9bdec

Browse files
authored
Simplify dependencies (#87)
1 parent 7737014 commit 4d9bdec

30 files changed

+415
-68
lines changed

.github/workflows/ci-testing.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ jobs:
6262
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}
6363
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
6464

65+
- name: Install package & dependencies on Ubuntu
66+
if: matrix.os == 'ubuntu-latest'
67+
run: |
68+
pip --version
69+
pip install -e '.[extras]' -r requirements/test.txt -U -q --find-links $TORCH_URL
70+
pip list
71+
6572
- name: Install package & dependencies
73+
if: matrix.os != 'ubuntu-latest'
6674
run: |
6775
pip --version
6876
pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL

requirements.txt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
1-
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
2-
lightning-utilities >=0.8.0, <0.12.0
31
torch >=2.1.0
42
filelock
5-
tqdm
63
numpy
7-
torchvision
8-
pillow
9-
viztracer
10-
pyarrow
114
boto3[crt]
5+
requests

requirements/extras.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
torchvision
2+
pillow
3+
viztracer
4+
pyarrow
5+
tqdm
6+
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ pytest-rerunfailures ==14.0
66
pytest-random-order ==1.1.1
77
pandas
88
lightning
9+
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
_PATH_ROOT = os.path.dirname(__file__)
1111
_PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
12-
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")
12+
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements")
1313

1414

1515
def _load_py_module(fname, pkg="litdata"):

src/litdata/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
from lightning_utilities.core.imports import RequirementCache
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
213

314
from litdata.__about__ import * # noqa: F403
15+
from litdata.imports import RequirementCache
416
from litdata.processing.functions import map, optimize, walk
517
from litdata.streaming.combined import CombinedStreamingDataset
618
from litdata.streaming.dataloader import StreamingDataLoader

src/litdata/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import numpy as np
1818
import torch
19-
from lightning_utilities.core.imports import RequirementCache
19+
20+
from litdata.imports import RequirementCache
2021

2122
_INDEX_FILENAME = "index.json"
2223
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
@@ -26,7 +27,7 @@
2627
# This is required for full pytree serialization / deserialization support
2728
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2829
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
29-
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
30+
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
3031
_BOTO3_AVAILABLE = RequirementCache("boto3")
3132
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
3233
_ZSTD_AVAILABLE = RequirementCache("zstd")

src/litdata/imports.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import importlib
15+
from functools import lru_cache
16+
from importlib.util import find_spec
17+
from typing import Optional, TypeVar
18+
19+
import pkg_resources
20+
from typing_extensions import ParamSpec
21+
22+
T = TypeVar("T")
23+
P = ParamSpec("P")
24+
25+
26+
@lru_cache
27+
def package_available(package_name: str) -> bool:
28+
"""Check if a package is available in your environment.
29+
30+
>>> package_available('os')
31+
True
32+
>>> package_available('bla')
33+
False
34+
35+
"""
36+
try:
37+
return find_spec(package_name) is not None
38+
except ModuleNotFoundError:
39+
return False
40+
41+
42+
@lru_cache
43+
def module_available(module_path: str) -> bool:
44+
"""Check if a module path is available in your environment.
45+
46+
>>> module_available('os')
47+
True
48+
>>> module_available('os.bla')
49+
False
50+
>>> module_available('bla.bla')
51+
False
52+
53+
"""
54+
module_names = module_path.split(".")
55+
if not package_available(module_names[0]):
56+
return False
57+
try:
58+
importlib.import_module(module_path)
59+
except ImportError:
60+
return False
61+
return True
62+
63+
64+
class RequirementCache:
65+
"""Boolean-like class to check for requirement and module availability.
66+
67+
Args:
68+
requirement: The requirement to check, version specifiers are allowed.
69+
module: The optional module to try to import if the requirement check fails.
70+
71+
>>> RequirementCache("torch>=0.1")
72+
Requirement 'torch>=0.1' met
73+
>>> bool(RequirementCache("torch>=0.1"))
74+
True
75+
>>> bool(RequirementCache("torch>100.0"))
76+
False
77+
>>> RequirementCache("torch")
78+
Requirement 'torch' met
79+
>>> bool(RequirementCache("torch"))
80+
True
81+
>>> bool(RequirementCache("unknown_package"))
82+
False
83+
84+
"""
85+
86+
def __init__(self, requirement: str, module: Optional[str] = None) -> None:
87+
self.requirement = requirement
88+
self.module = module
89+
90+
def _check_requirement(self) -> None:
91+
if hasattr(self, "available"):
92+
return
93+
try:
94+
# first try the pkg_resources requirement
95+
pkg_resources.require(self.requirement)
96+
self.available = True
97+
self.message = f"Requirement {self.requirement!r} met"
98+
except Exception as ex:
99+
self.available = False
100+
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
101+
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
102+
if not requirement_contains_version_specifier or self.module is not None:
103+
module = self.requirement if self.module is None else self.module
104+
# sometimes `pkg_resources.require()` fails but the module is importable
105+
self.available = module_available(module)
106+
if self.available:
107+
self.message = f"Module {module!r} available"
108+
109+
def __bool__(self) -> bool:
110+
"""Format as bool."""
111+
self._check_requirement()
112+
return self.available
113+
114+
def __str__(self) -> str:
115+
"""Format as string."""
116+
self._check_requirement()
117+
return self.message
118+
119+
def __repr__(self) -> str:
120+
"""Format as string."""
121+
return self.__str__()

src/litdata/processing/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.

src/litdata/processing/data_processor.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
114
import concurrent
215
import json
316
import logging
@@ -19,16 +32,16 @@
1932

2033
import numpy as np
2134
import torch
22-
from tqdm.auto import tqdm as _tqdm
2335

2436
from litdata.constants import (
2537
_BOTO3_AVAILABLE,
2638
_DEFAULT_FAST_DEV_RUN_ITEMS,
2739
_INDEX_FILENAME,
2840
_IS_IN_STUDIO,
29-
_LIGHTNING_CLOUD_LATEST,
41+
_LIGHTNING_CLOUD_AVAILABLE,
3042
_TORCH_GREATER_EQUAL_2_1_0,
3143
)
44+
from litdata.imports import RequirementCache
3245
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
3346
from litdata.processing.utilities import _create_dataset
3447
from litdata.streaming import Cache
@@ -39,10 +52,15 @@
3952
from litdata.utilities.broadcast import broadcast_object
4053
from litdata.utilities.packing import _pack_greedily
4154

55+
_TQDM_AVAILABLE = RequirementCache("tqdm")
56+
57+
if _TQDM_AVAILABLE:
58+
from tqdm.auto import tqdm as _tqdm
59+
4260
if _TORCH_GREATER_EQUAL_2_1_0:
4361
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads
4462

45-
if _LIGHTNING_CLOUD_LATEST:
63+
if _LIGHTNING_CLOUD_AVAILABLE:
4664
from lightning_cloud.openapi import V1DatasetType
4765

4866

@@ -944,15 +962,16 @@ def run(self, data_recipe: DataRecipe) -> None:
944962
print("Workers are ready ! Starting data processing...")
945963

946964
current_total = 0
947-
pbar = _tqdm(
948-
desc="Progress",
949-
total=num_items,
950-
smoothing=0,
951-
position=-1,
952-
mininterval=1,
953-
leave=True,
954-
dynamic_ncols=True,
955-
)
965+
if _TQDM_AVAILABLE:
966+
pbar = _tqdm(
967+
desc="Progress",
968+
total=num_items,
969+
smoothing=0,
970+
position=-1,
971+
mininterval=1,
972+
leave=True,
973+
dynamic_ncols=True,
974+
)
956975
num_nodes = _get_num_nodes()
957976
node_rank = _get_node_rank()
958977
total_num_items = len(user_items)
@@ -970,7 +989,8 @@ def run(self, data_recipe: DataRecipe) -> None:
970989
self.workers_tracker[index] = counter
971990
new_total = sum(self.workers_tracker.values())
972991

973-
pbar.update(new_total - current_total)
992+
if _TQDM_AVAILABLE:
993+
pbar.update(new_total - current_total)
974994

975995
current_total = new_total
976996
if current_total == num_items:
@@ -985,7 +1005,8 @@ def run(self, data_recipe: DataRecipe) -> None:
9851005
if all(not w.is_alive() for w in self.workers):
9861006
raise RuntimeError("One of the worker has failed")
9871007

988-
pbar.close()
1008+
if _TQDM_AVAILABLE:
1009+
pbar.close()
9891010

9901011
# TODO: Understand why it hangs.
9911012
if num_nodes == 1:

0 commit comments

Comments
 (0)