1818import sys
1919from contextlib import suppress
2020from dataclasses import dataclass
21+ from enum import Enum
2122from pathlib import Path
2223from time import sleep
2324from typing import TYPE_CHECKING , Any , Literal , Optional , Union
@@ -38,6 +39,11 @@ class Dir:
3839 url : Optional [str ] = None
3940
4041
42+ class CloudProvider (str , Enum ):
43+ AWS = "aws"
44+ GCP = "gcp"
45+
46+
4147def _resolve_dir (dir_path : Optional [Union [str , Path , Dir ]]) -> Dir :
4248 if isinstance (dir_path , Dir ):
4349 return Dir (path = str (dir_path .path ) if dir_path .path else None , url = str (dir_path .url ) if dir_path .url else None )
@@ -46,7 +52,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
4652 return Dir ()
4753
4854 if not isinstance (dir_path , (str , Path )):
49- raise ValueError (f"`dir_path` must be either a string, Path, or Dir, got: { dir_path } " )
55+ raise ValueError (f"`dir_path` must be either a string, Path, or Dir, got: { type ( dir_path ) } " )
5056
5157 if isinstance (dir_path , str ):
5258 cloud_prefixes = ("s3://" , "gs://" , "azure://" , "hf://" )
@@ -61,6 +67,7 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
6167 dir_path_absolute = str (Path (dir_path ).absolute ().resolve ())
6268 dir_path = str (dir_path ) # Convert to string if it was a Path object
6369
70+ # Handle special teamspace paths
6471 if dir_path_absolute .startswith ("/teamspace/studios/this_studio" ):
6572 return Dir (path = dir_path_absolute , url = None )
6673
@@ -73,9 +80,15 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
7380 if dir_path_absolute .startswith ("/teamspace/s3_connections" ) and len (dir_path_absolute .split ("/" )) > 3 :
7481 return _resolve_s3_connections (dir_path_absolute )
7582
83+ if dir_path_absolute .startswith ("/teamspace/gcs_connections" ) and len (dir_path_absolute .split ("/" )) > 3 :
84+ return _resolve_gcs_connections (dir_path_absolute )
85+
7686 if dir_path_absolute .startswith ("/teamspace/s3_folders" ) and len (dir_path_absolute .split ("/" )) > 3 :
7787 return _resolve_s3_folders (dir_path_absolute )
7888
89+ if dir_path_absolute .startswith ("/teamspace/gcs_folders" ) and len (dir_path_absolute .split ("/" )) > 3 :
90+ return _resolve_gcs_folders (dir_path_absolute )
91+
7992 if dir_path_absolute .startswith ("/teamspace/datasets" ) and len (dir_path_absolute .split ("/" )) > 3 :
8093 return _resolve_datasets (dir_path_absolute )
8194
@@ -104,6 +117,7 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option
104117 # Get the ids from env variables
105118 cluster_id = os .getenv ("LIGHTNING_CLUSTER_ID" , None )
106119 project_id = os .getenv ("LIGHTNING_CLOUD_PROJECT_ID" , None )
120+ provider = os .getenv ("LIGHTNING_CLOUD_PROVIDER" , CloudProvider .AWS )
107121
108122 if cluster_id is None :
109123 raise RuntimeError ("The `LIGHTNING_CLUSTER_ID` couldn't be found from the environment variables." )
@@ -126,12 +140,19 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option
126140 f"We didn't find a matching cluster associated with the id { target_cloud_space [0 ].cluster_id } ."
127141 )
128142
129- bucket_name = target_cluster [0 ].spec .aws_v1 .bucket_name
143+ if provider == CloudProvider .AWS :
144+ bucket_name = target_cluster [0 ].spec .aws_v1 .bucket_name
145+ scheme = "s3"
146+ elif provider == CloudProvider .GCP :
147+ bucket_name = target_cluster [0 ].spec .google_cloud_v1 .bucket_name
148+ scheme = "gs"
149+ else :
150+ raise ValueError (f"Unsupported cloud provider: { provider } . Supported providers are AWS and GCP." )
130151
131152 return Dir (
132153 path = dir_path ,
133154 url = os .path .join (
134- f"s3 ://{ bucket_name } /projects/{ project_id } /cloudspaces/{ target_cloud_space [0 ].id } /code/content" ,
155+ f"{ scheme } ://{ bucket_name } /projects/{ project_id } /cloudspaces/{ target_cloud_space [0 ].id } /code/content" ,
135156 * dir_path .split ("/" )[4 :],
136157 ),
137158 )
@@ -159,6 +180,28 @@ def _resolve_s3_connections(dir_path: str) -> Dir:
159180 return Dir (path = dir_path , url = os .path .join (data_connection [0 ].aws .source , * dir_path .split ("/" )[4 :]))
160181
161182
183+ def _resolve_gcs_connections (dir_path : str ) -> Dir :
184+ from lightning_sdk .lightning_cloud .rest_client import LightningClient
185+
186+ client = LightningClient (max_tries = 2 )
187+
188+ # Get the ids from env variables
189+ project_id = os .getenv ("LIGHTNING_CLOUD_PROJECT_ID" , None )
190+ if project_id is None :
191+ raise RuntimeError ("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables." )
192+
193+ target_name = dir_path .split ("/" )[3 ]
194+
195+ data_connections = client .data_connection_service_list_data_connections (project_id ).data_connections
196+
197+ data_connection = [dc for dc in data_connections if dc .name == target_name ]
198+
199+ if not data_connection :
200+ raise ValueError (f"We didn't find any matching data connection with the provided name `{ target_name } `." )
201+
202+ return Dir (path = dir_path , url = os .path .join (data_connection [0 ].gcp .source , * dir_path .split ("/" )[4 :]))
203+
204+
162205def _resolve_s3_folders (dir_path : str ) -> Dir :
163206 from lightning_sdk .lightning_cloud .rest_client import LightningClient
164207
@@ -181,6 +224,28 @@ def _resolve_s3_folders(dir_path: str) -> Dir:
181224 return Dir (path = dir_path , url = os .path .join (data_connection [0 ].s3_folder .source , * dir_path .split ("/" )[4 :]))
182225
183226
227+ def _resolve_gcs_folders (dir_path : str ) -> Dir :
228+ from lightning_sdk .lightning_cloud .rest_client import LightningClient
229+
230+ client = LightningClient (max_tries = 2 )
231+
232+ # Get the ids from env variables
233+ project_id = os .getenv ("LIGHTNING_CLOUD_PROJECT_ID" , None )
234+ if project_id is None :
235+ raise RuntimeError ("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables." )
236+
237+ target_name = dir_path .split ("/" )[3 ]
238+
239+ data_connections = client .data_connection_service_list_data_connections (project_id ).data_connections
240+
241+ data_connection = [dc for dc in data_connections if dc .name == target_name ]
242+
243+ if not data_connection :
244+ raise ValueError (f"We didn't find any matching data connection with the provided name `{ target_name } `." )
245+
246+ return Dir (path = dir_path , url = os .path .join (data_connection [0 ].gcs_folder .source , * dir_path .split ("/" )[4 :]))
247+
248+
184249def _resolve_datasets (dir_path : str ) -> Dir :
185250 from lightning_sdk .lightning_cloud .rest_client import LightningClient
186251
0 commit comments