diff --git a/gcsfs/credentials.py b/gcsfs/credentials.py index ef4e81b8..2eaf26ad 100644 --- a/gcsfs/credentials.py +++ b/gcsfs/credentials.py @@ -18,6 +18,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow from gcsfs.retry import HttpError +from gcsfs.retry import NonRetryableError logger = logging.getLogger("gcsfs.credentials") @@ -38,6 +39,42 @@ } } +TOKEN_INFO_TIMEOUT_SECONDS = 10 +LOCAL_REEFRESH_BUFFER = 300 # Greater than google.auth._helpers.REFRESH_THRESHOLD + +def _get_creds_from_raw_token(token): + # Default to True. Only disable if user explicitly says 'false', '0', or 'off'. + env_val = os.environ.get("FETCH_RAW_TOKEN_EXPIRY", "true").lower() + should_fetch_expiry = env_val not in ("false", "0", "off", "no") + + if should_fetch_expiry: + response = requests.get( + 'https://oauth2.googleapis.com/tokeninfo', + params={'access_token': token}, + timeout=TOKEN_INFO_TIMEOUT_SECONDS + ) + + if response.status_code == 400: + # Token is likely expired or invalid format + raise ValueError("Provided token is either not valid, or expired.") + + response.raise_for_status() + expiry = datetime.utcfromtimestamp(float(response.json()['exp'])) + + time_remaining = max(0, (expiry.replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)).total_seconds()) + if time_remaining <= LOCAL_REEFRESH_BUFFER: + raise ValueError( + f"The provided raw token expires in {time_remaining} seconds, " + f"which is less than the safety buffer ({LOCAL_REEFRESH_BUFFER}). " + "This may cause immediate authentication failures. " + "To bypass this check and safety buffer, you can set the environment " + "variable FETCH_RAW_TOKEN_EXPIRY=false (expiry will be unknown)." + ) + else: + expiry = None + + return Credentials(token, expiry=expiry) + class GoogleCredentials: def __init__(self, project, access, token, check_credentials=None, on_google=True): @@ -161,7 +198,7 @@ def _connect_token(self, token): with open(token) as data: token = json.load(data) else: - token = Credentials(token) + token = _get_creds_from_raw_token(token) if isinstance(token, dict): credentials = self._dict_to_credentials(token) elif isinstance(token, google.auth.credentials.Credentials): @@ -190,7 +227,7 @@ def _credentials_valid(self, refresh_buffer): ) ) - def maybe_refresh(self, refresh_buffer=300): + def maybe_refresh(self, refresh_buffer=LOCAL_REEFRESH_BUFFER): """ Check and refresh credentials if needed """ @@ -210,6 +247,15 @@ def maybe_refresh(self, refresh_buffer=300): try: self.credentials.refresh(req) except gauth.exceptions.RefreshError as error: + # There may be scenarios where this error is raised from the client side due + # to missing dependencies, especially when the client doesn't know how to refresh + # or lacks the necessary information. In such cases, the request gets retried + # with backoff strategy, which can be avoided. + + # Check for client side errors (if any) + if 'credentials do not contain the necessary fields need to refresh' in str(error): + raise NonRetryableError("Got error while refreshing credentials.") from error + # Re-raise as HttpError with a 401 code and the expected message raise HttpError( {"code": 401, "message": "Invalid Credentials"} diff --git a/gcsfs/retry.py b/gcsfs/retry.py index c5062173..3c4f16b8 100644 --- a/gcsfs/retry.py +++ b/gcsfs/retry.py @@ -44,6 +44,11 @@ class ChecksumError(Exception): pass +class NonRetryableError(Exception): + """Raised when the underlying error can not be retried, or continued further.""" + pass + + RETRIABLE_EXCEPTIONS = ( requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, @@ -69,6 +74,8 @@ class ChecksumError(Exception): def is_retriable(exception): """Returns True if this exception is retriable.""" + if isinstance(exception, NonRetryableError): + return False if isinstance(exception, HttpError): # Add 401 to retriable errors when it's an auth expiration issue diff --git a/gcsfs/tests/test_credentials.py b/gcsfs/tests/test_credentials.py index 5c1bb658..80c751e1 100644 --- a/gcsfs/tests/test_credentials.py +++ b/gcsfs/tests/test_credentials.py @@ -1,9 +1,15 @@ import pytest +import datetime +import os from gcsfs import GCSFileSystem from gcsfs.credentials import GoogleCredentials +from gcsfs.retry import NonRetryableError from gcsfs.retry import HttpError +from unittest.mock import patch, Mock +MOCK_TOKEN_STR = "ya29.valid_raw_token_string" +MOCK_EXP_TIMESTAMP = 1764620492 # 2025-12-01 20:21:32 UTC def test_googlecredentials_none(): credentials = GoogleCredentials(project="myproject", token=None, access="read_only") @@ -13,6 +19,84 @@ def test_googlecredentials_none(): @pytest.mark.parametrize("token", ["", "incorrect.token", "x" * 100]) def test_credentials_from_raw_token(token): - with pytest.raises(HttpError, match="Invalid Credentials"): - fs = GCSFileSystem(project="myproject", token=token) - fs.ls("/") + with patch.dict(os.environ, {"FETCH_RAW_TOKEN_EXPIRY": "false"}): + with pytest.raises(HttpError, match="Invalid Credentials"): + fs = GCSFileSystem(project="myproject", token=token) + fs.ls("/") + +@pytest.fixture +def mock_token_info_api_response(): + """Returns a mock response object that mimics a valid Google Token Info response""" + resp = Mock() + resp.status_code = 200 + resp.json.return_value = {'exp': str(MOCK_EXP_TIMESTAMP)} + return resp + +def test_raw_token_credentials_init_with_raw_token_fetches_expiry(mock_token_info_api_response): + """ + Test that initializing GoogleCredentials with a raw string token + triggers the API lookup and sets the expiry. + """ + future_time = int((datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=600)).timestamp()) + mock_token_info_api_response.json.return_value = {"exp": str(future_time)} + + with patch("gcsfs.credentials.requests.get", return_value=mock_token_info_api_response) as mock_get: + creds = GoogleCredentials(project="my-project", token=MOCK_TOKEN_STR, access="read_only") + mock_get.assert_called_once_with( + 'https://oauth2.googleapis.com/tokeninfo', + params={'access_token': MOCK_TOKEN_STR}, + timeout=10 + ) + + assert creds.credentials.token == MOCK_TOKEN_STR + assert creds.credentials.expiry is not None + assert creds.credentials.expiry == datetime.datetime.utcfromtimestamp(future_time) + +def test_raw_token_credentials_init_env_var_disables_fetch(mock_token_info_api_response): + """Test that the FETCH_RAW_TOKEN_EXPIRY environment variable stops the network call.""" + with patch.dict(os.environ, {"FETCH_RAW_TOKEN_EXPIRY": "false"}): + with patch("gcsfs.credentials.requests.get", return_value=mock_token_info_api_response) as mock_get: + creds = GoogleCredentials(project="my-project", token=MOCK_TOKEN_STR, access="read_only") + mock_get.assert_not_called() + assert creds.credentials.token == MOCK_TOKEN_STR + assert creds.credentials.expiry is None + +def test_raw_token_credentials_init_raises_on_invalid_token(mock_token_info_api_response): + """Test that if the API returns 400 (Bad Request), the class initialization fails.""" + mock_token_info_api_response.status_code = 400 + mock_token_info_api_response.json.return_value = {"error": "invalid_token"} + + with patch("gcsfs.credentials.requests.get", return_value=mock_token_info_api_response): + with pytest.raises(ValueError, match="Provided token is either not valid"): + GoogleCredentials(project="my-project", token="bad_token_string", access="read_only") + +def test_raw_token_credentials_refresh_throws_error_after_expiry(mock_token_info_api_response): + """Tests that raw token cred refresh throws error after expiry.""" + future_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=600) + mock_token_info_api_response.json.return_value = {"exp": str(int(future_time.timestamp()))} + + with patch("gcsfs.credentials.requests.get", return_value=mock_token_info_api_response) as _: + creds = GoogleCredentials(project="my-project", token="my_token", access="read_only") + + # Refresh before expiry + with patch("gcsfs.credentials.requests.Session") as mock_session: + creds.maybe_refresh() + mock_session.assert_not_called() + + creds.credentials.expiry = datetime.datetime.utcnow() - datetime.timedelta(minutes=10) + + # Refresh after expiry + with pytest.raises(NonRetryableError, match="Got error while refreshing credentials"): + creds.maybe_refresh() + +def test_raw_token_credentials_init_raises_on_short_lived_token(mock_token_info_api_response): + """ + Test that if the token expires too soon (less than the safety buffer), + we raise a ValueError immediately to warn the user. + """ + future_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=2) + mock_token_info_api_response.json.return_value = {"exp": str(int(future_time.timestamp()))} + + with patch("gcsfs.credentials.requests.get", return_value=mock_token_info_api_response): + with pytest.raises(ValueError, match="less than the safety buffer"): + GoogleCredentials(project="my-project", token="short_lived_token", access="read_only") \ No newline at end of file