Skip to content

Commit

Permalink
Add walk operator (#19333)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored and lexierule committed Jan 31, 2024
1 parent 9c76a14 commit 9bc1664
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/lightning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.functions import map, optimize
from lightning.data.streaming.functions import map, optimize, walk

__all__ = [
"LightningDataset",
Expand All @@ -11,4 +11,5 @@
"LightningIterableDataset",
"map",
"optimize",
"walk",
]
51 changes: 50 additions & 1 deletion src/lightning/data/streaming/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent.futures
import inspect
import os
from datetime import datetime
from functools import partial
from pathlib import Path
from types import FunctionType
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch

Expand Down Expand Up @@ -286,3 +287,51 @@ def optimize(
num_nodes,
machine,
)


def _listdir(folder: str) -> Tuple[str, List[str]]:
return folder, os.listdir(folder)


class walk:
"""This class is an optimized version of os.walk for listing files and folders from cloud filesystem.
Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call.
"""

def __init__(self, folder: str, max_workers: Optional[int] = os.cpu_count()) -> None:
self.folders = [folder]
self.max_workers = max_workers or 1
self.futures: List[concurrent.futures.Future] = []

def __iter__(self) -> Any:
"""This function queues the folders to perform listdir across multiple workers."""
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
while len(self.folders):
folder = self.folders.pop(0)
future = executor.submit(_listdir, folder)
self.futures.append(future)

while self.futures:
for future in concurrent.futures.as_completed(self.futures):
filenames = []
folders = []

folder, files_or_folders = future.result()
self.futures = [f for f in self.futures if f != future]

for file_or_folder in files_or_folders:
if os.path.isfile(os.path.join(folder, file_or_folder)):
filenames.append(file_or_folder)
else:
folders.append(file_or_folder)
self.folders.append(os.path.join(folder, file_or_folder))

yield folder, folders, filenames

while len(self.folders) and len(self.futures) <= self.max_workers * 2:
folder = self.folders.pop(0)
future = executor.submit(_listdir, folder)
self.futures.append(future)
return
18 changes: 17 additions & 1 deletion tests/tests_data/streaming/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import sys
from unittest import mock

import pytest
from lightning.data.streaming.functions import _get_input_dir, os
from lightning.data import walk
from lightning.data.streaming.functions import _get_input_dir


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand All @@ -19,3 +21,17 @@ def fn(*_, **__):

with pytest.raises(ValueError, match="The provided item didn't contain any filepaths."):
assert _get_input_dir(["", "/teamspace/studios/asd/b"])


def test_walk(tmpdir):
for i in range(5):
folder_path = os.path.join(tmpdir, str(i))
os.makedirs(folder_path, exist_ok=True)
for j in range(5):
filepath = os.path.join(folder_path, f"{j}.txt")
with open(filepath, "w") as f:
f.write("hello world !")

walks_os = sorted(os.walk(tmpdir))
walks_function = sorted(walk(tmpdir))
assert walks_os == walks_function

0 comments on commit 9bc1664

Please sign in to comment.