`_ customer, you have the option to connect your existing cloud account to Lightning AI.
+This gives your organization the ability to keep all compute and data on your own cloud account and your Virtual Private Cloud (VPC).
+
+
+----
+
**********
-Next steps
+Learn more
**********
.. raw:: html
@@ -103,8 +138,8 @@ Next steps
.. displayitem::
- :header: Lightning Platform
- :description: Develop, Train and Deploy models on the cloud
+ :header: Lightning Studios
+ :description: Code together. Prototype. Train. Deploy. Host AI web apps. From your browser - with zero setup.
:col_css: col-md-4
:button_link: https://lightning.ai
:height: 150
diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst
index a8f7205596604..23e5215e7cb97 100644
--- a/docs/source-pytorch/common/checkpointing_expert.rst
+++ b/docs/source-pytorch/common/checkpointing_expert.rst
@@ -136,4 +136,37 @@ Note that you can load the distributed checkpoint even if the world size has cha
Convert a distributed checkpoint
********************************
-Coming soon.
+It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility:
+
+.. code-block:: bash
+
+ python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint
+
+You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.
+
+.. note::
+
+ All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command.
+ This function assumes you have enough free CPU memory to hold the entire checkpoint in memory.
+
+.. collapse:: Full example
+
+ Assuming you have saved a checkpoint ``epoch=0-step=3.ckpt`` using the examples above, run the following command to convert it:
+
+ .. code-block:: bash
+
+ cd lightning_logs/version_0/checkpoints
+ python -m lightning.pytorch.utilities.consolidate_checkpoint epoch=0-step=3.ckpt
+
+ This saves a new file ``epoch=0-step=3.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:
+
+ .. code-block:: python
+
+ import torch
+
+ checkpoint = torch.load("epoch=0-step=3.ckpt.consolidated")
+ print(list(checkpoint.keys()))
+ print(checkpoint["state_dict"]["model.transformer.decoder.layers.31.norm1.weight"])
+
+
+|
diff --git a/requirements/app/app.txt b/requirements/app/app.txt
index 6a042ae7c573b..8c77f23e666b2 100644
--- a/requirements/app/app.txt
+++ b/requirements/app/app.txt
@@ -1,6 +1,6 @@
lightning-cloud == 0.5.61 # Must be pinned to ensure compatibility
packaging
-typing-extensions >=4.4.0, <4.8.0
+typing-extensions >=4.4.0, <4.10.0
deepdiff >=5.7.0, <6.6.0
fsspec[http] >=2022.5.0, <2023.11.0
croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust.
diff --git a/requirements/data/test.txt b/requirements/data/test.txt
index d30343b08a628..38439e2d6705a 100644
--- a/requirements/data/test.txt
+++ b/requirements/data/test.txt
@@ -4,3 +4,4 @@ pytest-cov ==4.1.0
pytest-timeout ==2.1.0
pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
+viztracer
diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt
index d1496393a8b6f..64bd5cc6368bc 100644
--- a/requirements/fabric/base.txt
+++ b/requirements/fabric/base.txt
@@ -5,5 +5,5 @@ numpy >=1.17.2, <1.27.0
torch >=1.13.0, <2.2.0
fsspec[http] >=2022.5.0, <2023.11.0
packaging >=20.0, <=23.1
-typing-extensions >=4.4.0, <4.8.0
+typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.8.0, <0.10.0
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index 2b790dd277358..17af46699ece9 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -8,5 +8,5 @@ PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2023.11.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
-typing-extensions >=4.4.0, <4.8.0
+typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.8.0, <0.10.0
diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt
index 4bdfdcd7ade80..fddaece918b4a 100644
--- a/requirements/pytorch/extra.txt
+++ b/requirements/pytorch/extra.txt
@@ -5,7 +5,7 @@
matplotlib>3.1, <3.9.0
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
-jsonargparse[signatures] >=4.26.1, <4.27.0
+jsonargparse[signatures] >=4.26.1, <4.28.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes ==0.41.0 # strict
diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py
index 9750dba22edaa..88b384e8a227b 100644
--- a/src/lightning/data/__init__.py
+++ b/src/lightning/data/__init__.py
@@ -1,7 +1,7 @@
from lightning.data.streaming.combined import CombinedStreamingDataset
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
-from lightning.data.streaming.functions import map, optimize
+from lightning.data.streaming.functions import map, optimize, walk
__all__ = [
"LightningDataset",
@@ -11,4 +11,5 @@
"LightningIterableDataset",
"map",
"optimize",
+ "walk",
]
diff --git a/src/lightning/data/streaming/client.py b/src/lightning/data/streaming/client.py
index cfe757537d218..c42d3af9f6ed1 100644
--- a/src/lightning/data/streaming/client.py
+++ b/src/lightning/data/streaming/client.py
@@ -7,8 +7,6 @@
if _BOTO3_AVAILABLE:
import boto3
import botocore
- from botocore.credentials import InstanceMetadataProvider
- from botocore.utils import InstanceMetadataFetcher
class S3Client:
@@ -31,14 +29,8 @@ def client(self) -> Any:
# Re-generate credentials for EC2
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
- provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
- credentials = provider.load()
self._client = boto3.client(
- "s3",
- aws_access_key_id=credentials.access_key,
- aws_secret_access_key=credentials.secret_key,
- aws_session_token=credentials.token,
- config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
+ "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)
self._last_time = time()
diff --git a/src/lightning/data/streaming/combined.py b/src/lightning/data/streaming/combined.py
index aab5691acba9a..7b3373a3e800e 100644
--- a/src/lightning/data/streaming/combined.py
+++ b/src/lightning/data/streaming/combined.py
@@ -17,20 +17,29 @@
from torch.utils.data import IterableDataset
from lightning.data.streaming.dataset import StreamingDataset
+from lightning.data.utilities.env import _WorkerEnv
+
+__NUM_SAMPLES_YIELDED_KEY__ = "__NUM_SAMPLES_YIELDED__"
+__SAMPLES_KEY__ = "__SAMPLES__"
class CombinedStreamingDataset(IterableDataset):
"""The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of
your choice.
- Addtionally, the `CombinedStreamingDataset` keeps track of the number of
- samples fetched to enable resumability of the datasets.
+ Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability
+ of the datasets.
+
+ Note that due to the random sampling, the number of samples returned from the iterator is variable and a function
+ of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted.
"""
def __init__(
self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None
) -> None:
+ self._check_datasets(datasets)
+
self._seed = seed
self._datasets = datasets
self._weights = weights
@@ -43,34 +52,83 @@ def __init__(
self._weights = [w / sum(weights) for w in weights]
self._iterator: Optional[_CombinedDatasetIterator] = None
+ self._use_streaming_dataloader = False
+ self._num_samples_yielded: Optional[List[int]] = None
+ self._current_epoch = 0
- def __len__(self) -> int:
- assert self._weights
- return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0]))
+ def set_epoch(self, current_epoch: int) -> None:
+ """Set the current epoch to the datasets on epoch starts.
+
+ When using the StreamingDataLoader, this is done automatically
+
+ """
+ self._current_epoch = current_epoch
+ for dataset in self._datasets:
+ dataset.set_epoch(current_epoch)
+
+ def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
+ if any(not isinstance(d, StreamingDataset) for d in datasets):
+ raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
+
+ def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None:
+ # Used to prevent returning num_samples_yielded when using PyTorch DataLoader
+ self._use_streaming_dataloader = use_streaming_dataloader
def __iter__(self) -> Iterator[Any]:
assert self._weights
- self._iterator = _CombinedDatasetIterator(self._datasets, self._seed, self._weights)
+
+ worker_env = _WorkerEnv.detect()
+
+ num_samples_yielded = None
+
+ if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded:
+ num_samples_yielded = self._num_samples_yielded[worker_env.rank]
+
+ self._iterator = _CombinedDatasetIterator(
+ self._datasets,
+ self._seed,
+ self._weights,
+ self._use_streaming_dataloader,
+ num_samples_yielded,
+ )
return self._iterator
- def state_dict(self, num_workers: int, batch_size: int) -> Dict[str, Any]:
+ def state_dict(
+ self, num_workers: int, batch_size: int, num_samples_yielded: Optional[List[int]] = None
+ ) -> Dict[str, Any]:
if self._iterator is None:
- return {}
+ if num_samples_yielded is None:
+ return {}
+ return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size)
return self._iterator.state_dict(num_workers, batch_size)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
- if len(state_dict) != len(self._datasets):
+ if not state_dict:
+ return
+
+ if len(state_dict["dataset"]) != len(self._datasets):
raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.")
for dataset_idx, dataset in enumerate(self._datasets):
- if str(dataset_idx) not in state_dict:
+ if str(dataset_idx) not in state_dict["dataset"]:
raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.")
- dataset.load_state_dict(state_dict[str(dataset_idx)])
+ dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)])
+
+ # Used to iterate over the sampler to avoid sampling the same samples
+ if self._use_streaming_dataloader:
+ self._num_samples_yielded = state_dict["num_samples_yielded"]
class _CombinedDatasetIterator(Iterator):
- def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequence[float]) -> None:
+ def __init__(
+ self,
+ datasets: List[StreamingDataset],
+ seed: int,
+ weights: Sequence[float],
+ use_streaming_dataloader: bool,
+ num_samples_yielded: Optional[Any] = None,
+ ) -> None:
self._datasets = datasets
self._dataset_iters = [iter(dataset) for dataset in datasets]
self._dataset_indexes = list(range(len(datasets)))
@@ -78,6 +136,13 @@ def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequenc
self._weights = weights
self._rng = random.Random(seed)
+ if num_samples_yielded is not None:
+ self._num_samples_yielded = num_samples_yielded
+ for _ in range(sum(num_samples_yielded)):
+ self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
+
+ self._use_streaming_dataloader = use_streaming_dataloader
+
def __next__(self) -> Any:
# randomly select a dataset index
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
@@ -85,11 +150,26 @@ def __next__(self) -> Any:
# keep track the sample was fetched
self._num_samples_yielded[dataset_index] += 1
+ sample = next(self._dataset_iters[dataset_index])
+
# return a new sample
- return next(self._dataset_iters[dataset_index])
+ if self._use_streaming_dataloader:
+ return {
+ __SAMPLES_KEY__: sample,
+ __NUM_SAMPLES_YIELDED_KEY__: self._num_samples_yielded,
+ }
+ return sample
def state_dict(self, num_workers: int = 0, batch_size: int = 1) -> Dict[str, Any]:
- return {
- str(dataset_idx): dataset.state_dict(self._num_samples_yielded[dataset_idx], num_workers, batch_size)
- for dataset_idx, dataset in enumerate(self._datasets)
- }
+ return _state_dict(self._datasets, self._num_samples_yielded, num_workers, batch_size)
+
+
+def _state_dict(
+ datasets: List[StreamingDataset], num_samples_yielded: List[int], num_workers: int = 0, batch_size: int = 1
+) -> Dict[str, Any]:
+ return {
+ str(dataset_idx): dataset.state_dict(
+ num_samples_yielded=num_samples_yielded[dataset_idx], num_workers=num_workers, batch_size=batch_size
+ )
+ for dataset_idx, dataset in enumerate(datasets)
+ }
diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py
index a589a5906e26c..a6e100e3deb64 100644
--- a/src/lightning/data/streaming/data_processor.py
+++ b/src/lightning/data/streaming/data_processor.py
@@ -563,7 +563,7 @@ def _handle_data_chunk_recipe_end(self) -> None:
def _handle_data_transform_recipe(self, index: int) -> None:
# Don't use a context manager to avoid deleting files that are being uploaded.
output_dir = tempfile.mkdtemp()
- item_data = self.data_recipe.prepare_item(str(output_dir), self.items[index])
+ item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir))
if item_data is not None:
raise ValueError(
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
@@ -753,7 +753,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
"""
@abstractmethod
- def prepare_item(self, output_dir: str, item_metadata: T) -> None: # type: ignore
+ def prepare_item(self, item_metadata: T, output_dir: str) -> None: # type: ignore
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py
index bd49e83349a0e..acd0ebef19af6 100644
--- a/src/lightning/data/streaming/dataloader.py
+++ b/src/lightning/data/streaming/dataloader.py
@@ -15,7 +15,9 @@
import inspect
import logging
import os
+from copy import deepcopy
from importlib import reload
+from itertools import cycle
from typing import Any, Callable, Dict, List, Optional, Union
import torch
@@ -32,7 +34,11 @@
from torch.utils.data.sampler import BatchSampler, Sampler
from lightning.data.streaming import Cache
-from lightning.data.streaming.combined import CombinedStreamingDataset
+from lightning.data.streaming.combined import (
+ __NUM_SAMPLES_YIELDED_KEY__,
+ __SAMPLES_KEY__,
+ CombinedStreamingDataset,
+)
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.sampler import CacheBatchSampler
@@ -341,6 +347,137 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
return _MultiProcessingDataLoaderIterPatch(self)
+def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int, profile_dir: str) -> Callable:
+ counter = 0
+
+ def wrap(*args: Any, **kwargs: Any) -> Any:
+ nonlocal counter
+ result = func(*args, **kwargs)
+
+ if tracer.enable and counter == profile:
+ tracer.stop()
+ tracer.save()
+ print(
+ f"Saved {os.path.join(profile_dir, 'result.json')} file after {profile} batches."
+ "Use chrome://tracing/ to view it."
+ )
+ fetcher.fetch = func
+
+ counter += 1
+ return result
+
+ return wrap
+
+
+class _ProfileWorkerLoop:
+ """Wrap the PyTorch DataLoader WorkerLoop to add profiling."""
+
+ def __init__(self, profile: Union[int, bool], profile_dir: Optional[str] = None):
+ self._profile = profile
+ self._profile_dir = profile_dir if profile_dir else os.getcwd()
+
+ def __call__(
+ self,
+ dataset_kind: Any,
+ dataset: Any,
+ index_queue: Any,
+ data_queue: Any,
+ done_event: Any,
+ auto_collation: Any,
+ collate_fn: Any,
+ drop_last: Any,
+ base_seed: Any,
+ init_fn: Any,
+ worker_id: Any,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ from torch.utils.data._utils import worker
+ from viztracer import VizTracer
+
+ if worker_id == 0:
+ output_file = os.path.join(self._profile_dir, "result.json")
+
+ if os.path.exists(output_file):
+ os.remove(output_file)
+
+ tracer = VizTracer(output_file=output_file, verbose=0)
+ tracer.start()
+
+ # Reload to remove the patching
+ reloaded_worker = reload(worker)
+ create_fetcher = _DatasetKind.create_fetcher
+ fetcher = None
+
+ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
+ nonlocal fetcher
+ fetcher = create_fetcher(*args, **kwargs)
+
+ if worker_id == 0 and isinstance(self._profile, int):
+ fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile, self._profile_dir)
+ return fetcher
+
+ _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore
+
+ reloaded_worker._worker_loop(
+ dataset_kind,
+ dataset,
+ index_queue,
+ data_queue,
+ done_event,
+ auto_collation,
+ collate_fn,
+ drop_last,
+ base_seed,
+ init_fn,
+ worker_id,
+ *args,
+ **kwargs,
+ )
+
+ if worker_id == 0 and isinstance(self._profile, bool):
+ tracer.stop()
+ tracer.save()
+
+
+class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
+ def __init__(self, loader: DataLoader) -> None:
+ self._loader = loader
+ self._indexes = (
+ list(range(self._loader._latest_worker_idx, self._loader.num_workers))
+ if self._loader._latest_worker_idx > 0
+ else []
+ )
+ self._num_workers = loader.num_workers
+
+ distributed_env = _DistributedEnv.detect()
+
+ if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
+ from torch.utils.data._utils import worker
+
+ worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir)
+
+ super().__init__(loader)
+
+ def _try_put_index(self) -> None:
+ # Used to restart on the right DataLoader worker
+ if self._loader.restore and self._indexes:
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
+
+ try:
+ index = self._next_index()
+ except StopIteration:
+ return
+ worker_queue_idx = self._indexes.pop(0)
+
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
+ self._task_info[self._send_idx] = (worker_queue_idx,)
+ self._tasks_outstanding += 1
+ self._send_idx += 1
+ else:
+ super()._try_put_index()
+
+
class StreamingDataLoader(DataLoader):
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
dataset."""
@@ -353,29 +490,102 @@ def __init__(
*args: Any,
batch_size: int = 1,
num_workers: int = 0,
+ profile_bactches: Union[bool, int] = False,
+ profile_dir: Optional[str] = None,
+ prefetch_factor: Optional[int] = None,
**kwargs: Any,
) -> None: # pyright: ignore
+ if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
+ raise RuntimeError(
+ "The provided dataset should be either an instance of StreamingDataset or CombinedStreamingDataset."
+ f" Found {dataset}."
+ )
+
+ if profile_bactches and not _VIZ_TRACKER_AVAILABLE:
+ raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`")
+
+ if profile_bactches and num_workers == 0:
+ raise ValueError("Profiling is supported only with num_workers >= 1.")
+
+ self.current_epoch = 0
self.batch_size = batch_size
self.num_workers = num_workers
- self.num_samples_yielded = 0
- super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore
+ self._profile_bactches = profile_bactches
+ self._profile_dir = profile_dir
+ self._num_samples_yielded_streaming = 0
+ self._num_samples_yielded_combined: Dict[int, List[Any]] = {}
+ self.rng_state: Optional[Any] = None
+ self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
+ self._worker_idx_iter: Optional[Any] = None
+ self._latest_worker_idx = 0
+ self.restore = False
+ super().__init__(
+ dataset,
+ *args,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor,
+ **kwargs,
+ ) # type: ignore
def __iter__(self) -> Any:
+ if not self.restore:
+ self._latest_worker_idx = 0
+ self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
+ self._worker_idx_iter = iter(self._worker_idx)
+ self.current_epoch += 1
+ self._num_samples_yielded_combined = {}
+ self._num_samples_yielded_streaming = 0
+
+ self.dataset.set_epoch(self.current_epoch)
+
if isinstance(self.dataset, StreamingDataset):
assert self.batch_size
- self.num_samples_yielded = 0
for batch in super().__iter__():
- self.num_samples_yielded += self.batch_size
+ self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore
+ self._num_samples_yielded_streaming += self.batch_size
yield batch
else:
- yield from super().__iter__()
+ self.dataset._set_use_streaming_dataloader(True)
+ assert self.batch_size
+ # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key
+ for batch in super().__iter__():
+ self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore
+ if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch:
+ self._num_samples_yielded_combined[self._latest_worker_idx] = [
+ sample[-1].item() if self.batch_size > 1 else sample.item()
+ for sample in batch[__NUM_SAMPLES_YIELDED_KEY__]
+ ]
+
+ yield batch[__SAMPLES_KEY__]
+ else:
+ yield batch
+
+ self.restore = False
def state_dict(self) -> Dict[str, Any]:
if isinstance(self.dataset, StreamingDataset):
assert self.batch_size
- num_samples = self.num_samples_yielded
- return self.dataset.state_dict(num_samples, self.num_workers, self.batch_size)
- return self.dataset.state_dict(self.num_workers, self.batch_size)
+ return {
+ "dataset": self.dataset.state_dict(
+ self._num_samples_yielded_streaming, self.num_workers, self.batch_size
+ ),
+ "current_epoch": self.current_epoch,
+ "num_samples_yielded": self._num_samples_yielded_streaming,
+ "latest_worker_idx": self._latest_worker_idx,
+ }
+
+ num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
+ for worker_idx in self._num_samples_yielded_combined:
+ for dataset_idx, samples_yieled in enumerate(self._num_samples_yielded_combined[worker_idx]):
+ num_samples_yieled[dataset_idx] += samples_yieled
+
+ return {
+ "dataset": self.dataset.state_dict(self.num_workers, self.batch_size, num_samples_yieled),
+ "current_epoch": self.current_epoch if self.restore else self.current_epoch - 1,
+ "latest_worker_idx": self._latest_worker_idx,
+ "num_samples_yielded": deepcopy(self._num_samples_yielded_combined),
+ }
def load_state_dict(self, obj: Dict[str, Any]) -> None:
"""Load a dict containing training state (called from non-worker process).
@@ -386,7 +596,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
obj (Any): The state.
"""
- if isinstance(self.dataset, (StreamingDataset, CombinedStreamingDataset)):
+ self.current_epoch = obj["current_epoch"]
+
+ if isinstance(self.dataset, StreamingDataset):
+ self._num_samples_yielded_streaming = obj["num_samples_yielded"]
+ else:
+ self._num_samples_yielded_combined = obj["num_samples_yielded"]
+
+ # Used to restart on the next DataLoader worker from the previous run.
+ self._latest_worker_idx = obj["latest_worker_idx"] + 1
+ self._worker_idx_iter = iter(self._worker_idx)
+ for _ in range(self._latest_worker_idx):
+ next(self._worker_idx_iter)
+
+ # Inform we are resuming and disable resetting the StreamingDataLoader state.
+ # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
+ self.restore = True
+
+ if isinstance(self.dataset, CombinedStreamingDataset):
+ self.dataset._set_use_streaming_dataloader(True)
self.dataset.load_state_dict(obj)
+ elif isinstance(self.dataset, StreamingDataset):
+ self.dataset.load_state_dict(obj["dataset"])
else:
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")
+
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
+ """Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
+ if self.num_workers == 0:
+ return _SingleProcessDataLoaderIter(self)
+ self.check_worker_number_rationality()
+ return _StreamingMultiProcessingDataLoaderIter(self)
diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py
index ec331a92dec95..e281bbc0a1e2d 100644
--- a/src/lightning/data/streaming/dataset.py
+++ b/src/lightning/data/streaming/dataset.py
@@ -17,7 +17,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
-from torch.utils.data import IterableDataset, get_worker_info
+from torch.utils.data import IterableDataset
from lightning.data.streaming import Cache
from lightning.data.streaming.constants import (
@@ -29,7 +29,7 @@
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
-from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
+from lightning.data.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
class StreamingDataset(IterableDataset):
@@ -90,6 +90,17 @@ def __init__(
self.serializers = serializers
self._state_dict: Optional[Dict[str, Any]] = None
+ def set_epoch(self, current_epoch: int) -> None:
+ """Set the current epoch to the dataset on epoch starts.
+
+ When using the StreamingDataLoader, this is done automatically
+
+ """
+ # If the state dict has been reloaded, don't override the current epoch
+ # The StreamingDataloader would clean this out
+ if self._state_dict is None:
+ self.current_epoch = current_epoch
+
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if _should_replace_path(self.input_dir.path):
cache_path = _try_create_cache_dir(
@@ -119,8 +130,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
seed = self.seed
drop_last = self.drop_last
if self._state_dict is not None:
- restart_keys = sorted(self._state_dict)
- state: Dict[str, Any] = self._state_dict[restart_keys[-1]]
+ state: Dict[str, Any] = self._state_dict
seed = state["seed"]
drop_last = state["drop_last"]
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)
@@ -139,8 +149,7 @@ def __iter__(self) -> "StreamingDataset":
# Handle restart
if self._state_dict:
self._validate_state_dict()
- restart_keys = sorted(self._state_dict)
- state: Dict[str, Any] = self._state_dict[restart_keys[-1]]
+ state: Dict[str, Any] = self._state_dict
self.current_epoch = state["current_epoch"]
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
@@ -187,16 +196,13 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
assert self.worker_env
assert self.shuffler
- restart_keys = sorted(self._state_dict)
-
- # Get the state from the previous run
- state: Dict[str, Any] = self._state_dict[restart_keys[-1]]
+ state: Dict[str, Any] = self._state_dict
num_workers = state["num_workers"]
batch_size = state["batch_size"]
# TODO: Implement elastic sampling where the number of workers, ranks can change.
- num_samples_yielded = sum([state["num_samples_yielded"] for state in self._state_dict.values()])
+ num_samples_yielded = self._state_dict["num_samples_yielded"]
# replay sampling from each worker / chunks using the batch size
workers_chunks, workers_intervals = _associate_chunks_to_workers(
@@ -213,7 +219,7 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
self.worker_intervals = workers_intervals[worker_rank]
# replay the indexes for the current chunks
- interval = workers_intervals[worker_rank][self.chunk_index]
+ interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
# re-shuffle the indexes
@@ -223,6 +229,8 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
current_indexes = current_indexes[indexes[worker_rank] :]
self.current_indexes = current_indexes
+ self.global_index = num_samples_yielded
+
# bump the chunk_index
self.chunk_index += 1
@@ -283,6 +291,10 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int
if _is_in_dataloader_worker():
raise RuntimeError("The method `state_dict` should only be called in the main process.")
+ if self._state_dict is not None:
+ self._state_dict["num_samples_yielded"] = num_samples_yielded
+ return self._state_dict
+
state = {
"num_samples_yielded": num_samples_yielded,
"num_workers": num_workers,
@@ -297,10 +309,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int
"shuffle": self.shuffle,
}
- if self._state_dict:
- num_restarts = len(self._state_dict)
- return {**self._state_dict, f"{num_restarts}": state}
- return {"0": state}
+ return state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
@@ -312,8 +321,7 @@ def _validate_state_dict(self) -> None:
assert self.worker_env
assert self.cache
- restart_keys = sorted(self._state_dict)
- state: Dict[str, Any] = self._state_dict[restart_keys[-1]]
+ state: Dict[str, Any] = self._state_dict
if state["shuffle"] != self.shuffle:
raise ValueError(
@@ -327,7 +335,18 @@ def _validate_state_dict(self) -> None:
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`."
)
- if state["input_dir_path"] != self.input_dir.path:
+ # Note: We need to check whether the path has been resolved to its associated cache.
+ # In this case, validate the cache folder is the same.
+ if _should_replace_path(state["input_dir_path"]):
+ cache_path = _try_create_cache_dir(
+ input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"]
+ )
+ if cache_path != self.input_dir.path:
+ raise ValueError(
+ "The provided `input_dir` path state doesn't match the current one. "
+ f"Found `{self.input_dir.path}` instead of `{cache_path}`."
+ )
+ elif state["input_dir_path"] != self.input_dir.path:
raise ValueError(
"The provided `input_dir` path state doesn't match the current one. "
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`."
@@ -374,11 +393,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
if path is None or path == "":
return True
- return "/datasets/" in path or "_connections/" in path
-
-
-def _is_in_dataloader_worker() -> bool:
- return get_worker_info() is not None
+ return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/")
def is_integer(value: str) -> bool:
diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py
index b9097c843e66e..982b4d25142b3 100644
--- a/src/lightning/data/streaming/downloader.py
+++ b/src/lightning/data/streaming/downloader.py
@@ -12,6 +12,7 @@
# limitations under the License.
import os
import shutil
+import subprocess
from abc import ABC
from typing import Any, Dict, List
from urllib import parse
@@ -19,6 +20,7 @@
from filelock import FileLock, Timeout
from lightning.data.streaming.client import S3Client
+from lightning.data.streaming.constants import _INDEX_FILENAME
class Downloader(ABC):
@@ -40,7 +42,10 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
class S3Downloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
super().__init__(remote_dir, cache_dir, chunks)
- self._client = S3Client()
+ self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0
+
+ if not self._s5cmd_available:
+ self._client = S3Client()
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
obj = parse.urlparse(remote_filepath)
@@ -48,21 +53,34 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if obj.scheme != "s3":
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
- from boto3.s3.transfer import TransferConfig
-
- extra_args: Dict[str, Any] = {}
+ if os.path.exists(local_filepath):
+ return
try:
- with FileLock(local_filepath + ".lock", timeout=1):
- if not os.path.exists(local_filepath):
- # Issue: https://github.com/boto/boto3/issues/3113
- self._client.client.download_file(
- obj.netloc,
- obj.path.lstrip("/"),
- local_filepath,
- ExtraArgs=extra_args,
- Config=TransferConfig(use_threads=False),
+ with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0):
+ if self._s5cmd_available:
+ proc = subprocess.Popen(
+ f"s5cmd cp {remote_filepath} {local_filepath}",
+ shell=True,
+ stdout=subprocess.PIPE,
)
+ proc.wait()
+ else:
+ from boto3.s3.transfer import TransferConfig
+
+ extra_args: Dict[str, Any] = {}
+
+ # try:
+ # with FileLock(local_filepath + ".lock", timeout=1):
+ if not os.path.exists(local_filepath):
+ # Issue: https://github.com/boto/boto3/issues/3113
+ self._client.client.download_file(
+ obj.netloc,
+ obj.path.lstrip("/"),
+ local_filepath,
+ ExtraArgs=extra_args,
+ Config=TransferConfig(use_threads=False),
+ )
except Timeout:
# another process is responsible to download that file, continue
pass
diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py
index 1d06f6a1a3885..f6aefcbcd0b33 100644
--- a/src/lightning/data/streaming/functions.py
+++ b/src/lightning/data/streaming/functions.py
@@ -11,13 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import concurrent.futures
import inspect
import os
from datetime import datetime
from functools import partial
from pathlib import Path
from types import FunctionType
-from typing import Any, Callable, Dict, Optional, Sequence, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
@@ -81,24 +82,24 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
params = inspect.signature(_fn).parameters
self._contains_device = "device" in params
- def prepare_structure(self, input_dir: Optional[str]) -> Any:
+ def prepare_structure(self, _: Optional[str]) -> Any:
return self._inputs
- def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore
+ def prepare_item(self, item_metadata: Any, output_dir: str) -> None: # type: ignore
if self._contains_device and self._device is None:
self._find_device()
if isinstance(self._fn, (FunctionType, partial)):
if self._contains_device:
- self._fn(output_dir, item_metadata, self._device)
+ self._fn(item_metadata, output_dir, self._device)
else:
- self._fn(output_dir, item_metadata)
+ self._fn(item_metadata, output_dir)
elif callable(self._fn):
if self._contains_device:
- self._fn.__call__(output_dir, item_metadata, self._device) # type: ignore
+ self._fn.__call__(item_metadata, output_dir, self._device) # type: ignore
else:
- self._fn.__call__(output_dir, item_metadata) # type: ignore
+ self._fn.__call__(item_metadata, output_dir) # type: ignore
else:
raise ValueError(f"The provided {self._fn} isn't supported.")
@@ -286,3 +287,51 @@ def optimize(
num_nodes,
machine,
)
+
+
+def _listdir(folder: str) -> Tuple[str, List[str]]:
+ return folder, os.listdir(folder)
+
+
+class walk:
+ """This class is an optimized version of os.walk for listing files and folders from cloud filesystem.
+
+ Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call.
+
+ """
+
+ def __init__(self, folder: str, max_workers: Optional[int] = os.cpu_count()) -> None:
+ self.folders = [folder]
+ self.max_workers = max_workers or 1
+ self.futures: List[concurrent.futures.Future] = []
+
+ def __iter__(self) -> Any:
+ """This function queues the folders to perform listdir across multiple workers."""
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
+ while len(self.folders):
+ folder = self.folders.pop(0)
+ future = executor.submit(_listdir, folder)
+ self.futures.append(future)
+
+ while self.futures:
+ for future in concurrent.futures.as_completed(self.futures):
+ filenames = []
+ folders = []
+
+ folder, files_or_folders = future.result()
+ self.futures = [f for f in self.futures if f != future]
+
+ for file_or_folder in files_or_folders:
+ if os.path.isfile(os.path.join(folder, file_or_folder)):
+ filenames.append(file_or_folder)
+ else:
+ folders.append(file_or_folder)
+ self.folders.append(os.path.join(folder, file_or_folder))
+
+ yield folder, folders, filenames
+
+ while len(self.folders) and len(self.futures) <= self.max_workers * 2:
+ folder = self.folders.pop(0)
+ future = executor.submit(_listdir, folder)
+ self.futures.append(future)
+ return
diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py
index 6ce919e2804c8..779a683146182 100644
--- a/src/lightning/data/streaming/item_loader.py
+++ b/src/lightning/data/streaming/item_loader.py
@@ -87,15 +87,11 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
del self._chunk_filepaths[chunk_filepath]
if chunk_filepath not in self._chunk_filepaths:
- first_exists = exists = os.path.exists(chunk_filepath)
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
while not exists:
sleep(0.1)
- exists = os.path.exists(chunk_filepath)
-
- # Wait to avoid any corruption when the file appears
- if not first_exists:
- sleep(0.001)
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
self._chunk_filepaths[chunk_filepath] = True
@@ -189,16 +185,12 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
del self._chunk_filepaths[chunk_filepath]
if chunk_filepath not in self._chunk_filepaths:
- first_exists = exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
+ exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
while not exists:
sleep(0.1)
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0
- # Wait to avoid any corruption when the file appears
- if not first_exists:
- sleep(0.1)
-
self._chunk_filepaths[chunk_filepath] = True
self._load_chunk(chunk_index, chunk_filepath)
diff --git a/src/lightning/data/streaming/writer.py b/src/lightning/data/streaming/writer.py
index 68b2f6c762d2e..e58194289b208 100644
--- a/src/lightning/data/streaming/writer.py
+++ b/src/lightning/data/streaming/writer.py
@@ -105,12 +105,12 @@ def filled(self) -> bool:
return True
files = os.listdir(self._cache_dir)
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
- worker_end = _WorkerEnv.detect()
+ worker_env = _WorkerEnv.detect()
data_optimiser_num_workers = os.getenv("DATA_OPTIMIZER_NUM_WORKERS", None)
if data_optimiser_num_workers is not None:
self._is_done = len(index_files) == int(data_optimiser_num_workers)
else:
- self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size
+ self._is_done = len(index_files) == self._distributed_env.world_size * worker_env.world_size
return self._is_done
@property
diff --git a/src/lightning/data/utilities/env.py b/src/lightning/data/utilities/env.py
index fa91e714f8666..c9406963d909b 100644
--- a/src/lightning/data/utilities/env.py
+++ b/src/lightning/data/utilities/env.py
@@ -163,3 +163,7 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return repr(self)
+
+
+def _is_in_dataloader_worker() -> bool:
+ return torch_get_worker_info() is not None
diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md
index 808d8f3ee2746..32344efc17610 100644
--- a/src/lightning/fabric/CHANGELOG.md
+++ b/src/lightning/fabric/CHANGELOG.md
@@ -24,6 +24,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for clipping gradients by value with FSDP ([#19236](https://github.com/Lightning-AI/lightning/pull/19236))
+- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213))
+
+
+- (Experimental) Added support for re-compiling the model inside `Fabric.setup()` over the FSDP/DDP wrappers ([#19280](https://github.com/Lightning-AI/lightning/pull/19280))
+
+
### Changed
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py
index 4afed65356e62..3eb12f2afad67 100644
--- a/src/lightning/fabric/fabric.py
+++ b/src/lightning/fabric/fabric.py
@@ -76,6 +76,7 @@
_FabricDataLoader,
_FabricModule,
_FabricOptimizer,
+ _to_compiled,
_unwrap_compiled,
_unwrap_objects,
)
@@ -213,6 +214,7 @@ def setup(
module: nn.Module,
*optimizers: Optimizer,
move_to_device: bool = True,
+ _reapply_compile: Optional[bool] = None,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
r"""Set up a model and its optimizers for accelerated training.
@@ -221,12 +223,17 @@ def setup(
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
+ _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
+ corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
+ same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
+ FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
"""
self._validate_setup(module, optimizers)
+ module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module
module = self._precision.convert_module(module)
@@ -242,6 +249,8 @@ def setup(
else:
module = self._strategy.setup_module(module)
+ if compile_kwargs is not None:
+ module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)
# Update the _DeviceDtypeModuleMixin's device parameter
@@ -258,8 +267,8 @@ def setup(
self._models_setup += 1
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
- original_module._fabric = self # type: ignore[assignment]
- original_module._fabric_optimizers = optimizers # type: ignore[assignment]
+ original_module._fabric = self
+ original_module._fabric_optimizers = optimizers
if original_module not in self._callbacks:
self._callbacks.append(original_module)
@@ -270,7 +279,9 @@ def setup(
return (module, *optimizers)
return module
- def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _FabricModule:
+ def setup_module(
+ self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
+ ) -> _FabricModule:
r"""Set up a model for accelerated training or inference.
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
@@ -281,12 +292,17 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
module: A :class:`torch.nn.Module` to set up
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
+ _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
+ corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
+ same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
+ FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
Returns:
The wrapped model.
"""
self._validate_setup_module(module)
+ module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
original_module = module
module = self._precision.convert_module(module)
@@ -296,6 +312,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
# Let strategy wrap and connect the module alone
module = self._strategy.setup_module(module)
+
+ if compile_kwargs is not None:
+ module = _to_compiled(module, compile_kwargs)
module = _FabricModule(module, self._precision, original_module=original_module)
# Update the _DeviceDtypeModuleMixin's device parameter
@@ -305,7 +324,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
)
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
- original_module._fabric = self # type: ignore[assignment]
+ original_module._fabric = self
if original_module not in self._callbacks:
self._callbacks.append(original_module)
@@ -410,6 +429,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
"""
module = model._forward_module if model is not None else model
+ module, _ = _unwrap_compiled(module)
if isinstance(self._strategy, DeepSpeedStrategy):
if model is None:
if self._models_setup == 0:
@@ -641,7 +661,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
skip.
"""
- module = _unwrap_compiled(module)
+ module, _ = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
raise TypeError(
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
@@ -656,7 +676,9 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
category=PossibleUserWarning,
)
return nullcontext()
- return self._strategy._backward_sync_control.no_backward_sync(module._forward_module)
+
+ forward_module, _ = _unwrap_compiled(module._forward_module)
+ return self._strategy._backward_sync_control.no_backward_sync(forward_module)
def sharded_model(self) -> ContextManager:
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
@@ -772,7 +794,7 @@ def load(
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
- obj = _unwrap_compiled(state[k])
+ obj, _ = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
continue
state[k] = unwrapped_state[k]
diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py
index f429a364b5618..420d49605fcba 100644
--- a/src/lightning/fabric/strategies/fsdp.py
+++ b/src/lightning/fabric/strategies/fsdp.py
@@ -68,7 +68,7 @@
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.init import _EmptyInit
-from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _move_state_into
+from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, _Stateful
@@ -86,7 +86,6 @@
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
-_METADATA_FILENAME = "meta.pt"
class FSDPStrategy(ParallelStrategy, _Sharded):
diff --git a/src/lightning/fabric/utilities/consolidate_checkpoint.py b/src/lightning/fabric/utilities/consolidate_checkpoint.py
new file mode 100644
index 0000000000000..b41e8f8a1312e
--- /dev/null
+++ b/src/lightning/fabric/utilities/consolidate_checkpoint.py
@@ -0,0 +1,79 @@
+import logging
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+
+import torch
+
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
+from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint
+
+_log = logging.getLogger(__name__)
+
+
+def _parse_cli_args() -> Namespace:
+ parser = ArgumentParser(
+ description=(
+ "Converts a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`."
+ " Only supports FSDP sharded checkpoints at the moment."
+ ),
+ )
+ parser.add_argument(
+ "checkpoint_folder",
+ type=str,
+ help=(
+ "Path to a checkpoint folder, containing the sharded checkpoint files saved using the"
+ " `torch.distributed.checkpoint` API."
+ ),
+ )
+ parser.add_argument(
+ "--output_file",
+ type=str,
+ help=(
+ "Path to the file where the converted checkpoint should be saved. The file should not already exist."
+ " If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
+ " and a '.consolidated' suffix."
+ ),
+ )
+ return parser.parse_args()
+
+
+def _process_cli_args(args: Namespace) -> Namespace:
+ if not _TORCH_GREATER_EQUAL_2_1:
+ _log.error("Processing distributed checkpoints requires PyTorch >= 2.1.")
+ exit(1)
+
+ checkpoint_folder = Path(args.checkpoint_folder)
+ if not checkpoint_folder.exists():
+ _log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}")
+ exit(1)
+ if not checkpoint_folder.is_dir():
+ _log.error(
+ f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}"
+ )
+ exit(1)
+ if not (checkpoint_folder / _METADATA_FILENAME).is_file():
+ _log.error(
+ "Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder"
+ f" is not in that format: {checkpoint_folder}"
+ )
+ exit(1)
+
+ if args.output_file is None:
+ output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated")
+ else:
+ output_file = Path(args.output_file)
+ if output_file.exists():
+ _log.error(
+ "The path for the converted checkpoint already exists. Choose a different path by providing"
+ f" `--output_file` or move/delete the file first: {output_file}"
+ )
+ exit(1)
+
+ return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
+
+
+if __name__ == "__main__":
+ args = _parse_cli_args()
+ config = _process_cli_args(args)
+ checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
+ torch.save(checkpoint, config.output_file)
diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py
index 6925ccfabfd7a..bab8e29823903 100644
--- a/src/lightning/fabric/utilities/load.py
+++ b/src/lightning/fabric/utilities/load.py
@@ -15,7 +15,8 @@
import warnings
from functools import partial
from io import BytesIO
-from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union
+from pathlib import Path
+from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@@ -24,9 +25,16 @@
from torch.nn import Parameter
from typing_extensions import override
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
+from lightning.fabric.utilities.imports import (
+ _TORCH_GREATER_EQUAL_2_0,
+ _TORCH_GREATER_EQUAL_2_1,
+ _TORCH_GREATER_EQUAL_2_2,
+)
from lightning.fabric.utilities.types import _PATH, _Stateful
+_METADATA_FILENAME = "meta.pt"
+
+
if TYPE_CHECKING:
from torch.storage import TypedStorage
@@ -227,3 +235,76 @@ def _move_state_into(
destination[key].load_state_dict(state)
else:
destination[key] = state
+
+
+def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]:
+ """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict.
+
+ The current implementation assumes that the entire checkpoint fits in CPU memory.
+
+ """
+ if not _TORCH_GREATER_EQUAL_2_1:
+ raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.")
+
+ from torch.distributed.checkpoint import FileSystemReader
+ from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata
+
+ if _TORCH_GREATER_EQUAL_2_2:
+ from torch.distributed.checkpoint import load
+ else:
+ from torch.distributed.checkpoint import load_state_dict as load # deprecated
+
+ reader = FileSystemReader(checkpoint_folder)
+ metadata = reader.read_metadata()
+
+ # TODO: Add sequential save to avoid storing the entire checkpoint in memory
+ checkpoint: Dict[str, Any] = {}
+ for tensor_name, sd_metadata in metadata.state_dict_metadata.items():
+ if isinstance(sd_metadata, BytesStorageMetadata):
+ checkpoint[tensor_name] = ""
+ elif isinstance(sd_metadata, TensorStorageMetadata):
+ checkpoint[tensor_name] = torch.empty(
+ size=sd_metadata.size,
+ dtype=sd_metadata.properties.dtype,
+ device=torch.device("cpu"),
+ memory_format=sd_metadata.properties.memory_format,
+ layout=sd_metadata.properties.layout,
+ requires_grad=sd_metadata.properties.requires_grad,
+ pin_memory=sd_metadata.properties.pin_memory,
+ )
+
+ load(state_dict=checkpoint, storage_reader=reader, no_dist=True)
+ checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data)
+
+ # This is the extra file saved by Fabric, with user data separate from weights and optimizer states
+ extra_file = checkpoint_folder / _METADATA_FILENAME
+ extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
+ checkpoint.update(extra)
+
+ return checkpoint
+
+
+def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]:
+ """Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map.
+
+ Args:
+ checkpoint: The flat checkpoint dictionary.
+ key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing
+ the index path into the nested dictonary that this function should construct.
+
+ """
+ assert checkpoint.keys() == key_map.keys()
+ converted: Dict[str, Any] = {}
+ for flat_key in checkpoint:
+ key_path = key_map[flat_key]
+ _set_nested_dict_value(converted, key_path, checkpoint[flat_key])
+ return converted
+
+
+def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None:
+ result = nested_dict
+ for key in key_path[:-1]:
+ if key not in result:
+ result[key] = {}
+ result = result[key]
+ result[key_path[-1]] = value
diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py
index 6b971a58ada9d..d96673df6451d 100644
--- a/src/lightning/fabric/utilities/throughput.py
+++ b/src/lightning/fabric/utilities/throughput.py
@@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
else:
from torch_xla.experimental import tpu
- device_name = tpu.get_tpu_env()["TYPE"]
+ tpu_env = tpu.get_tpu_env()
+ # not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8"
+ device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0]
chip = device_name.lower()
assert isinstance(device_name, str)
if chip not in _TPU_FLOPS:
diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py
index 16611eb4c754b..1aabb718d9abd 100644
--- a/src/lightning/fabric/wrappers.py
+++ b/src/lightning/fabric/wrappers.py
@@ -12,8 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
+from copy import deepcopy
from functools import wraps
-from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import torch
from lightning_utilities.core.apply_func import apply_to_collection
@@ -32,6 +47,9 @@
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
+if TYPE_CHECKING:
+ from torch._dynamo import OptimizedModule
+
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
@@ -285,8 +303,8 @@ def _unwrap_objects(collection: Any) -> Any:
def _unwrap(
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
) -> Union[nn.Module, Optimizer, DataLoader]:
- if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule):
- return unwrapped._forward_module
+ if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
+ return _unwrap_compiled(unwrapped._forward_module)[0]
if isinstance(obj, _FabricOptimizer):
return obj.optimizer
if isinstance(obj, _FabricDataLoader):
@@ -302,19 +320,33 @@ def _unwrap(
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
-def _unwrap_compiled(obj: Any) -> Any:
+def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
Use this function before instance checks against e.g. :class:`_FabricModule`.
"""
if not _TORCH_GREATER_EQUAL_2_0:
- return obj
+ # obj can't be an `OptimizedModule` anyway
+ return obj, None
+
from torch._dynamo import OptimizedModule
if isinstance(obj, OptimizedModule):
- return obj._orig_mod
- return obj
+ if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
+ raise RuntimeError(
+ "Failed to determine the arguments that were used to compile the module. Make sure to import"
+ " lightning before `torch.compile` is used."
+ )
+ return obj._orig_mod, compile_kwargs
+ return obj, None
+
+
+def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "OptimizedModule":
+ if not _TORCH_GREATER_EQUAL_2_0:
+ raise RuntimeError("Converting to a compiled module is only supported in PyTorch >= 2.0.0")
+
+ return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
def is_wrapped(obj: object) -> bool:
@@ -328,5 +360,30 @@ def is_wrapped(obj: object) -> bool:
obj: The object to test.
"""
- obj = _unwrap_compiled(obj)
+ obj, _ = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
+
+
+def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
+ """Wraps the ``torch.compile`` function and captures the compile arguments.
+
+ We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
+ same arguments as the user passed to the original call. The arguments get stored in a dictionary
+ ``_compile_kwargs`` on the returned compiled module.
+
+ """
+ # Limitation: Currently, the global compile config does not get captured on a per-model basis.
+ # PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575
+
+ @wraps(compile_fn)
+ def _capture(model: Any, **kwargs: Any) -> Any:
+ compiled_model = compile_fn(model, **kwargs)
+ if isinstance(model, nn.Module):
+ compiled_model._compile_kwargs = deepcopy(kwargs)
+ return compiled_model
+
+ return _capture
+
+
+if _TORCH_GREATER_EQUAL_2_0:
+ torch.compile = _capture_compile_kwargs(torch.compile)
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index dcb056d59dfac..fb6a835509946 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the option `ModelCheckpoint(save_last='link')` to create a symbolic link for the 'last.ckpt' file ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))
+- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213))
+
+
### Changed
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
@@ -77,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303))
+- Fixed issue where the `_restricted_classmethod_impl` would incorrectly raise a TypeError on inspection rather than on call ([#19332](https://github.com/Lightning-AI/lightning/pull/19332))
+
+
## [2.1.3] - 2023-12-21
### Changed
diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py
index a46a8a2ccf8c6..acfd17c308e6b 100644
--- a/src/lightning/pytorch/core/module.py
+++ b/src/lightning/pytorch/core/module.py
@@ -688,9 +688,11 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
Return:
- :class:`~torch.Tensor` - The loss tensor
- - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
- - ``None`` - Skip to the next batch. This is only supported for automatic optimization.
- This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
+ - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
+ automatic optimization.
+ - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
+ multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
+ the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch.
You can also do fancier things like multiple forward passes or something model specific.
diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py
index 9554439da2fa9..b685fb403497a 100644
--- a/src/lightning/pytorch/profilers/pytorch.py
+++ b/src/lightning/pytorch/profilers/pytorch.py
@@ -240,9 +240,8 @@ def __init__(
table_kwargs: Optional[Dict[str, Any]] = None,
**profiler_kwargs: Any,
) -> None:
- r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
-
- different operators inside your model - both on the CPU and GPU
+ r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
+ different operators inside your model - both on the CPU and GPU.
Args:
dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
diff --git a/src/lightning/pytorch/utilities/consolidate_checkpoint.py b/src/lightning/pytorch/utilities/consolidate_checkpoint.py
new file mode 100644
index 0000000000000..6f150bab0f23c
--- /dev/null
+++ b/src/lightning/pytorch/utilities/consolidate_checkpoint.py
@@ -0,0 +1,30 @@
+import re
+from typing import Any, Dict
+
+import torch
+
+from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args
+from lightning.fabric.utilities.load import _load_distributed_checkpoint
+
+
+def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
+ """Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load."""
+ # Rename the model key
+ checkpoint["state_dict"] = checkpoint.pop("model")
+
+ optimizer_keys = [key for key in checkpoint if re.match("optimizer_[0-9]+", key)]
+ if not optimizer_keys:
+ return checkpoint
+
+ # Optimizers are saved in special keys named `optimizer_0`, `optimizer_1`, etc.
+ # These need to be merged back into a Python list
+ checkpoint["optimizer_states"] = [checkpoint.pop(f"optimizer_{opt_idx}") for opt_idx in range(len(optimizer_keys))]
+ return checkpoint
+
+
+if __name__ == "__main__":
+ args = _parse_cli_args()
+ config = _process_cli_args(args)
+ checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
+ checkpoint = _format_checkpoint(checkpoint)
+ torch.save(checkpoint, config.output_file)
diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py
index bb307be1db88f..dde6cf2a33b02 100644
--- a/src/lightning/pytorch/utilities/model_helpers.py
+++ b/src/lightning/pytorch/utilities/model_helpers.py
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import inspect
import logging
import os
-from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar
from lightning_utilities.core.imports import RequirementCache
@@ -108,18 +108,23 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]):
"""Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance
instead of a class type."""
- def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None:
+ def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None:
self.method = method
def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]:
- # Workaround for https://github.com/pytorch/pytorch/issues/67146
- is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
- if instance is not None and not is_scripting:
- raise TypeError(
- f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
- " Please call it on the class type and make sure the return value is used."
- )
- return MethodType(self.method, cls)
+ # The wrapper ensures that the method can be inspected, but not called on an instance
+ @functools.wraps(self.method)
+ def wrapper(*args: Any, **kwargs: Any) -> _R_co:
+ # Workaround for https://github.com/pytorch/pytorch/issues/67146
+ is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack())
+ if instance is not None and not is_scripting:
+ raise TypeError(
+ f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
+ " Please call it on the class type and make sure the return value is used."
+ )
+ return self.method(cls, *args, **kwargs)
+
+ return wrapper
# trick static type checkers into thinking it's a @classmethod
diff --git a/tests/tests_data/streaming/test_client.py b/tests/tests_data/streaming/test_client.py
index 0b18a2ae98270..e4d9d80cbdbe9 100644
--- a/tests/tests_data/streaming/test_client.py
+++ b/tests/tests_data/streaming/test_client.py
@@ -31,12 +31,6 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
botocore = mock.MagicMock()
monkeypatch.setattr(client, "botocore", botocore)
- instance_metadata_provider = mock.MagicMock()
- monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
-
- instance_metadata_fetcher = mock.MagicMock()
- monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
-
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
s3 = client.S3Client(1)
diff --git a/tests/tests_data/streaming/test_combined.py b/tests/tests_data/streaming/test_combined.py
index 1b72ac545e011..5d8d4baa16581 100644
--- a/tests/tests_data/streaming/test_combined.py
+++ b/tests/tests_data/streaming/test_combined.py
@@ -1,22 +1,36 @@
+import os
+import sys
+from unittest.mock import ANY
+
import pytest
+import torch
+from lightning.data.streaming.cache import Cache
from lightning.data.streaming.combined import CombinedStreamingDataset
+from lightning.data.streaming.dataloader import StreamingDataLoader
+from lightning.data.streaming.dataset import Dir, StreamingDataset
from torch.utils.data import IterableDataset
+from torch.utils.data.dataloader import DataLoader
+
+
+class TestCombinedStreamingDataset(CombinedStreamingDataset):
+ def _check_datasets(self, datasets) -> None:
+ pass
def test_combined_dataset_num_samples_yield():
- dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5))
+ dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5))
dataset_iter = iter(dataset)
data = list(dataset_iter)
assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8]
- dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5))
+ dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5))
dataset_iter = iter(dataset)
data = list(dataset_iter)
assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5]
- dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5))
+ dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5))
dataset_iter = iter(dataset)
data = [next(dataset_iter) for _ in range(5)]
@@ -54,14 +68,14 @@ def load_state_dict(self, state_dict):
def test_combined_dataset_state_dict():
- dataset = CombinedStreamingDataset(
+ dataset = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
assert dataset.state_dict(0, 1) == {}
dataset_iter = iter(dataset)
assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}}
- dataset2 = CombinedStreamingDataset(
+ dataset2 = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
assert dataset2.state_dict(0, 1) == {}
@@ -96,7 +110,7 @@ def test_combined_dataset_state_dict():
{"0": {"counter": 10}, "1": {"counter": 9}},
]
- dataset2 = CombinedStreamingDataset(
+ dataset2 = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
assert dataset2.state_dict(0, 1) == {}
@@ -104,7 +118,7 @@ def test_combined_dataset_state_dict():
data_2 = []
for state in states:
- dataset.load_state_dict(state)
+ dataset.load_state_dict({"dataset": state})
data_2.append(next(dataset2_iter))
assert data == data_2
@@ -122,7 +136,7 @@ def test_combined_dataset_state_dict():
],
)
def test_combined_dataset_normalizes_weights(weights, expected):
- combined_dataset = CombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1)
+ combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1)
assert combined_dataset._weights == expected
@@ -135,28 +149,687 @@ def __init__(self, start, end):
def __iter__(self):
return iter(range(self._start, self._end))
+ def state_dict(self, **kwargs):
+ return kwargs
+
+ def set_epoch(self, current_epoch):
+ pass
+
def test_combined_dataset():
dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
- dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345)
+ dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345)
res = list(dataset)
assert res == list(range(0, 10))
dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
- dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345)
+ dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345)
res = list(dataset)
assert res == list(range(10, 20))
dataset1 = SimpleDataset(0, 10)
dataset2 = SimpleDataset(10, 20)
- dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
+ dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
res = list(dataset)
assert 9 in res or 19 in res
if len(res) > 10:
assert 0 in res
assert 10 in res
+
+ dataset1 = SimpleDataset(0, 10)
+ dataset2 = SimpleDataset(10, 20)
+ dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
+ dataloader = DataLoader(dataset, batch_size=2, num_workers=1)
+ dataloader_iter = iter(dataloader)
+ assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1]))
+
+
+@pytest.mark.parametrize("batch_size", [1, 2])
+def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
+ dataset1 = SimpleDataset(0, 10)
+ dataset2 = SimpleDataset(10, 20)
+ dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
+ dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1)
+ dataloader_iter = iter(dataloader)
+
+ if batch_size == 2:
+ assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([10, 2]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([3, 4]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([11, 5]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([6, 7]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([12, 8]))
+
+ else:
+ assert torch.equal(next(dataloader_iter), torch.Tensor([0]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([1]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([10]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([2]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([3]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([4]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([11]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([5]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([6]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([7]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([12]))
+ assert torch.equal(next(dataloader_iter), torch.Tensor([8]))
+
+ assert dataloader.state_dict() == {
+ "dataset": {
+ "0": {"num_samples_yielded": 9, "num_workers": 1, "batch_size": batch_size},
+ "1": {"num_samples_yielded": 3, "num_workers": 1, "batch_size": batch_size},
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [9, 3]},
+ }
+
+
+@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow in CI")
+def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
+ data_dir_1 = os.path.join(tmpdir, "data_1")
+ data_dir_2 = os.path.join(tmpdir, "data_2")
+ cache_dir_1 = os.path.join(tmpdir, "cache_dir_1")
+ cache_dir_2 = os.path.join(tmpdir, "cache_dir_2")
+
+ os.makedirs(data_dir_1)
+ os.makedirs(data_dir_2)
+ os.makedirs(cache_dir_1)
+ os.makedirs(cache_dir_2)
+
+ cache = Cache(input_dir=str(data_dir_1), chunk_size=2)
+
+ for i in range(10):
+ cache[i] = i
+
+ cache.done()
+ cache.merge()
+
+ cache = Cache(input_dir=str(data_dir_2), chunk_size=2)
+
+ for i in range(10):
+ cache[i] = i + 5
+
+ cache.done()
+ cache.merge()
+
+ dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True)
+ dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True)
+ dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345)
+ dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2)
+
+ assert dataset1.current_epoch == 1
+ assert dataset2.current_epoch == 1
+
+ batches_1 = []
+ states_1 = []
+ for batch in dataloader:
+ batches_1.append(batch)
+ states_1.append(dataloader.state_dict())
+
+ assert dataset1.current_epoch == 1
+ assert dataset2.current_epoch == 1
+
+ batches_2 = []
+ states_2 = []
+ for batch in dataloader:
+ batches_2.append(batch)
+ states_2.append(dataloader.state_dict())
+ assert dataset1.current_epoch == 2
+ assert dataset2.current_epoch == 2
+
+ assert sum(torch.equal(b1, b2) for b1, b2 in zip(batches_1, batches_2)) != len(batches_1)
+
+ assert states_1 == [
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 2,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 4,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [2, 0], 1: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 6,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 2,
+ "num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 7,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 1,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 8,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 2,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 9,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 2,
+ "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 11,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 13,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 1,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 0,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
+ },
+ ]
+
+ assert states_2 == [
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 2,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 4,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [2, 0], 1: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 6,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 0,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 2,
+ "num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 7,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 1,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 8,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 2,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 9,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 2,
+ "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 11,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]},
+ },
+ {
+ "dataset": {
+ "0": {
+ "num_samples_yielded": 13,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ "1": {
+ "num_samples_yielded": 3,
+ "num_workers": 3,
+ "batch_size": 2,
+ "current_epoch": 2,
+ "input_dir_path": ANY,
+ "input_dir_url": ANY,
+ "item_loader": None,
+ "drop_last": False,
+ "seed": 42,
+ "world_size": 1,
+ "shuffle": True,
+ },
+ },
+ "current_epoch": 1,
+ "latest_worker_idx": 1,
+ "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]},
+ },
+ ]
+
+ dataloader.load_state_dict(states_2[1])
+
+ assert dataloader.restore
+
+ batches_23 = []
+ states_23 = []
+ for batch in dataloader:
+ batches_23.append(batch)
+ states_23.append(dataloader.state_dict())
+
+ assert sum(not torch.equal(b1, b2) for b1, b2 in zip(batches_2[2:], batches_23)) == 0
+ assert states_23[0]["current_epoch"] == 1
+
+ assert not dataloader.restore
diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py
index 9c8ba63391d86..33edc1d724b14 100644
--- a/tests/tests_data/streaming/test_data_processor.py
+++ b/tests/tests_data/streaming/test_data_processor.py
@@ -581,7 +581,7 @@ def prepare_structure(self, input_dir: str):
filepaths = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
return [filepath for filepath in filepaths if os.path.isfile(filepath)]
- def prepare_item(self, output_dir: str, filepath: Any) -> None:
+ def prepare_item(self, filepath: Any, output_dir: str) -> None:
from PIL import Image
img = Image.open(filepath)
@@ -628,7 +628,7 @@ def test_data_process_transform(monkeypatch, tmpdir):
assert img.size == (12, 12)
-def map_fn(output_dir, filepath):
+def map_fn(filepath, output_dir):
from PIL import Image
img = Image.open(filepath)
@@ -833,7 +833,7 @@ def fn(output_dir, item, device):
data_recipe = LambdaDataTransformRecipe(fn, range(1))
- data_recipe.prepare_item("", 1)
+ data_recipe.prepare_item(1, "")
assert called
@@ -847,13 +847,13 @@ def test_lambda_transform_recipe_class(monkeypatch):
called = False
class Transform:
- def __call__(self, output_dir, item, device):
+ def __call__(self, item, output_dir, device):
nonlocal called
assert device == "cuda:2"
called = True
data_recipe = LambdaDataTransformRecipe(Transform(), range(1))
- data_recipe.prepare_item("", 1)
+ data_recipe.prepare_item(1, "")
assert called
@@ -894,7 +894,7 @@ def test_get_item_filesizes(tmp_path):
_get_item_filesizes([str(tmp_path / "empty_file")])
-def map_fn_index(output_dir, index):
+def map_fn_index(index, output_dir):
with open(os.path.join(output_dir, f"{index}.JPEG"), "w") as f:
f.write("Hello")
diff --git a/tests/tests_data/streaming/test_dataloader.py b/tests/tests_data/streaming/test_dataloader.py
index 88c2f84def0ff..293a96636adae 100644
--- a/tests/tests_data/streaming/test_dataloader.py
+++ b/tests/tests_data/streaming/test_dataloader.py
@@ -1,5 +1,9 @@
+import os
+
+import pytest
import torch
from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader
+from lightning.data.streaming import dataloader as streaming_dataloader_module
from torch import tensor
@@ -29,9 +33,17 @@ def state_dict(self, *args, **kwargs):
def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
+ def set_epoch(self, current_epoch):
+ pass
+
+
+class TestCombinedStreamingDataset(CombinedStreamingDataset):
+ def _check_datasets(self, datasets) -> None:
+ pass
+
def test_streaming_dataloader():
- dataset = CombinedStreamingDataset(
+ dataset = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
dataloader = StreamingDataLoader(dataset, batch_size=2)
@@ -56,4 +68,27 @@ def test_streaming_dataloader():
for exp, gen in zip(expected, batches):
assert torch.equal(exp, gen)
- assert dataloader.state_dict() == {"0": {"counter": 10}, "1": {"counter": 9}}
+ assert dataloader.state_dict() == {
+ "dataset": {"0": {"counter": 10}, "1": {"counter": 9}},
+ "current_epoch": 0,
+ "latest_worker_idx": 0,
+ "num_samples_yielded": {0: [11, 9]},
+ }
+
+
+@pytest.mark.parametrize("profile", [2, True])
+def test_dataloader_profiling(profile, tmpdir, monkeypatch):
+ monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True)
+
+ dataset = TestCombinedStreamingDataset(
+ [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
+ )
+ dataloader = StreamingDataLoader(
+ dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1
+ )
+ dataloader_iter = iter(dataloader)
+ batches = []
+ for batch in dataloader_iter:
+ batches.append(batch)
+
+ assert os.path.exists(os.path.join(tmpdir, "result.json"))
diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py
index f3f26595ec44f..db294f6ea564f 100644
--- a/tests/tests_data/streaming/test_dataset.py
+++ b/tests/tests_data/streaming/test_dataset.py
@@ -69,8 +69,10 @@ def test_streaming_dataset(tmpdir, monkeypatch):
def test_should_replace_path():
assert _should_replace_path(None)
assert _should_replace_path("")
- assert _should_replace_path(".../datasets/...")
- assert _should_replace_path(".../_connections/...")
+ assert not _should_replace_path(".../datasets/...")
+ assert not _should_replace_path(".../s3__connections/...")
+ assert _should_replace_path("/teamspace/datasets/...")
+ assert _should_replace_path("/teamspace/s3_connections/...")
assert not _should_replace_path("something_else")
@@ -647,21 +649,22 @@ def test_resumable_dataset_two_workers(tmpdir):
_ = next(dataloader_iter)
state_dict_0 = dataloader.state_dict()
- assert state_dict_0["0"]["num_samples_yielded"] == 2
- assert state_dict_0["0"]["num_workers"] == 2
- assert state_dict_0["0"]["batch_size"] == 2
+
+ assert state_dict_0["dataset"]["num_samples_yielded"] == 2
+ assert state_dict_0["dataset"]["num_workers"] == 2
+ assert state_dict_0["dataset"]["batch_size"] == 2
_ = next(dataloader_iter)
state_dict_1 = dataloader.state_dict()
- assert state_dict_1["0"]["num_samples_yielded"] == 4
- assert state_dict_1["0"]["num_workers"] == 2
- assert state_dict_1["0"]["batch_size"] == 2
+ assert state_dict_1["dataset"]["num_samples_yielded"] == 4
+ assert state_dict_1["dataset"]["num_workers"] == 2
+ assert state_dict_1["dataset"]["batch_size"] == 2
batch_2 = next(dataloader_iter)
state_dict_2 = dataloader.state_dict()
- assert state_dict_2["0"]["num_samples_yielded"] == 6
- assert state_dict_2["0"]["num_workers"] == 2
- assert state_dict_2["0"]["batch_size"] == 2
+ assert state_dict_2["dataset"]["num_samples_yielded"] == 6
+ assert state_dict_2["dataset"]["num_workers"] == 2
+ assert state_dict_2["dataset"]["batch_size"] == 2
dataset = EmulateS3StreamingDataset(
input_dir=Dir(cache_dir, data_dir),
@@ -669,21 +672,17 @@ def test_resumable_dataset_two_workers(tmpdir):
shuffle=True,
)
- dataset.load_state_dict(state_dict_1)
dataloader = StreamingDataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1)
+ dataloader.load_state_dict(state_dict_1)
dataloader_iter = iter(dataloader)
batch_0_restart = next(dataloader_iter)
- state_dict_2 = dataloader.state_dict()
- assert len(state_dict_2) == 2
- assert state_dict_2["0"]["num_samples_yielded"] == 4
- assert state_dict_2["0"]["num_workers"] == 2
- assert state_dict_2["0"]["batch_size"] == 2
+ state_dict_2 = dataloader.state_dict()["dataset"]
- assert state_dict_2["1"]["num_samples_yielded"] == 2
- assert state_dict_2["1"]["num_workers"] == 2
- assert state_dict_2["1"]["batch_size"] == 2
+ assert state_dict_2["num_samples_yielded"] == 6
+ assert state_dict_2["num_workers"] == 2
+ assert state_dict_2["batch_size"] == 2
assert torch.equal(batch_2, batch_0_restart)
@@ -738,7 +737,7 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir):
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
-def test_dataset_valid_state(tmpdir):
+def test_dataset_valid_state(tmpdir, monkeypatch):
seed_everything(42)
data_dir = os.path.join(tmpdir, "data")
@@ -776,7 +775,7 @@ def test_dataset_valid_state(tmpdir):
dataset._validate_state_dict()
- state_dict["0"]["drop_last"] = True
+ state_dict["drop_last"] = True
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -784,7 +783,7 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["item_loader"] = {}
+ state_dict["item_loader"] = {}
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -792,7 +791,7 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["seed"] = 12
+ state_dict["seed"] = 12
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -800,7 +799,7 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["input_dir_url"] = "toto"
+ state_dict["input_dir_url"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -808,7 +807,7 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["input_dir_path"] = "toto"
+ state_dict["input_dir_path"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -816,7 +815,15 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["num_workers"] = "8"
+ state_dict["input_dir_path"] = "/teamspace/datasets/coco"
+ dataset.load_state_dict(state_dict)
+ with pytest.raises(
+ ValueError,
+ match=f"The provided `input_dir` path state doesn't match the current one. Found `{cache_dir}` instead of ", # noqa E501
+ ):
+ dataset._validate_state_dict()
+
+ state_dict["num_workers"] = "8"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
@@ -824,7 +831,7 @@ def test_dataset_valid_state(tmpdir):
):
dataset._validate_state_dict()
- state_dict["0"]["shuffle"] = True
+ state_dict["shuffle"] = True
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
diff --git a/tests/tests_data/streaming/test_downloader.py b/tests/tests_data/streaming/test_downloader.py
new file mode 100644
index 0000000000000..3d4e5421711c2
--- /dev/null
+++ b/tests/tests_data/streaming/test_downloader.py
@@ -0,0 +1,13 @@
+import os
+from unittest.mock import MagicMock
+
+from lightning.data.streaming.downloader import S3Downloader, subprocess
+
+
+def test_s3_downloader_fast(tmpdir, monkeypatch):
+ monkeypatch.setattr(os, "system", MagicMock(return_value=0))
+ popen_mock = MagicMock()
+ monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock))
+ downloader = S3Downloader(tmpdir, tmpdir, [])
+ downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
+ popen_mock.wait.assert_called()
diff --git a/tests/tests_data/streaming/test_functions.py b/tests/tests_data/streaming/test_functions.py
index b5d1b737b5c43..10bf40caf7c2f 100644
--- a/tests/tests_data/streaming/test_functions.py
+++ b/tests/tests_data/streaming/test_functions.py
@@ -1,8 +1,10 @@
+import os
import sys
from unittest import mock
import pytest
-from lightning.data.streaming.functions import _get_input_dir, os
+from lightning.data import walk
+from lightning.data.streaming.functions import _get_input_dir
@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
@@ -19,3 +21,17 @@ def fn(*_, **__):
with pytest.raises(ValueError, match="The provided item didn't contain any filepaths."):
assert _get_input_dir(["", "/teamspace/studios/asd/b"])
+
+
+def test_walk(tmpdir):
+ for i in range(5):
+ folder_path = os.path.join(tmpdir, str(i))
+ os.makedirs(folder_path, exist_ok=True)
+ for j in range(5):
+ filepath = os.path.join(folder_path, f"{j}.txt")
+ with open(filepath, "w") as f:
+ f.write("hello world !")
+
+ walks_os = sorted(os.walk(tmpdir))
+ walks_function = sorted(walk(tmpdir))
+ assert walks_os == walks_function
diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py
index cc9c687487f92..c4f4a856007bc 100644
--- a/tests/tests_fabric/plugins/precision/test_amp_integration.py
+++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py
@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Automatic Mixed Precision (AMP) training."""
+import sys
+
import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from tests_fabric.helpers.runif import RunIf
@@ -37,6 +40,11 @@ def forward(self, x):
return output
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.parametrize(
("accelerator", "precision", "expected_dtype"),
[
diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py
index ee1bee5cc5c81..c986ee0db91bb 100644
--- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py
+++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py
@@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import sys
+
import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from tests_fabric.helpers.runif import RunIf
@@ -28,6 +31,11 @@ def __init__(self):
self.register_buffer("buffer", torch.ones(3))
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py
index a0c810669d8da..251d1870d90ac 100644
--- a/tests/tests_fabric/strategies/test_ddp_integration.py
+++ b/tests/tests_fabric/strategies/test_ddp_integration.py
@@ -11,16 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+import sys
from copy import deepcopy
+from unittest import mock
+from unittest.mock import Mock
import pytest
import torch
from lightning.fabric import Fabric
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2
+from torch.nn.parallel.distributed import DistributedDataParallel
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
+from tests_fabric.test_fabric import BoringModel
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.parametrize(
"accelerator",
[
@@ -64,6 +76,40 @@ def assert_params_equal(params0, params1):
assert_params_equal(params_before, wrapped_model.parameters())
+@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
+@mock.patch(
+ "lightning.fabric.wrappers.torch.compile",
+ Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
+)
+@mock.patch.dict(os.environ, {})
+def test_reapply_compile():
+ """Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
+ from torch._dynamo import OptimizedModule
+
+ fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp")
+ fabric.launch()
+
+ model = BoringModel()
+ compile_kwargs = {"mode": "reduce-overhead"}
+ compiled_model = torch.compile(model, **compile_kwargs)
+ torch.compile.reset_mock()
+
+ fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
+
+ assert isinstance(fabric_model._forward_module, OptimizedModule)
+ assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
+ # Assert we called compile again with the same arguments, but on the DDP-wrapped module
+ torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
+
+ assert fabric_model._original_module == model
+ assert fabric_model._forward_module._orig_mod.module == model
+ assert fabric_model.device == fabric.device
+
+ # Smoke-testing forward to ensure we don't get compilation errors
+ for _ in range(3):
+ fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
+
+
@pytest.mark.parametrize(
("clip_type", "accelerator", "precision"),
[
diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py
index dc7b18095e0c3..704a86a997bc5 100644
--- a/tests/tests_fabric/strategies/test_fsdp_integration.py
+++ b/tests/tests_fabric/strategies/test_fsdp_integration.py
@@ -15,13 +15,18 @@
from copy import deepcopy
from pathlib import Path
from unittest import mock
+from unittest.mock import Mock
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
+from lightning.fabric.utilities.imports import (
+ _TORCH_GREATER_EQUAL_2_0,
+ _TORCH_GREATER_EQUAL_2_1,
+)
+from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.fabric.wrappers import _FabricOptimizer
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
@@ -344,33 +349,40 @@ def test_setup_with_orig_params_and_multiple_param_groups():
assert not isinstance(layer.weight, FlatParameter)
-@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, dynamo=True)
-@mock.patch.dict(os.environ, {})
-@pytest.mark.parametrize(
- "compile_after_setup",
- [
- False,
- # https://github.com/pytorch/pytorch/issues/97811
- pytest.param(True, marks=RunIf(min_python="3.9")),
- ],
+@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
+@mock.patch(
+ "lightning.fabric.wrappers.torch.compile",
+ Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
)
-def test_compile(compile_after_setup):
- """Test that the model can be compiled before and after the model is wrapped in FSDP."""
- model = BoringModel()
+@mock.patch.dict(os.environ, {})
+def test_reapply_compile():
+ """Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
+ from torch._dynamo import OptimizedModule
+
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
- if not compile_after_setup:
- model = torch.compile(model)
+ model = BoringModel()
+ compile_kwargs = {"mode": "reduce-overhead"}
+ compiled_model = torch.compile(model, **compile_kwargs)
+ torch.compile.reset_mock()
+
+ fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
- model = fabric.setup(model)
+ assert isinstance(fabric_model._forward_module, OptimizedModule)
+ assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)
- if compile_after_setup:
- model = torch.compile(model)
+ # Assert we called compile again with the same arguments, but on the FSDP-wrapped module
+ torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
+ assert fabric_model._original_module == model
+ assert fabric_model._forward_module._orig_mod.module == model
+ assert fabric_model.device == fabric.device
+
+ # Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
- model(torch.rand(2, 32, device=fabric.device)).sum().backward()
+ fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@@ -549,3 +561,52 @@ def test_clip_gradients(clip_type, precision):
optimizer.step()
optimizer.zero_grad()
+
+
+# TODO: Support checkpoint consolidation with PyTorch >= 2.2
+@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
+def test_save_sharded_and_consolidate_and_load(tmp_path):
+ """Test the consolidation of a FSDP-sharded checkpoint into a single file."""
+
+ fabric = Fabric(
+ accelerator="cuda",
+ strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"),
+ devices=2,
+ )
+ fabric.launch()
+
+ model = BoringModel()
+ optimizer = torch.optim.Adam(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+
+ # run one iteration to init the state of the optimizer
+ model(torch.rand(1, 32, device=fabric.device)).sum().backward()
+ optimizer.step()
+
+ checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
+ fabric.save(checkpoint_path_sharded, state)
+ assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
+
+ # consolidate the checkpoint to a single file
+ checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
+ if fabric.global_rank == 0:
+ checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
+ torch.save(checkpoint, checkpoint_path_full)
+ fabric.barrier()
+
+ # re-init and load from full checkpoint
+ fabric = Fabric(
+ accelerator="cuda",
+ strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
+ devices=2,
+ )
+
+ # Hack: we already called launch() on another Fabric instance above
+ fabric._launched = True
+
+ model = BoringModel()
+ optimizer = torch.optim.Adam(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+ fabric.load(checkpoint_path_full, state)
diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py
index b9e7a9a3430d5..2c860eb45d78d 100644
--- a/tests/tests_fabric/test_fabric.py
+++ b/tests/tests_fabric/test_fabric.py
@@ -89,18 +89,28 @@ def test_setup_module(ddp_mock, setup_method):
@RunIf(skip_windows=True, dynamo=True)
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
-def test_setup_compiled_module(setup_method):
+@pytest.mark.parametrize("reapply_compile", [True, False, None])
+def test_setup_compiled_module(reapply_compile, setup_method):
"""Test that an `OptimizedModule` can be passed to the setup method."""
from torch._dynamo.eval_frame import OptimizedModule
fabric = Fabric(devices=1)
model = nn.Linear(1, 2)
compiled_model = torch.compile(model)
+ assert compiled_model._compile_kwargs is not None
assert isinstance(compiled_model, OptimizedModule)
setup_method = getattr(fabric, setup_method)
- fabric_model = setup_method(compiled_model)
-
- assert fabric_model.module == compiled_model
+ fabric_model = setup_method(compiled_model, _reapply_compile=reapply_compile)
+
+ assert isinstance(fabric_model._forward_module, OptimizedModule)
+ if reapply_compile:
+ # The forward_module got rewrapped into a new OptimizedModule
+ assert fabric_model._forward_module != fabric_model._original_module
+ # The original_module points to the pure module
+ assert fabric_model._original_module is model
+ assert fabric_model._forward_module._orig_mod is model
+ else:
+ assert fabric_model._forward_module is fabric_model._original_module
# Attributes get passed through
assert fabric_model.weight is model.weight
diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py
index 820d9032f5b99..09d731c064d39 100644
--- a/tests/tests_fabric/test_wrappers.py
+++ b/tests/tests_fabric/test_wrappers.py
@@ -24,6 +24,7 @@
_FabricDataLoader,
_FabricModule,
_FabricOptimizer,
+ _unwrap_compiled,
_unwrap_objects,
is_wrapped,
)
@@ -593,3 +594,23 @@ def normal_method(self):
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
assert fabric_module.validation_step == original_module.validation_step
+
+
+@RunIf(dynamo=True)
+def test_unwrap_compiled():
+ model = torch.nn.Linear(1, 1)
+
+ with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False):
+ unwrapped, compile_kwargs = _unwrap_compiled(model)
+ assert unwrapped is model
+ assert compile_kwargs is None
+
+ compiled = torch.compile(model, fullgraph=True, dynamic=True, disable=False)
+ assert compiled._compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
+ unwrapped, compile_kwargs = _unwrap_compiled(compiled)
+ assert unwrapped is compiled._orig_mod
+ assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
+
+ del compiled._compile_kwargs
+ with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
+ _unwrap_compiled(compiled)
diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py
new file mode 100644
index 0000000000000..16feb3d3c1014
--- /dev/null
+++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py
@@ -0,0 +1,97 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from argparse import Namespace
+from pathlib import Path
+from unittest import mock
+
+import lightning.fabric
+import pytest
+from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args
+from lightning.fabric.utilities.load import _METADATA_FILENAME
+
+
+@pytest.mark.parametrize(
+ ("args", "expected"),
+ [
+ (["path/to/checkpoint"], {"checkpoint_folder": "path/to/checkpoint", "output_file": None}),
+ (
+ ["path/to/checkpoint", "--output_file", "path/to/output"],
+ {"checkpoint_folder": "path/to/checkpoint", "output_file": "path/to/output"},
+ ),
+ ],
+)
+def test_parse_cli_args(args, expected):
+ with mock.patch("sys.argv", ["any.py", *args]):
+ args = _parse_cli_args()
+ assert vars(args) == expected
+
+
+def test_process_cli_args(tmp_path, caplog, monkeypatch):
+ # PyTorch version < 2.1
+ monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", False)
+ with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
+ SystemExit
+ ):
+ _process_cli_args(Namespace())
+ assert "requires PyTorch >= 2.1." in caplog.text
+ caplog.clear()
+ monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", True)
+
+ # Checkpoint does not exist
+ checkpoint_folder = Path("does/not/exist")
+ with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
+ SystemExit
+ ):
+ _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder))
+ assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text
+ caplog.clear()
+
+ # Checkpoint exists but is not a folder
+ file = tmp_path / "checkpoint_file"
+ file.touch()
+ with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
+ SystemExit
+ ):
+ _process_cli_args(Namespace(checkpoint_folder=file))
+ assert "checkpoint path must be a folder" in caplog.text
+ caplog.clear()
+
+ # Checkpoint exists but is not an FSDP checkpoint
+ folder = tmp_path / "checkpoint_folder"
+ folder.mkdir()
+ with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
+ SystemExit
+ ):
+ _process_cli_args(Namespace(checkpoint_folder=folder))
+ assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text
+ caplog.clear()
+
+ # Checkpoint is a FSDP folder, output file not specified
+ (folder / _METADATA_FILENAME).touch()
+ config = _process_cli_args(Namespace(checkpoint_folder=folder, output_file=None))
+ assert vars(config) == {
+ "checkpoint_folder": folder,
+ "output_file": folder.with_suffix(folder.suffix + ".consolidated"),
+ }
+
+ # Checkpoint is a FSDP folder, output file already exists
+ file = tmp_path / "ouput_file"
+ file.touch()
+ with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
+ SystemExit
+ ):
+ _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file))
+ assert "path for the converted checkpoint already exists" in caplog.text
+ caplog.clear()
diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py
index ad32c936a1c4d..4badff47403eb 100644
--- a/tests/tests_fabric/utilities/test_distributed.py
+++ b/tests/tests_fabric/utilities/test_distributed.py
@@ -1,5 +1,6 @@
import functools
import os
+import sys
from functools import partial
from pathlib import Path
from unittest import mock
@@ -17,6 +18,7 @@
_sync_ddp,
is_shared_filesystem,
)
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from tests_fabric.helpers.runif import RunIf
@@ -118,6 +120,11 @@ def test_collective_operations(devices, process):
spawn_launch(process, devices)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py
index 3dccf2258e74b..eb534bf1cdca1 100644
--- a/tests/tests_fabric/utilities/test_load.py
+++ b/tests/tests_fabric/utilities/test_load.py
@@ -14,7 +14,13 @@
import pytest
import torch
import torch.nn as nn
-from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _move_state_into, _NotYetLoadedTensor
+from lightning.fabric.utilities.load import (
+ _lazy_load,
+ _materialize_tensors,
+ _move_state_into,
+ _NotYetLoadedTensor,
+ _unflatten_dict,
+)
from tests_fabric.helpers.runif import RunIf
@@ -139,3 +145,24 @@ def load_state_dict(self, state_dict):
assert source == {}
assert destination["cocofruit"] == 2
assert destination["banana"].count == 100
+
+
+def test_unflatten_dict():
+ assert _unflatten_dict({}, {}) == {}
+
+ tensor0 = torch.rand(2, 2)
+ tensor1 = torch.tensor(3.0)
+ data = {
+ "model.layer.weight": tensor0,
+ "optimizer.state.layer.weight.exp_avg": {"test": tensor1},
+ "optimizer.param_groups": "param_groups",
+ }
+ key_map = {
+ "model.layer.weight": ("model", "layer.weight"),
+ "optimizer.state.layer.weight.exp_avg": ("optimizer", "state", "layer.weight", "exp_avg"),
+ "optimizer.param_groups": ("optimizer", "param_groups"),
+ }
+ assert _unflatten_dict(data, key_map) == {
+ "model": {"layer.weight": tensor0},
+ "optimizer": {"state": {"layer.weight": {"exp_avg": {"test": tensor1}}}, "param_groups": "param_groups"},
+ }
diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py
index 9739540af7f18..ab7d9f474383d 100644
--- a/tests/tests_fabric/utilities/test_spike.py
+++ b/tests/tests_fabric/utilities/test_spike.py
@@ -4,6 +4,7 @@
import pytest
import torch
from lightning.fabric import Fabric
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
@@ -28,6 +29,11 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py
index 2a66ce25e9861..f2c3de30a325b 100644
--- a/tests/tests_fabric/utilities/test_throughput.py
+++ b/tests/tests_fabric/utilities/test_throughput.py
@@ -49,8 +49,8 @@ def test_get_available_flops(xla_available):
from torch_xla.experimental import tpu
assert isinstance(tpu, Mock)
- tpu.get_tpu_env.return_value = {"TYPE": "V4"}
+ tpu.get_tpu_env.return_value = {"TYPE": "V4"}
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
assert flops == 275e12
@@ -58,6 +58,10 @@ def test_get_available_flops(xla_available):
with pytest.warns(match="not found for TPU 'V1'"):
assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
+ tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
+ flops = get_available_flops(torch.device("xla"), torch.bfloat16)
+ assert flops == 123e12
+
tpu.reset_mock()
diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py
index f4d0c946cefa0..b3bc54bd6e836 100644
--- a/tests/tests_pytorch/callbacks/test_spike.py
+++ b/tests/tests_pytorch/callbacks/test_spike.py
@@ -3,6 +3,7 @@
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.spike import SpikeDetection
@@ -46,6 +47,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py
index 2f4b281836510..5f6b9d6328d55 100644
--- a/tests/tests_pytorch/loops/test_prediction_loop.py
+++ b/tests/tests_pytorch/loops/test_prediction_loop.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
+import sys
import pytest
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
@@ -50,6 +52,11 @@ def predict_step(self, batch, batch_idx):
assert trainer.predict_loop.predictions == []
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py
index d51b139d6118d..5608c02026072 100644
--- a/tests/tests_pytorch/models/test_amp.py
+++ b/tests/tests_pytorch/models/test_amp.py
@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
+import sys
from unittest import mock
import pytest
import torch
from lightning.fabric.plugins.environments import SLURMEnvironment
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader
@@ -53,7 +55,16 @@ def _assert_autocast_enabled(self):
[
("single_device", "16-mixed", 1),
("single_device", "bf16-mixed", 1),
- ("ddp_spawn", "16-mixed", 2),
+ pytest.param(
+ "ddp_spawn",
+ "16-mixed",
+ 2,
+ marks=pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+ ),
+ ),
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
],
)
diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py
index 578619ece75eb..9d299da460770 100644
--- a/tests/tests_pytorch/serve/test_servable_module_validator.py
+++ b/tests/tests_pytorch/serve/test_servable_module_validator.py
@@ -1,7 +1,9 @@
+import sys
from typing import Dict
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
@@ -36,6 +38,11 @@ def test_servable_module_validator():
callback.on_train_start(Trainer(accelerator="cpu"), model)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.flaky(reruns=3)
def test_servable_module_validator_with_trainer(tmpdir):
callback = ServableModuleValidator()
diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
index beab1e20f46b3..d70a4776b81bd 100644
--- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
+++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
+import sys
from multiprocessing import Process
from unittest import mock
from unittest.mock import ANY, Mock, call, patch
@@ -19,6 +20,7 @@
import pytest
import torch
from lightning.fabric.plugins import ClusterEnvironment
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.strategies import DDPStrategy
@@ -194,6 +196,11 @@ def on_fit_start(self) -> None:
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
def test_memory_sharing_disabled():
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
@@ -214,6 +221,11 @@ def test_check_for_missing_main_guard():
launcher.launch(function=Mock())
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
def test_fit_twice_raises():
model = BoringModel()
trainer = Trainer(
diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py
index 7be116456154a..bcd46ebde7526 100644
--- a/tests/tests_pytorch/strategies/test_fsdp.py
+++ b/tests/tests_pytorch/strategies/test_fsdp.py
@@ -18,6 +18,7 @@
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
+from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -25,6 +26,7 @@
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
@@ -991,3 +993,40 @@ def _run_setup_assertions(empty_init, expected_device):
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
+
+
+# TODO: Support checkpoint consolidation with PyTorch >= 2.2
+@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
+def test_save_sharded_and_consolidate_and_load(tmp_path):
+ """Test the consolidation of a FSDP-sharded checkpoint into a single file."""
+
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="cuda",
+ devices=2,
+ strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"),
+ max_steps=3,
+ )
+ trainer.fit(model)
+
+ checkpoint_path_sharded = trainer.strategy.broadcast(str(trainer.checkpoint_callback.best_model_path))
+ assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
+
+ # consolidate the checkpoint to a single file
+ checkpoint_path_full = trainer.strategy.broadcast(str(tmp_path / "checkpoint_full.ckpt"))
+ if trainer.global_rank == 0:
+ checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
+ checkpoint = _format_checkpoint(checkpoint)
+ torch.save(checkpoint, checkpoint_path_full)
+ trainer.strategy.barrier()
+
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="cuda",
+ devices=2,
+ strategy="ddp",
+ max_steps=4,
+ )
+ trainer.fit(model, ckpt_path=checkpoint_path_full)
diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
index 5f77604a7a6d4..e9684657dd3c5 100644
--- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import sys
from re import escape
from typing import Sized
from unittest import mock
@@ -19,6 +20,7 @@
import lightning.fabric
import pytest
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
@@ -123,6 +125,11 @@ def on_train_end(self):
self.ctx.__exit__(None, None, None)
+@pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+)
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""
diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
index bc2abdacb0e04..bb17220610366 100644
--- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
+++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py
@@ -15,6 +15,7 @@
import collections
import itertools
+import sys
from re import escape
from unittest import mock
from unittest.mock import call
@@ -22,6 +23,7 @@
import numpy as np
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer, callbacks
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.core.module import LightningModule
@@ -346,7 +348,15 @@ def validation_step(self, batch, batch_idx):
("devices", "accelerator"),
[
(1, "cpu"),
- (2, "cpu"),
+ pytest.param(
+ 2,
+ "cpu",
+ marks=pytest.mark.xfail(
+ # https://github.com/pytorch/pytorch/issues/116056
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
+ reason="Windows + DDP issue in PyTorch 2.2",
+ ),
+ ),
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
],
)
diff --git a/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py b/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py
new file mode 100644
index 0000000000000..5bbaa4181ecdc
--- /dev/null
+++ b/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py
@@ -0,0 +1,33 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint
+
+
+def test_format_checkpoint():
+ # The 'model' key gets renamed to 'state_dict'
+ model = Mock()
+ checkpoint = {"model": model}
+ assert _format_checkpoint(checkpoint) == {"state_dict": model}
+
+ # Optimizer states with keys 'optimizer_0', 'optimizer_1', etc. get converted to a list
+ optimizer0 = Mock()
+ optimizer1 = Mock()
+ checkpoint = {"model": model, "optimizer_1": optimizer1, "optimizer_0": optimizer0, "optimizer_abc": "other"}
+ assert _format_checkpoint(checkpoint) == {
+ "state_dict": model,
+ "optimizer_states": [optimizer0, optimizer1],
+ "optimizer_abc": "other",
+ }
diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py
index 248a40cb38989..78a63a7e9d2a7 100644
--- a/tests/tests_pytorch/utilities/test_model_helpers.py
+++ b/tests/tests_pytorch/utilities/test_model_helpers.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import pytest
@@ -66,10 +67,12 @@ def cmethod(cls):
def test_restricted_classmethod():
+ restricted_method = RestrictedClass().restricted_cmethod # no exception when getting restricted method
+
with pytest.raises(TypeError, match="cannot be called on an instance"):
- RestrictedClass().restricted_cmethod()
+ restricted_method()
- RestrictedClass.restricted_cmethod() # no exception
+ _ = inspect.getmembers(RestrictedClass()) # no exception on inspecting instance
def test_module_mode():