3939
4040class Downloader (ABC ):
4141 def __init__ (
42- self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
42+ self ,
43+ remote_dir : str ,
44+ cache_dir : str ,
45+ chunks : List [Dict [str , Any ]],
46+ storage_options : Optional [Dict ] = {},
47+ ** kwargs : Any ,
4348 ):
4449 self ._remote_dir = remote_dir
4550 self ._cache_dir = cache_dir
@@ -77,13 +82,20 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
7782
7883class S3Downloader (Downloader ):
7984 def __init__ (
80- self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
85+ self ,
86+ remote_dir : str ,
87+ cache_dir : str ,
88+ chunks : List [Dict [str , Any ]],
89+ storage_options : Optional [Dict ] = {},
90+ ** kwargs : Any ,
8191 ):
8292 super ().__init__ (remote_dir , cache_dir , chunks , storage_options )
8393 self ._s5cmd_available = os .system ("s5cmd > /dev/null 2>&1" ) == 0
94+ # check if kwargs contains session_options
95+ self .session_options = kwargs .get ("session_options" , {})
8496
8597 if not self ._s5cmd_available or _DISABLE_S5CMD :
86- self ._client = S3Client (storage_options = self ._storage_options )
98+ self ._client = S3Client (storage_options = self ._storage_options , session_options = self . session_options )
8799
88100 def download_file (self , remote_filepath : str , local_filepath : str ) -> None :
89101 obj = parse .urlparse (remote_filepath )
@@ -156,7 +168,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
156168
157169class GCPDownloader (Downloader ):
158170 def __init__ (
159- self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
171+ self ,
172+ remote_dir : str ,
173+ cache_dir : str ,
174+ chunks : List [Dict [str , Any ]],
175+ storage_options : Optional [Dict ] = {},
176+ ** kwargs : Any ,
160177 ):
161178 if not _GOOGLE_STORAGE_AVAILABLE :
162179 raise ModuleNotFoundError (str (_GOOGLE_STORAGE_AVAILABLE ))
@@ -194,7 +211,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
194211
195212class AzureDownloader (Downloader ):
196213 def __init__ (
197- self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
214+ self ,
215+ remote_dir : str ,
216+ cache_dir : str ,
217+ chunks : List [Dict [str , Any ]],
218+ storage_options : Optional [Dict ] = {},
219+ ** kwargs : Any ,
198220 ):
199221 if not _AZURE_STORAGE_AVAILABLE :
200222 raise ModuleNotFoundError (str (_AZURE_STORAGE_AVAILABLE ))
@@ -247,7 +269,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
247269
248270class HFDownloader (Downloader ):
249271 def __init__ (
250- self , remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
272+ self ,
273+ remote_dir : str ,
274+ cache_dir : str ,
275+ chunks : List [Dict [str , Any ]],
276+ storage_options : Optional [Dict ] = {},
277+ ** kwargs : Any ,
251278 ):
252279 if not _HF_HUB_AVAILABLE :
253280 raise ModuleNotFoundError (
@@ -331,7 +358,11 @@ def unregister_downloader(prefix: str) -> None:
331358
332359
333360def get_downloader (
334- remote_dir : str , cache_dir : str , chunks : List [Dict [str , Any ]], storage_options : Optional [Dict ] = {}
361+ remote_dir : str ,
362+ cache_dir : str ,
363+ chunks : List [Dict [str , Any ]],
364+ storage_options : Optional [Dict ] = {},
365+ session_options : Optional [Dict ] = {},
335366) -> Downloader :
336367 """Get the appropriate downloader instance based on the remote directory prefix.
337368
@@ -340,13 +371,14 @@ def get_downloader(
340371 cache_dir (str): The local cache directory.
341372 chunks (List[Dict[str, Any]]): List of chunks to managed by the downloader.
342373 storage_options (Optional[Dict], optional): Additional storage options. Defaults to {}.
374+ session_options (Optional[Dict], optional): Additional S3 session options. Defaults to {}.
343375
344376 Returns:
345377 Downloader: An instance of the appropriate downloader class.
346378 """
347379 for k , cls in _DOWNLOADERS .items ():
348380 if str (remote_dir ).startswith (k ):
349- return cls (remote_dir , cache_dir , chunks , storage_options )
381+ return cls (remote_dir , cache_dir , chunks , storage_options , session_options = session_options )
350382 else :
351383 # Default to LocalDownloader if no prefix is matched
352384 return LocalDownloader (remote_dir , cache_dir , chunks , storage_options )
0 commit comments