Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion requests_aws4auth/aws4auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,16 @@ def __init__(self, *args, **kwargs):
If refreshable_credentials is set, the following arguments
are ignored: access_id, secret_key, signing_key,
session_token.
refresh_credentials_when_needed
-- Must be supplied as keyword argument. If refreshable_credentials
is provided and refresh_credentials_when_needed is set to True,
then credentials will only be refreshed if needed instead of
on each request. Defaults to False.

"""
self.signing_key = None
self.refreshable_credentials = kwargs.get('refreshable_credentials', None)
self.refresh_credentials_when_needed = kwargs.get('refresh_credentials_when_needed', False)
if self.refreshable_credentials:
# instantiate from refreshable_credentials
self.service = kwargs.get('service', None)
Expand All @@ -269,6 +275,8 @@ def __init__(self, *args, **kwargs):
raise TypeError('region must be provided as keyword argument when using refreshable_credentials')
self.date = kwargs.get('date', None)
self.default_include_headers.add('x-amz-security-token')
if self.refresh_credentials_when_needed and not callable(getattr(self.refreshable_credentials, 'refresh_needed', None)):
raise TypeError(f'credentials acquired via {self.refreshable_credentials.method} which does not support refresh when needed')
else:
l = len(args)
if l not in [2, 4, 5]:
Expand Down Expand Up @@ -372,8 +380,10 @@ def __call__(self, req):

"""
if self.refreshable_credentials:
if not self.refresh_credentials_when_needed or self.refreshable_credentials.refresh_needed():
# generate per-request static credentials
self.refresh_credentials()
# alternatively generate only when needed if self.refresh_credentials_when_needed is True
self.refresh_credentials()
# check request date matches scope date
req_date = self.get_request_date(req)
if req_date is None:
Expand Down