diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index a69493ab1c875..6d013415fd3a2 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -20,8 +20,6 @@ import pytest import torch import torch.nn as nn -from torch.utils.data import DataLoader - from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy @@ -34,6 +32,7 @@ from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap from torch.nn import Parameter +from torch.utils.data import DataLoader from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf