Skip to content

Commit 0f4f9b2

Browse files
authored
fix: boto3 session options (#604)
* feat: Add s3_session_options to various components for enhanced S3 configuration * update * meow * feat: Update S3 session options to a unified parameter across components
1 parent 7194311 commit 0f4f9b2

File tree

8 files changed

+90
-18
lines changed

8 files changed

+90
-18
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,12 @@ aws_storage_options={
431431
"aws_access_key_id": os.environ['AWS_ACCESS_KEY_ID'],
432432
"aws_secret_access_key": os.environ['AWS_SECRET_ACCESS_KEY'],
433433
}
434-
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
434+
# You can also pass the session options. (for boto3 only)
435+
aws_session_options = {
436+
"profile_name": os.environ['AWS_PROFILE_NAME'], # Required only for custom profiles
437+
"region_name": os.environ['AWS_REGION_NAME'], # Required only for custom regions
438+
}
439+
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options, session_options=aws_session_options)
435440

436441

437442
# Read data from GCS

src/litdata/streaming/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
serializers: Optional[Dict[str, Serializer]] = None,
4747
writer_chunk_index: Optional[int] = None,
4848
storage_options: Optional[Dict] = {},
49+
session_options: Optional[Dict] = {},
4950
max_pre_download: int = 2,
5051
):
5152
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
@@ -64,6 +65,7 @@ def __init__(
6465
serializers: Provide your own serializers.
6566
writer_chunk_index: The index of the chunk to start from when writing.
6667
storage_options: Additional connection options for accessing storage services.
68+
session_options: Additional options for the S3 session.
6769
max_pre_download: Maximum number of chunks that can be pre-downloaded while filling up the cache.
6870
6971
"""
@@ -92,6 +94,7 @@ def __init__(
9294
item_loader=item_loader,
9395
serializers=serializers,
9496
storage_options=storage_options,
97+
session_options=session_options,
9598
max_pre_download=max_pre_download,
9699
)
97100
self._is_done = False

src/litdata/streaming/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,25 @@
2626
class S3Client:
2727
# TODO: Generalize to support more cloud providers.
2828

29-
def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None:
29+
def __init__(
30+
self,
31+
refetch_interval: int = 3300,
32+
storage_options: Optional[Dict] = {},
33+
session_options: Optional[Dict] = {},
34+
) -> None:
3035
self._refetch_interval = refetch_interval
3136
self._last_time: Optional[float] = None
3237
self._client: Optional[Any] = None
3338
self._storage_options: dict = storage_options or {}
39+
self._session_options: dict = session_options or {}
3440

3541
def _create_client(self) -> None:
3642
has_shared_credentials_file = (
3743
os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials"
3844
)
3945

40-
if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options:
41-
session = boto3.Session()
46+
if has_shared_credentials_file or not _IS_IN_STUDIO or self._storage_options or self._session_options:
47+
session = boto3.Session(**self._session_options) # If additional options are provided
4248
self._client = session.client(
4349
"s3",
4450
**{

src/litdata/streaming/config.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
subsampled_files: Optional[List[str]] = None,
4040
region_of_interest: Optional[List[Tuple[int, int]]] = None,
4141
storage_options: Optional[Dict] = {},
42+
session_options: Optional[Dict] = {},
4243
) -> None:
4344
"""Reads the index files associated a chunked dataset and enables to map an index to its chunk.
4445
@@ -51,6 +52,7 @@ def __init__(
5152
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
5253
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
5354
storage_options: Additional connection options for accessing storage services.
55+
session_options: Additional options for S3 session.
5456
5557
"""
5658
self._cache_dir = cache_dir
@@ -60,6 +62,7 @@ def __init__(
6062
self._remote_dir = remote_dir
6163
self._item_loader = item_loader or PyTreeLoader()
6264
self._storage_options = storage_options
65+
self._session_options = session_options
6366

6467
# load data from `index.json` file
6568
data = load_index_file(self._cache_dir)
@@ -84,7 +87,9 @@ def __init__(
8487
self._downloader = None
8588

8689
if remote_dir:
87-
self._downloader = get_downloader(remote_dir, cache_dir, self._chunks, self._storage_options)
90+
self._downloader = get_downloader(
91+
remote_dir, cache_dir, self._chunks, self._storage_options, self._session_options
92+
)
8893

8994
self._compressor_name = self._config["compression"]
9095
self._compressor: Optional[Compressor] = None
@@ -286,6 +291,7 @@ def load(
286291
subsampled_files: Optional[List[str]] = None,
287292
region_of_interest: Optional[List[Tuple[int, int]]] = None,
288293
storage_options: Optional[dict] = {},
294+
session_options: Optional[dict] = {},
289295
) -> Optional["ChunksConfig"]:
290296
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
291297

@@ -298,14 +304,21 @@ def load(
298304
f"This should not have happened. No index.json file found in cache: {cache_index_filepath}"
299305
)
300306
else:
301-
downloader = get_downloader(remote_dir, cache_dir, [], storage_options)
307+
downloader = get_downloader(remote_dir, cache_dir, [], storage_options, session_options)
302308
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
303309

304310
if not os.path.exists(cache_index_filepath):
305311
return None
306312

307313
return ChunksConfig(
308-
cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest, storage_options
314+
cache_dir,
315+
serializers,
316+
remote_dir,
317+
item_loader,
318+
subsampled_files,
319+
region_of_interest,
320+
storage_options,
321+
session_options,
309322
)
310323

311324
def __len__(self) -> int:

src/litdata/streaming/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
subsample: float = 1.0,
5959
encryption: Optional[Encryption] = None,
6060
storage_options: Optional[Dict] = {},
61+
session_options: Optional[Dict] = {},
6162
max_pre_download: int = 2,
6263
index_path: Optional[str] = None,
6364
force_override_state_dict: bool = False,
@@ -81,6 +82,7 @@ def __init__(
8182
subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset).
8283
encryption: The encryption object to use for decrypting the data.
8384
storage_options: Additional connection options for accessing storage services.
85+
session_options: Additional connection options for accessing S3 services.
8486
max_pre_download: Maximum number of chunks that can be pre-downloaded by the StreamingDataset.
8587
index_path: Path to `index.json` for the Parquet dataset.
8688
If `index_path` is a directory, the function will look for `index.json` within it.
@@ -128,6 +130,7 @@ def __init__(
128130
shuffle,
129131
seed,
130132
storage_options,
133+
session_options,
131134
index_path,
132135
fnmatch_pattern,
133136
)
@@ -190,6 +193,7 @@ def __init__(
190193
self.batch_size: int = 1
191194
self._encryption = encryption
192195
self.storage_options = storage_options
196+
self.session_options = session_options
193197
self.max_pre_download = max_pre_download
194198

195199
def set_shuffle(self, shuffle: bool) -> None:
@@ -228,6 +232,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
228232
max_cache_size=self.max_cache_size,
229233
encryption=self._encryption,
230234
storage_options=self.storage_options,
235+
session_options=self.session_options,
231236
max_pre_download=self.max_pre_download,
232237
)
233238
cache._reader._try_load_config()

src/litdata/streaming/downloader.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@
3939

4040
class Downloader(ABC):
4141
def __init__(
42-
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
42+
self,
43+
remote_dir: str,
44+
cache_dir: str,
45+
chunks: List[Dict[str, Any]],
46+
storage_options: Optional[Dict] = {},
47+
**kwargs: Any,
4348
):
4449
self._remote_dir = remote_dir
4550
self._cache_dir = cache_dir
@@ -77,13 +82,20 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
7782

7883
class S3Downloader(Downloader):
7984
def __init__(
80-
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
85+
self,
86+
remote_dir: str,
87+
cache_dir: str,
88+
chunks: List[Dict[str, Any]],
89+
storage_options: Optional[Dict] = {},
90+
**kwargs: Any,
8191
):
8292
super().__init__(remote_dir, cache_dir, chunks, storage_options)
8393
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0
94+
# check if kwargs contains session_options
95+
self.session_options = kwargs.get("session_options", {})
8496

8597
if not self._s5cmd_available or _DISABLE_S5CMD:
86-
self._client = S3Client(storage_options=self._storage_options)
98+
self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options)
8799

88100
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
89101
obj = parse.urlparse(remote_filepath)
@@ -156,7 +168,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
156168

157169
class GCPDownloader(Downloader):
158170
def __init__(
159-
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
171+
self,
172+
remote_dir: str,
173+
cache_dir: str,
174+
chunks: List[Dict[str, Any]],
175+
storage_options: Optional[Dict] = {},
176+
**kwargs: Any,
160177
):
161178
if not _GOOGLE_STORAGE_AVAILABLE:
162179
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))
@@ -194,7 +211,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
194211

195212
class AzureDownloader(Downloader):
196213
def __init__(
197-
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
214+
self,
215+
remote_dir: str,
216+
cache_dir: str,
217+
chunks: List[Dict[str, Any]],
218+
storage_options: Optional[Dict] = {},
219+
**kwargs: Any,
198220
):
199221
if not _AZURE_STORAGE_AVAILABLE:
200222
raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE))
@@ -247,7 +269,12 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
247269

248270
class HFDownloader(Downloader):
249271
def __init__(
250-
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
272+
self,
273+
remote_dir: str,
274+
cache_dir: str,
275+
chunks: List[Dict[str, Any]],
276+
storage_options: Optional[Dict] = {},
277+
**kwargs: Any,
251278
):
252279
if not _HF_HUB_AVAILABLE:
253280
raise ModuleNotFoundError(
@@ -331,7 +358,11 @@ def unregister_downloader(prefix: str) -> None:
331358

332359

333360
def get_downloader(
334-
remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
361+
remote_dir: str,
362+
cache_dir: str,
363+
chunks: List[Dict[str, Any]],
364+
storage_options: Optional[Dict] = {},
365+
session_options: Optional[Dict] = {},
335366
) -> Downloader:
336367
"""Get the appropriate downloader instance based on the remote directory prefix.
337368
@@ -340,13 +371,14 @@ def get_downloader(
340371
cache_dir (str): The local cache directory.
341372
chunks (List[Dict[str, Any]]): List of chunks to managed by the downloader.
342373
storage_options (Optional[Dict], optional): Additional storage options. Defaults to {}.
374+
session_options (Optional[Dict], optional): Additional S3 session options. Defaults to {}.
343375
344376
Returns:
345377
Downloader: An instance of the appropriate downloader class.
346378
"""
347379
for k, cls in _DOWNLOADERS.items():
348380
if str(remote_dir).startswith(k):
349-
return cls(remote_dir, cache_dir, chunks, storage_options)
381+
return cls(remote_dir, cache_dir, chunks, storage_options, session_options=session_options)
350382
else:
351383
# Default to LocalDownloader if no prefix is matched
352384
return LocalDownloader(remote_dir, cache_dir, chunks, storage_options)

src/litdata/streaming/reader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def __init__(
265265
item_loader: Optional[BaseItemLoader] = None,
266266
serializers: Optional[Dict[str, Serializer]] = None,
267267
storage_options: Optional[dict] = {},
268+
session_options: Optional[dict] = {},
268269
max_pre_download: int = 2,
269270
) -> None:
270271
"""The BinaryReader enables to read chunked dataset in an efficient way.
@@ -281,6 +282,7 @@ def __init__(
281282
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
282283
serializers: Provide your own serializers.
283284
storage_options: Additional connection options for accessing storage services.
285+
session_options: Additional options for the S3 session.
284286
max_pre_download: Maximum number of chunks that can be pre-downloaded by the reader.
285287
286288
"""
@@ -308,6 +310,7 @@ def __init__(
308310
self._chunks_queued_for_download = False
309311
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))
310312
self._storage_options = storage_options
313+
self._session_options = session_options
311314
self._max_pre_download = max_pre_download
312315

313316
def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
@@ -327,6 +330,7 @@ def _try_load_config(self) -> Optional[ChunksConfig]:
327330
self.subsampled_files,
328331
self.region_of_interest,
329332
self._storage_options,
333+
self._session_options,
330334
)
331335
return self._config
332336

src/litdata/utilities/dataset_utilities.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def subsample_streaming_dataset(
2424
shuffle: bool = False,
2525
seed: int = 42,
2626
storage_options: Optional[Dict] = {},
27+
session_options: Optional[Dict] = {},
2728
index_path: Optional[str] = None,
2829
fnmatch_pattern: Optional[str] = None,
2930
) -> Tuple[List[str], List[Tuple[int, int]]]:
@@ -46,6 +47,7 @@ def subsample_streaming_dataset(
4647
input_dir=input_dir.path if input_dir.path else input_dir.url,
4748
cache_dir=cache_dir.path if cache_dir else None,
4849
storage_options=storage_options,
50+
session_options=session_options,
4951
index_path=index_path,
5052
)
5153
if cache_path is not None:
@@ -61,7 +63,7 @@ def subsample_streaming_dataset(
6163
if index_path is not None:
6264
copy_index_to_cache_index_filepath(index_path, cache_index_filepath)
6365
else:
64-
downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options)
66+
downloader = get_downloader(input_dir.url, input_dir.path, [], storage_options, session_options)
6567
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath)
6668

6769
time.sleep(0.5) # Give some time for the file to be available
@@ -141,6 +143,7 @@ def _should_replace_path(path: Optional[str]) -> bool:
141143
def _read_updated_at(
142144
input_dir: Optional[Dir],
143145
storage_options: Optional[Dict] = {},
146+
session_options: Optional[Dict] = {},
144147
index_path: Optional[str] = None,
145148
) -> str:
146149
"""Read last updated timestamp from index.json file."""
@@ -160,7 +163,7 @@ def _read_updated_at(
160163
if index_path is not None:
161164
copy_index_to_cache_index_filepath(index_path, temp_index_filepath)
162165
else:
163-
downloader = get_downloader(input_dir.url, tmp_directory, [], storage_options)
166+
downloader = get_downloader(input_dir.url, tmp_directory, [], storage_options, session_options)
164167
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), temp_index_filepath)
165168
index_json_content = load_index_file(tmp_directory)
166169

@@ -213,11 +216,12 @@ def _try_create_cache_dir(
213216
input_dir: Optional[str],
214217
cache_dir: Optional[str] = None,
215218
storage_options: Optional[Dict] = {},
219+
session_options: Optional[Dict] = {},
216220
index_path: Optional[str] = None,
217221
) -> Optional[str]:
218222
"""Prepare and return the cache directory for a dataset."""
219223
resolved_input_dir = _resolve_dir(input_dir)
220-
updated_at = _read_updated_at(resolved_input_dir, storage_options, index_path)
224+
updated_at = _read_updated_at(resolved_input_dir, storage_options, session_options, index_path)
221225

222226
# Fallback to a hash of the input_dir if updated_at is "0"
223227
if updated_at == "0" and input_dir is not None:

0 commit comments

Comments
 (0)