diff --git a/src/okta_mcp_server/server.py b/src/okta_mcp_server/server.py index caa81ba..d1e1792 100644 --- a/src/okta_mcp_server/server.py +++ b/src/okta_mcp_server/server.py @@ -28,18 +28,21 @@ class OktaAppContext: async def okta_authorisation_flow(server: FastMCP) -> AsyncIterator[OktaAppContext]: """ Manages the application lifecycle. It initializes the OktaManager on startup, - performs authorization, and yields the context for use in tools. + re-using a cached token from the OS keyring when one is still valid, and yields + the context for use in tools. """ logger.info("Starting Okta authorization flow") manager = OktaAuthManager() - await manager.authenticate() - logger.info("Okta authentication completed successfully") - - try: - yield OktaAppContext(okta_auth_manager=manager) - finally: - logger.debug("Clearing Okta tokens") - manager.clear_tokens() + + if manager.is_cached_token_valid(): + logger.info("Re-using cached Okta token from keyring; skipping interactive auth") + else: + if not await manager.is_valid_token(): + logger.error("Authentication failed: no token available after refresh and re-auth") + sys.exit(1) + logger.info("Okta authentication completed (refresh or new auth)") + + yield OktaAppContext(okta_auth_manager=manager) mcp = FastMCP("Okta IDaaS MCP Server", lifespan=okta_authorisation_flow) diff --git a/src/okta_mcp_server/utils/auth/auth_manager.py b/src/okta_mcp_server/utils/auth/auth_manager.py index 2dcdceb..0c2003a 100644 --- a/src/okta_mcp_server/utils/auth/auth_manager.py +++ b/src/okta_mcp_server/utils/auth/auth_manager.py @@ -21,6 +21,7 @@ from loguru import logger SERVICE_NAME = "OktaAuthManager" +_TOKEN_EXPIRY_SAFETY_MARGIN_SECONDS = 60 @dataclass @@ -29,7 +30,6 @@ class OktaAuthManager: org_url: str = field(init=False) client_id: str = field(init=False) - token_timestamp: int = 0 scopes: str = "openid profile email offline_access" private_key: str = field(init=False, default=None) key_id: str = field(init=False, default=None) @@ -138,7 +138,6 @@ def _browserless_authenticate(self) -> str | None: if access_token: logger.info("Successfully obtained access token via browserless authentication") keyring.set_password(SERVICE_NAME, "api_token", access_token) - self.token_timestamp = int(time.time()) # Note: Client credentials flow doesn't provide refresh tokens logger.debug("Note: Client credentials flow does not provide refresh tokens") @@ -211,7 +210,6 @@ def _poll_for_token(self, device_data): if response.status_code == 200 and "access_token" in resp_json: logger.info("Successfully obtained access token") keyring.set_password(SERVICE_NAME, "api_token", resp_json["access_token"]) - self.token_timestamp = int(time.time()) if "refresh_token" in resp_json: logger.debug("Refresh token received and stored") @@ -271,7 +269,6 @@ def refresh_access_token(self) -> bool: logger.debug("New refresh token received and stored") keyring.set_password(SERVICE_NAME, "refresh_token", resp_json["refresh_token"]) - self.token_timestamp = int(time.time()) logger.info("Token refreshed successfully") return True else: @@ -316,33 +313,71 @@ async def authenticate(self): else: logger.error("Authentication failed") - async def is_valid_token(self, expiry_duration: int = 3600) -> bool: - """Ensure that a valid token is available. Refresh or re-authenticate if needed.""" - logger.debug(f"Checking token validity (expiry duration: {expiry_duration}s)") + async def is_valid_token(self) -> bool: + """Ensure that a valid token is available. Refresh or re-authenticate if needed. + + Validity is determined by inspecting the JWT ``exp`` claim with a 60-second safety + margin to absorb clock skew. Tokens that are not parseable as JWTs (opaque tokens), + or JWTs without an ``exp`` claim, are treated as expired and fall through to the + refresh/reauth path. + """ + logger.debug("Checking token validity") api_token = keyring.get_password(SERVICE_NAME, "api_token") - token_age = time.time() - self.token_timestamp - if api_token and token_age < expiry_duration: - logger.debug(f"Token is valid (age: {token_age:.0f}s)") + if api_token and self._token_is_unexpired(api_token): + logger.debug("Cached token is valid") return True - logger.info(f"Token is expired or missing (age: {token_age:.0f}s)") + logger.info("Token is expired, missing, or unparseable; attempting refresh or re-auth") if self.use_browserless_auth: - # For browserless auth, we can't refresh, so re-authenticate + # Browserless flow has no refresh token; re-authenticate. logger.info("Re-authenticating using browserless flow") await self.authenticate() else: - # For device flow, try to refresh first refreshed = self.refresh_access_token() - - # If refresh token is not available or refresh failed, re-authenticate if not refreshed: - logger.warning("Token refresh failed, initiating re-authentication") + logger.warning("Token refresh failed or unavailable; initiating re-authentication") await self.authenticate() return keyring.get_password(SERVICE_NAME, "api_token") is not None + @staticmethod + def _token_is_unexpired(token: str) -> bool: + """Return True if ``token`` is a JWT whose ``exp`` claim is more than 60s in the future. + + Opaque tokens (non-JWT), JWTs missing an ``exp`` claim, or JWTs that fail to decode + return False, causing callers to fall through to refresh/reauth. + """ + try: + claims = jwt.decode(token, options={"verify_signature": False}) + except jwt.DecodeError: + logger.debug("Token is not a JWT (opaque); treating as expired") + return False + + exp = claims.get("exp") + if exp is None: + logger.debug("JWT has no exp claim; treating as expired") + return False + + seconds_remaining = exp - time.time() + if seconds_remaining > _TOKEN_EXPIRY_SAFETY_MARGIN_SECONDS: + logger.debug(f"JWT is valid for {seconds_remaining:.0f}s more") + return True + + logger.debug(f"JWT expires in {seconds_remaining:.0f}s, within safety margin; treating as expired") + return False + + def is_cached_token_valid(self) -> bool: + """Return True if a valid (unexpired) JWT api_token is cached in the keyring. + + Pure check with no side effects — does not run refresh or re-authentication. + Distinguishes a true cache hit from a refresh/re-auth that just minted a token, + so callers (e.g. the lifespan handler) can log accurately. + """ + api_token = keyring.get_password(SERVICE_NAME, "api_token") + return bool(api_token and self._token_is_unexpired(api_token)) + def clear_tokens(self): """Clear all stored tokens from keyring.""" logger.info("Clearing stored tokens") @@ -359,5 +394,4 @@ def clear_tokens(self): except keyring.backend.errors.KeyringError as e: logger.warning(f"Failed to delete refresh_token from keyring: {e}") - self.token_timestamp = 0 logger.info("Token cleanup completed") diff --git a/src/okta_mcp_server/utils/client.py b/src/okta_mcp_server/utils/client.py index dea45ed..cd1ff6b 100644 --- a/src/okta_mcp_server/utils/client.py +++ b/src/okta_mcp_server/utils/client.py @@ -15,11 +15,10 @@ async def get_okta_client(manager: OktaAuthManager) -> OktaClient: """Initialize and return an Okta client""" logger.debug("Initializing Okta client") - api_token = keyring.get_password(SERVICE_NAME, "api_token") if not await manager.is_valid_token(): logger.warning("Token is invalid or expired, re-authenticating") await manager.authenticate() - api_token = keyring.get_password(SERVICE_NAME, "api_token") + api_token = keyring.get_password(SERVICE_NAME, "api_token") config = { "orgUrl": manager.org_url, "token": api_token, diff --git a/tests/test_auth_manager.py b/tests/test_auth_manager.py new file mode 100644 index 0000000..3809602 --- /dev/null +++ b/tests/test_auth_manager.py @@ -0,0 +1,248 @@ +# The Okta software accompanied by this notice is provided pursuant to the following terms: +# Copyright © 2026-Present, Okta, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations under the License. + +"""Unit tests for OktaAuthManager.""" + +from __future__ import annotations + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt +import keyring.errors +import pytest + +from okta_mcp_server.utils.auth.auth_manager import OktaAuthManager + + +@pytest.fixture(autouse=True) +def _okta_env(monkeypatch): + monkeypatch.setenv("OKTA_ORG_URL", "https://test.okta.com") + monkeypatch.setenv("OKTA_CLIENT_ID", "test-client-id") + monkeypatch.delenv("OKTA_PRIVATE_KEY", raising=False) + monkeypatch.delenv("OKTA_KEY_ID", raising=False) + monkeypatch.delenv("OKTA_SCOPES", raising=False) + + +def _jwt_with_exp(exp_offset_seconds: int) -> str: + return jwt.encode({"exp": int(time.time()) + exp_offset_seconds}, "test-secret", algorithm="HS256") + + +def _jwt_without_exp() -> str: + return jwt.encode({"sub": "x"}, "test-secret", algorithm="HS256") + + +def _keyring_returns(api_token=None, refresh_token=None): + def _side_effect(_service, key): + if key == "api_token": + return api_token + if key == "refresh_token": + return refresh_token + return None + + return _side_effect + + +class TestIsValidToken: + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_cold_start_with_valid_cached_jwt_skips_auth(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns(api_token=_jwt_with_exp(3600)) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock()) as mock_refresh, + ): + result = await manager.is_valid_token() + assert result is True + mock_auth.assert_not_called() + mock_refresh.assert_not_called() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_expired_jwt_with_refresh_token_uses_refresh(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns( + api_token=_jwt_with_exp(-60), refresh_token="refresh-abc" + ) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=True)) as mock_refresh, + ): + result = await manager.is_valid_token() + assert result is True + mock_refresh.assert_called_once() + mock_auth.assert_not_called() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_expired_jwt_without_refresh_token_invokes_authenticate(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns(api_token=_jwt_with_exp(-60), refresh_token=None) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=False)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + mock_auth.assert_awaited_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_expired_jwt_with_failed_refresh_invokes_authenticate(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns( + api_token=_jwt_with_exp(-60), refresh_token="refresh-abc" + ) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=False)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + mock_auth.assert_awaited_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_no_cached_token_triggers_device_flow(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns(api_token=None, refresh_token=None) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=False)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + mock_auth.assert_awaited_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_opaque_token_falls_through_to_refresh(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns( + api_token="opaque-abc-123", refresh_token="refresh-abc" + ) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()), + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=True)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_malformed_jwt_falls_through_to_refresh(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns(api_token="a.b.c", refresh_token="refresh-abc") + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()), + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=True)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_jwt_without_exp_claim_falls_through(self, mock_keyring): + mock_keyring.get_password.side_effect = _keyring_returns( + api_token=_jwt_without_exp(), refresh_token="refresh-abc" + ) + manager = OktaAuthManager() + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()), + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock(return_value=True)) as mock_refresh, + ): + await manager.is_valid_token() + mock_refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_browserless_with_valid_cached_jwt_skips_auth(self, mock_keyring, monkeypatch): + monkeypatch.setenv("OKTA_PRIVATE_KEY", "fake-key") + monkeypatch.setenv("OKTA_KEY_ID", "fake-kid") + mock_keyring.get_password.side_effect = _keyring_returns(api_token=_jwt_with_exp(3600)) + manager = OktaAuthManager() + assert manager.use_browserless_auth is True + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock()) as mock_refresh, + ): + result = await manager.is_valid_token() + assert result is True + mock_auth.assert_not_called() + mock_refresh.assert_not_called() + + @pytest.mark.asyncio + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + async def test_browserless_with_expired_jwt_reauths_without_refresh(self, mock_keyring, monkeypatch): + monkeypatch.setenv("OKTA_PRIVATE_KEY", "fake-key") + monkeypatch.setenv("OKTA_KEY_ID", "fake-kid") + mock_keyring.get_password.side_effect = _keyring_returns(api_token=_jwt_with_exp(-60)) + manager = OktaAuthManager() + assert manager.use_browserless_auth is True + with ( + patch.object(OktaAuthManager, "authenticate", new=AsyncMock()) as mock_auth, + patch.object(OktaAuthManager, "refresh_access_token", new=MagicMock()) as mock_refresh, + ): + await manager.is_valid_token() + mock_auth.assert_awaited_once() + mock_refresh.assert_not_called() + + +class TestTokenIsUnexpired: + @patch("okta_mcp_server.utils.auth.auth_manager.time.time") + def test_token_expiring_within_safety_margin_is_treated_as_expired(self, mock_time): + mock_time.return_value = 1_000_000.0 + token = jwt.encode({"exp": 1_000_030}, "test-secret", algorithm="HS256") + assert OktaAuthManager._token_is_unexpired(token) is False + + @patch("okta_mcp_server.utils.auth.auth_manager.time.time") + def test_token_expiring_outside_safety_margin_is_valid(self, mock_time): + mock_time.return_value = 1_000_000.0 + token = jwt.encode({"exp": 1_000_090}, "test-secret", algorithm="HS256") + assert OktaAuthManager._token_is_unexpired(token) is True + + def test_empty_string_returns_false(self): + assert OktaAuthManager._token_is_unexpired("") is False + + +class TestIsCachedTokenValid: + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_returns_true_for_valid_cached_jwt(self, mock_keyring): + mock_keyring.get_password.return_value = _jwt_with_exp(3600) + assert OktaAuthManager().is_cached_token_valid() is True + + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_returns_false_when_no_token_cached(self, mock_keyring): + mock_keyring.get_password.return_value = None + assert OktaAuthManager().is_cached_token_valid() is False + + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_returns_false_for_expired_jwt(self, mock_keyring): + mock_keyring.get_password.return_value = _jwt_with_exp(-60) + assert OktaAuthManager().is_cached_token_valid() is False + + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_returns_false_for_opaque_token(self, mock_keyring): + mock_keyring.get_password.return_value = "opaque-not-a-jwt" + assert OktaAuthManager().is_cached_token_valid() is False + + +class TestClearTokens: + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_swallows_password_delete_errors(self, mock_keyring): + mock_keyring.delete_password.side_effect = keyring.errors.PasswordDeleteError("missing") + mock_keyring.backend.errors.KeyringError = keyring.errors.KeyringError + OktaAuthManager().clear_tokens() + assert mock_keyring.delete_password.call_count == 2 + + @patch("okta_mcp_server.utils.auth.auth_manager.keyring") + def test_success_path(self, mock_keyring): + mock_keyring.delete_password.return_value = None + mock_keyring.backend.errors.KeyringError = keyring.errors.KeyringError + OktaAuthManager().clear_tokens() + assert mock_keyring.delete_password.call_count == 2 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..1d364b4 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,75 @@ +# The Okta software accompanied by this notice is provided pursuant to the following terms: +# Copyright © 2026-Present, Okta, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations under the License. + +"""Tests for okta_mcp_server.utils.client.get_okta_client.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from okta_mcp_server.utils.auth.auth_manager import OktaAuthManager +from okta_mcp_server.utils.client import get_okta_client + + +def _build_manager_mock() -> MagicMock: + manager = MagicMock(spec=OktaAuthManager) + manager.org_url = "https://test.okta.com" + manager.authenticate = AsyncMock() + return manager + + +class TestGetOktaClient: + @pytest.mark.asyncio + async def test_uses_freshly_refreshed_token_not_stale_pre_refresh_value(self): + keyring_state = {"api_token": "stale-pre-refresh-token"} + + def refresh_then_return_true(): + keyring_state["api_token"] = "fresh-post-refresh-token" + return True + + manager = _build_manager_mock() + manager.is_valid_token = AsyncMock(side_effect=refresh_then_return_true) + + captured_config: dict = {} + + def fake_okta_client(config): + captured_config.update(config) + return MagicMock() + + with ( + patch("okta_mcp_server.utils.client.keyring") as mock_kr, + patch("okta_mcp_server.utils.client.OktaClient", side_effect=fake_okta_client), + ): + mock_kr.get_password.side_effect = lambda _s, k: keyring_state.get(k) + await get_okta_client(manager) + + assert captured_config["token"] == "fresh-post-refresh-token" + + @pytest.mark.asyncio + async def test_uses_cached_token_when_already_valid(self): + keyring_state = {"api_token": "valid-cached-token"} + + manager = _build_manager_mock() + manager.is_valid_token = AsyncMock(return_value=True) + + captured_config: dict = {} + + def fake_okta_client(config): + captured_config.update(config) + return MagicMock() + + with ( + patch("okta_mcp_server.utils.client.keyring") as mock_kr, + patch("okta_mcp_server.utils.client.OktaClient", side_effect=fake_okta_client), + ): + mock_kr.get_password.side_effect = lambda _s, k: keyring_state.get(k) + await get_okta_client(manager) + + assert captured_config["token"] == "valid-cached-token" + manager.authenticate.assert_not_awaited() diff --git a/tests/test_server_lifespan.py b/tests/test_server_lifespan.py new file mode 100644 index 0000000..481e35a --- /dev/null +++ b/tests/test_server_lifespan.py @@ -0,0 +1,78 @@ +# The Okta software accompanied by this notice is provided pursuant to the following terms: +# Copyright © 2026-Present, Okta, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations under the License. + +"""Tests for the okta_authorisation_flow lifespan handler.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from okta_mcp_server.server import OktaAppContext, okta_authorisation_flow + + +@pytest.fixture(autouse=True) +def _okta_env(monkeypatch): + monkeypatch.setenv("OKTA_ORG_URL", "https://test.okta.com") + monkeypatch.setenv("OKTA_CLIENT_ID", "test-client-id") + monkeypatch.delenv("OKTA_PRIVATE_KEY", raising=False) + monkeypatch.delenv("OKTA_KEY_ID", raising=False) + + +def _make_manager_mock(*, cached_valid: bool = True, auth_succeeds: bool = True) -> MagicMock: + manager = MagicMock() + manager.is_cached_token_valid = MagicMock(return_value=cached_valid) + manager.is_valid_token = AsyncMock(return_value=auth_succeeds) + manager.authenticate = AsyncMock() + manager.clear_tokens = MagicMock() + return manager + + +class TestOktaAuthorisationFlow: + @pytest.mark.asyncio + @patch("okta_mcp_server.server.OktaAuthManager") + async def test_skips_authenticate_when_cache_is_valid(self, mock_cls): + manager = _make_manager_mock(cached_valid=True) + mock_cls.return_value = manager + async with okta_authorisation_flow(MagicMock()) as ctx: + assert isinstance(ctx, OktaAppContext) + assert ctx.okta_auth_manager is manager + manager.is_cached_token_valid.assert_called_once() + manager.is_valid_token.assert_not_called() + manager.authenticate.assert_not_called() + + @pytest.mark.asyncio + @patch("okta_mcp_server.server.OktaAuthManager") + async def test_yields_context_after_successful_reauth(self, mock_cls): + manager = _make_manager_mock(cached_valid=False, auth_succeeds=True) + mock_cls.return_value = manager + async with okta_authorisation_flow(MagicMock()) as ctx: + assert ctx.okta_auth_manager is manager + manager.is_cached_token_valid.assert_called_once() + manager.is_valid_token.assert_awaited_once() + + @pytest.mark.asyncio + @patch("okta_mcp_server.server.OktaAuthManager") + async def test_does_not_clear_tokens_on_teardown(self, mock_cls): + manager = _make_manager_mock(cached_valid=True) + mock_cls.return_value = manager + async with okta_authorisation_flow(MagicMock()): + pass + manager.clear_tokens.assert_not_called() + + @pytest.mark.asyncio + @patch("okta_mcp_server.server.OktaAuthManager") + async def test_exits_with_code_1_when_auth_fails(self, mock_cls): + manager = _make_manager_mock(cached_valid=False, auth_succeeds=False) + mock_cls.return_value = manager + with pytest.raises(SystemExit) as exc_info: + async with okta_authorisation_flow(MagicMock()): + pytest.fail("Lifespan must not yield when no token is available") + assert exc_info.value.code == 1 + manager.is_cached_token_valid.assert_called_once() + manager.is_valid_token.assert_awaited_once()