Skip to content

Commit 68b74f3

Browse files
authored
Add GCP support for directory resolution inresolve_dir (#659)
* add support for gcp cloud providers in studio resolution * feat(streaming): add support for GCS folders resolution in _resolve_dir * feat(streaming): add GCS connections resolution in _resolve_dir * fix(resolver): improve error message for invalid dir_path type in _resolve_dir * feat(streaming): add GCS connections and folders resolvers in test cases
1 parent 268f3c7 commit 68b74f3

File tree

2 files changed

+130
-3
lines changed

2 files changed

+130
-3
lines changed

src/litdata/streaming/resolver.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
from contextlib import suppress
2020
from dataclasses import dataclass
21+
from enum import Enum
2122
from pathlib import Path
2223
from time import sleep
2324
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
@@ -38,6 +39,11 @@ class Dir:
3839
url: Optional[str] = None
3940

4041

42+
class CloudProvider(str, Enum):
43+
AWS = "aws"
44+
GCP = "gcp"
45+
46+
4147
def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
4248
if isinstance(dir_path, Dir):
4349
return Dir(path=str(dir_path.path) if dir_path.path else None, url=str(dir_path.url) if dir_path.url else None)
@@ -46,7 +52,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
4652
return Dir()
4753

4854
if not isinstance(dir_path, (str, Path)):
49-
raise ValueError(f"`dir_path` must be either a string, Path, or Dir, got: {dir_path}")
55+
raise ValueError(f"`dir_path` must be either a string, Path, or Dir, got: {type(dir_path)}")
5056

5157
if isinstance(dir_path, str):
5258
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
@@ -61,6 +67,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
6167
dir_path_absolute = str(Path(dir_path).absolute().resolve())
6268
dir_path = str(dir_path) # Convert to string if it was a Path object
6369

70+
# Handle special teamspace paths
6471
if dir_path_absolute.startswith("/teamspace/studios/this_studio"):
6572
return Dir(path=dir_path_absolute, url=None)
6673

@@ -73,9 +80,15 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
7380
if dir_path_absolute.startswith("/teamspace/s3_connections") and len(dir_path_absolute.split("/")) > 3:
7481
return _resolve_s3_connections(dir_path_absolute)
7582

83+
if dir_path_absolute.startswith("/teamspace/gcs_connections") and len(dir_path_absolute.split("/")) > 3:
84+
return _resolve_gcs_connections(dir_path_absolute)
85+
7686
if dir_path_absolute.startswith("/teamspace/s3_folders") and len(dir_path_absolute.split("/")) > 3:
7787
return _resolve_s3_folders(dir_path_absolute)
7888

89+
if dir_path_absolute.startswith("/teamspace/gcs_folders") and len(dir_path_absolute.split("/")) > 3:
90+
return _resolve_gcs_folders(dir_path_absolute)
91+
7992
if dir_path_absolute.startswith("/teamspace/datasets") and len(dir_path_absolute.split("/")) > 3:
8093
return _resolve_datasets(dir_path_absolute)
8194

@@ -104,6 +117,7 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option
104117
# Get the ids from env variables
105118
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
106119
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
120+
provider = os.getenv("LIGHTNING_CLOUD_PROVIDER", CloudProvider.AWS)
107121

108122
if cluster_id is None:
109123
raise RuntimeError("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables.")
@@ -126,12 +140,19 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option
126140
f"We didn't find a matching cluster associated with the id {target_cloud_space[0].cluster_id}."
127141
)
128142

129-
bucket_name = target_cluster[0].spec.aws_v1.bucket_name
143+
if provider == CloudProvider.AWS:
144+
bucket_name = target_cluster[0].spec.aws_v1.bucket_name
145+
scheme = "s3"
146+
elif provider == CloudProvider.GCP:
147+
bucket_name = target_cluster[0].spec.google_cloud_v1.bucket_name
148+
scheme = "gs"
149+
else:
150+
raise ValueError(f"Unsupported cloud provider: {provider}. Supported providers are AWS and GCP.")
130151

131152
return Dir(
132153
path=dir_path,
133154
url=os.path.join(
134-
f"s3://{bucket_name}/projects/{project_id}/cloudspaces/{target_cloud_space[0].id}/code/content",
155+
f"{scheme}://{bucket_name}/projects/{project_id}/cloudspaces/{target_cloud_space[0].id}/code/content",
135156
*dir_path.split("/")[4:],
136157
),
137158
)
@@ -159,6 +180,28 @@ def _resolve_s3_connections(dir_path: str) -> Dir:
159180
return Dir(path=dir_path, url=os.path.join(data_connection[0].aws.source, *dir_path.split("/")[4:]))
160181

161182

183+
def _resolve_gcs_connections(dir_path: str) -> Dir:
184+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
185+
186+
client = LightningClient(max_tries=2)
187+
188+
# Get the ids from env variables
189+
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
190+
if project_id is None:
191+
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")
192+
193+
target_name = dir_path.split("/")[3]
194+
195+
data_connections = client.data_connection_service_list_data_connections(project_id).data_connections
196+
197+
data_connection = [dc for dc in data_connections if dc.name == target_name]
198+
199+
if not data_connection:
200+
raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.")
201+
202+
return Dir(path=dir_path, url=os.path.join(data_connection[0].gcp.source, *dir_path.split("/")[4:]))
203+
204+
162205
def _resolve_s3_folders(dir_path: str) -> Dir:
163206
from lightning_sdk.lightning_cloud.rest_client import LightningClient
164207

@@ -181,6 +224,28 @@ def _resolve_s3_folders(dir_path: str) -> Dir:
181224
return Dir(path=dir_path, url=os.path.join(data_connection[0].s3_folder.source, *dir_path.split("/")[4:]))
182225

183226

227+
def _resolve_gcs_folders(dir_path: str) -> Dir:
228+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
229+
230+
client = LightningClient(max_tries=2)
231+
232+
# Get the ids from env variables
233+
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
234+
if project_id is None:
235+
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")
236+
237+
target_name = dir_path.split("/")[3]
238+
239+
data_connections = client.data_connection_service_list_data_connections(project_id).data_connections
240+
241+
data_connection = [dc for dc in data_connections if dc.name == target_name]
242+
243+
if not data_connection:
244+
raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.")
245+
246+
return Dir(path=dir_path, url=os.path.join(data_connection[0].gcs_folder.source, *dir_path.split("/")[4:]))
247+
248+
184249
def _resolve_datasets(dir_path: str) -> Dir:
185250
from lightning_sdk.lightning_cloud.rest_client import LightningClient
186251

tests/streaming/test_resolver.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,65 @@ def test_resolve_time_template():
403403
assert resolver._resolve_time_template(path_1) == f"/logs/log_{curr_year}-{curr_month:02d}"
404404
assert resolver._resolve_time_template(path_2) == path_2
405405
assert resolver._resolve_time_template(path_3) == f"/logs/log_{curr_year}-{curr_month:02d}/important"
406+
407+
408+
@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported")
409+
def test_src_resolver_gcs_connections(monkeypatch, lightning_cloud_mock):
410+
"""Test GCS connections resolver."""
411+
auth = login.Auth()
412+
auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")
413+
414+
with pytest.raises(
415+
RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables."
416+
):
417+
resolver._resolve_dir("/teamspace/gcs_connections/my_dataset")
418+
419+
monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")
420+
421+
client_mock = mock.MagicMock()
422+
client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse(
423+
data_connections=[V1DataConnection(name="my_dataset", gcp=mock.MagicMock(source="gs://my-gcs-bucket"))],
424+
)
425+
426+
client_cls_mock = mock.MagicMock()
427+
client_cls_mock.return_value = client_mock
428+
lightning_cloud_mock.rest_client.LightningClient = client_cls_mock
429+
430+
assert resolver._resolve_dir("/teamspace/gcs_connections/my_dataset").url == "gs://my-gcs-bucket"
431+
assert resolver._resolve_dir("/teamspace/gcs_connections/my_dataset/train").url == "gs://my-gcs-bucket/train"
432+
433+
# Test missing data connection
434+
client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse(
435+
data_connections=[],
436+
)
437+
438+
with pytest.raises(ValueError, match="name `my_dataset`"):
439+
resolver._resolve_dir("/teamspace/gcs_connections/my_dataset")
440+
441+
auth.clear()
442+
443+
444+
@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported")
445+
def test_src_resolver_gcs_folders(monkeypatch, lightning_cloud_mock):
446+
"""Test GCS folders resolver."""
447+
auth = login.Auth()
448+
auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")
449+
450+
monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")
451+
452+
client_mock = mock.MagicMock()
453+
client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse(
454+
data_connections=[
455+
V1DataConnection(name="debug_folder", gcs_folder=mock.MagicMock(source="gs://my-gcs-bucket"))
456+
],
457+
)
458+
459+
client_cls_mock = mock.MagicMock()
460+
client_cls_mock.return_value = client_mock
461+
lightning_cloud_mock.rest_client.LightningClient = client_cls_mock
462+
463+
expected = "gs://my-gcs-bucket"
464+
assert resolver._resolve_dir("/teamspace/gcs_folders/debug_folder").url == expected
465+
assert resolver._resolve_dir("/teamspace/gcs_folders/debug_folder/a/b/c").url == expected + "/a/b/c"
466+
467+
auth.clear()

0 commit comments

Comments
 (0)