Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/okta_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ class OktaAppContext:
async def okta_authorisation_flow(server: FastMCP) -> AsyncIterator[OktaAppContext]:
"""
Manages the application lifecycle. It initializes the OktaManager on startup,
performs authorization, and yields the context for use in tools.
re-using a cached token from the OS keyring when one is still valid, and yields
the context for use in tools.
"""
logger.info("Starting Okta authorization flow")
manager = OktaAuthManager()
await manager.authenticate()
logger.info("Okta authentication completed successfully")

try:
yield OktaAppContext(okta_auth_manager=manager)
finally:
logger.debug("Clearing Okta tokens")
manager.clear_tokens()

if manager.is_cached_token_valid():
logger.info("Re-using cached Okta token from keyring; skipping interactive auth")
else:
if not await manager.is_valid_token():
logger.error("Authentication failed: no token available after refresh and re-auth")
sys.exit(1)
logger.info("Okta authentication completed (refresh or new auth)")

yield OktaAppContext(okta_auth_manager=manager)


mcp = FastMCP("Okta IDaaS MCP Server", lifespan=okta_authorisation_flow)
Expand Down
68 changes: 51 additions & 17 deletions src/okta_mcp_server/utils/auth/auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from loguru import logger

SERVICE_NAME = "OktaAuthManager"
_TOKEN_EXPIRY_SAFETY_MARGIN_SECONDS = 60


@dataclass
Expand All @@ -29,7 +30,6 @@ class OktaAuthManager:

org_url: str = field(init=False)
client_id: str = field(init=False)
token_timestamp: int = 0
scopes: str = "openid profile email offline_access"
private_key: str = field(init=False, default=None)
key_id: str = field(init=False, default=None)
Expand Down Expand Up @@ -138,7 +138,6 @@ def _browserless_authenticate(self) -> str | None:
if access_token:
logger.info("Successfully obtained access token via browserless authentication")
keyring.set_password(SERVICE_NAME, "api_token", access_token)
self.token_timestamp = int(time.time())

# Note: Client credentials flow doesn't provide refresh tokens
logger.debug("Note: Client credentials flow does not provide refresh tokens")
Expand Down Expand Up @@ -211,7 +210,6 @@ def _poll_for_token(self, device_data):
if response.status_code == 200 and "access_token" in resp_json:
logger.info("Successfully obtained access token")
keyring.set_password(SERVICE_NAME, "api_token", resp_json["access_token"])
self.token_timestamp = int(time.time())

if "refresh_token" in resp_json:
logger.debug("Refresh token received and stored")
Expand Down Expand Up @@ -271,7 +269,6 @@ def refresh_access_token(self) -> bool:
logger.debug("New refresh token received and stored")
keyring.set_password(SERVICE_NAME, "refresh_token", resp_json["refresh_token"])

self.token_timestamp = int(time.time())
logger.info("Token refreshed successfully")
return True
else:
Expand Down Expand Up @@ -316,33 +313,71 @@ async def authenticate(self):
else:
logger.error("Authentication failed")

async def is_valid_token(self, expiry_duration: int = 3600) -> bool:
"""Ensure that a valid token is available. Refresh or re-authenticate if needed."""
logger.debug(f"Checking token validity (expiry duration: {expiry_duration}s)")
async def is_valid_token(self) -> bool:
"""Ensure that a valid token is available. Refresh or re-authenticate if needed.

Validity is determined by inspecting the JWT ``exp`` claim with a 60-second safety
margin to absorb clock skew. Tokens that are not parseable as JWTs (opaque tokens),
or JWTs without an ``exp`` claim, are treated as expired and fall through to the
refresh/reauth path.
"""
logger.debug("Checking token validity")

api_token = keyring.get_password(SERVICE_NAME, "api_token")
token_age = time.time() - self.token_timestamp

if api_token and token_age < expiry_duration:
logger.debug(f"Token is valid (age: {token_age:.0f}s)")
if api_token and self._token_is_unexpired(api_token):
logger.debug("Cached token is valid")
return True

logger.info(f"Token is expired or missing (age: {token_age:.0f}s)")
logger.info("Token is expired, missing, or unparseable; attempting refresh or re-auth")
if self.use_browserless_auth:
# For browserless auth, we can't refresh, so re-authenticate
# Browserless flow has no refresh token; re-authenticate.
logger.info("Re-authenticating using browserless flow")
await self.authenticate()
else:
# For device flow, try to refresh first
refreshed = self.refresh_access_token()

# If refresh token is not available or refresh failed, re-authenticate
if not refreshed:
logger.warning("Token refresh failed, initiating re-authentication")
logger.warning("Token refresh failed or unavailable; initiating re-authentication")
await self.authenticate()

return keyring.get_password(SERVICE_NAME, "api_token") is not None

@staticmethod
def _token_is_unexpired(token: str) -> bool:
"""Return True if ``token`` is a JWT whose ``exp`` claim is more than 60s in the future.

Opaque tokens (non-JWT), JWTs missing an ``exp`` claim, or JWTs that fail to decode
return False, causing callers to fall through to refresh/reauth.
"""
try:
claims = jwt.decode(token, options={"verify_signature": False})
except jwt.DecodeError:
logger.debug("Token is not a JWT (opaque); treating as expired")
return False

exp = claims.get("exp")
if exp is None:
logger.debug("JWT has no exp claim; treating as expired")
return False

seconds_remaining = exp - time.time()
if seconds_remaining > _TOKEN_EXPIRY_SAFETY_MARGIN_SECONDS:
logger.debug(f"JWT is valid for {seconds_remaining:.0f}s more")
return True

logger.debug(f"JWT expires in {seconds_remaining:.0f}s, within safety margin; treating as expired")
return False

def is_cached_token_valid(self) -> bool:
"""Return True if a valid (unexpired) JWT api_token is cached in the keyring.

Pure check with no side effects — does not run refresh or re-authentication.
Distinguishes a true cache hit from a refresh/re-auth that just minted a token,
so callers (e.g. the lifespan handler) can log accurately.
"""
api_token = keyring.get_password(SERVICE_NAME, "api_token")
return bool(api_token and self._token_is_unexpired(api_token))

def clear_tokens(self):
"""Clear all stored tokens from keyring."""
logger.info("Clearing stored tokens")
Expand All @@ -359,5 +394,4 @@ def clear_tokens(self):
except keyring.backend.errors.KeyringError as e:
logger.warning(f"Failed to delete refresh_token from keyring: {e}")

self.token_timestamp = 0
logger.info("Token cleanup completed")
3 changes: 1 addition & 2 deletions src/okta_mcp_server/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
async def get_okta_client(manager: OktaAuthManager) -> OktaClient:
"""Initialize and return an Okta client"""
logger.debug("Initializing Okta client")
api_token = keyring.get_password(SERVICE_NAME, "api_token")
if not await manager.is_valid_token():
logger.warning("Token is invalid or expired, re-authenticating")
await manager.authenticate()
api_token = keyring.get_password(SERVICE_NAME, "api_token")
api_token = keyring.get_password(SERVICE_NAME, "api_token")
config = {
"orgUrl": manager.org_url,
"token": api_token,
Expand Down
Loading