Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auth/credential_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, base_dir: Optional[str] = None):
Args:
base_dir: Base directory for credential files. If None, uses the directory
configured by the GOOGLE_MCP_CREDENTIALS_DIR environment variable,
or defaults to ~/.google_workspace_mcp/credentials if the environment
or defaults to ~/.config/google-workspace-mcp/credentials if the environment
variable is not set.
"""
if base_dir is None:
Expand All @@ -90,7 +90,7 @@ def __init__(self, base_dir: Optional[str] = None):
home_dir = os.path.expanduser("~")
if home_dir and home_dir != "~":
base_dir = os.path.join(
home_dir, ".google_workspace_mcp", "credentials"
home_dir, ".config", "google-workspace-mcp", "credentials"
)
else:
base_dir = os.path.join(os.getcwd(), ".credentials")
Expand Down
12 changes: 3 additions & 9 deletions auth/google_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_default_credentials_dir():
# Use user home directory for credentials storage
home_dir = os.path.expanduser("~")
if home_dir and home_dir != "~": # Valid home directory found
return os.path.join(home_dir, ".google_workspace_mcp", "credentials")
return os.path.join(home_dir, ".config", "google-workspace-mcp", "credentials")

# Fallback to current working directory if home directory is not accessible
return os.path.join(os.getcwd(), ".credentials")
Expand Down Expand Up @@ -871,21 +871,15 @@ async def get_authenticated_google_service(
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
)

# Ensure OAuth callback is available
from auth.oauth_callback_server import ensure_oauth_callback_available
from auth.oauth_callback_server import start_oauth_callback_server

redirect_uri = get_oauth_redirect_uri()
config = get_oauth_config()
success, error_msg = ensure_oauth_callback_available(
get_transport_mode(), config.port, config.base_uri
)
success, error_msg, redirect_uri = start_oauth_callback_server()
if not success:
error_detail = f" ({error_msg})" if error_msg else ""
raise GoogleAuthenticationError(
f"Cannot initiate OAuth flow - callback server unavailable{error_detail}"
)

# Generate auth URL and raise exception with it
auth_response = await start_auth_flow(
user_google_email=user_google_email,
service_name=f"Google {service_name.title()}",
Expand Down
115 changes: 114 additions & 1 deletion auth/oauth21_session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"""

import contextvars
import json
import logging
import os
from typing import Dict, Optional, Any, Tuple
from threading import RLock
from datetime import datetime, timedelta, timezone
Expand All @@ -19,6 +21,26 @@
logger = logging.getLogger(__name__)


def _get_oauth_states_file_path() -> str:
"""Get the file path for persisting OAuth states."""
# Use the same directory as credentials
env_dir = os.getenv("GOOGLE_MCP_CREDENTIALS_DIR")
if env_dir:
base_dir = env_dir
else:
home_dir = os.path.expanduser("~")
if home_dir and home_dir != "~":
base_dir = os.path.join(home_dir, ".config", "google-workspace-mcp")
else:
base_dir = os.path.join(os.getcwd(), ".config", "google-workspace-mcp")

# Ensure directory exists
if not os.path.exists(base_dir):
os.makedirs(base_dir, exist_ok=True)

return os.path.join(base_dir, "oauth_states.json")


def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
"""
Convert expiry values to timezone-naive UTC datetimes for google-auth compatibility.
Expand Down Expand Up @@ -187,6 +209,9 @@ class OAuth21SessionStore:

Security: Sessions are bound to specific users and can only access
their own credentials.

OAuth states are persisted to disk to survive server restarts during
the OAuth flow.
"""

def __init__(self):
Expand All @@ -199,6 +224,10 @@ def __init__(self):
] = {} # Maps session ID -> authenticated user email (immutable)
self._oauth_states: Dict[str, Dict[str, Any]] = {}
self._lock = RLock()
self._states_file_path = _get_oauth_states_file_path()

# Load persisted OAuth states on initialization
self._load_oauth_states_from_disk()

def _cleanup_expired_oauth_states_locked(self):
"""Remove expired OAuth state entries. Caller must hold lock."""
Expand All @@ -215,13 +244,91 @@ def _cleanup_expired_oauth_states_locked(self):
state[:8] if len(state) > 8 else state,
)

def _load_oauth_states_from_disk(self):
"""Load persisted OAuth states from disk on initialization."""
try:
if not os.path.exists(self._states_file_path):
logger.debug(
"No persisted OAuth states file found at %s", self._states_file_path
)
return

with open(self._states_file_path, "r") as f:
persisted_data = json.load(f)

if not isinstance(persisted_data, dict):
logger.warning("Invalid OAuth states file format, ignoring")
return

# Convert ISO format strings back to datetime objects
loaded_count = 0
for state, data in persisted_data.items():
try:
if "expires_at" in data and data["expires_at"]:
data["expires_at"] = datetime.fromisoformat(data["expires_at"])
if "created_at" in data and data["created_at"]:
data["created_at"] = datetime.fromisoformat(data["created_at"])
self._oauth_states[state] = data
loaded_count += 1
except (ValueError, TypeError) as e:
logger.warning(
"Failed to parse OAuth state %s: %s",
state[:8] if len(state) > 8 else state,
e,
)

# Clean up expired states after loading
self._cleanup_expired_oauth_states_locked()

logger.info(
"Loaded %d OAuth states from disk (%d after cleanup)",
loaded_count,
len(self._oauth_states),
)

except json.JSONDecodeError as e:
logger.warning("Failed to parse OAuth states file: %s", e)
except IOError as e:
logger.warning("Failed to read OAuth states file: %s", e)
except Exception as e:
logger.error("Unexpected error loading OAuth states: %s", e)

def _save_oauth_states_to_disk(self):
"""Persist OAuth states to disk. Caller must hold lock."""
try:
# Convert datetime objects to ISO format strings for JSON serialization
serializable_data = {}
for state, data in self._oauth_states.items():
serializable_data[state] = {
"session_id": data.get("session_id"),
"expires_at": data["expires_at"].isoformat()
if data.get("expires_at")
else None,
"created_at": data["created_at"].isoformat()
if data.get("created_at")
else None,
}

with open(self._states_file_path, "w") as f:
json.dump(serializable_data, f, indent=2)

logger.debug("Persisted %d OAuth states to disk", len(serializable_data))

except IOError as e:
logger.error("Failed to persist OAuth states to disk: %s", e)
except Exception as e:
logger.error("Unexpected error persisting OAuth states: %s", e)

def store_oauth_state(
self,
state: str,
session_id: Optional[str] = None,
expires_in_seconds: int = 600,
) -> None:
"""Persist an OAuth state value for later validation."""
"""Persist an OAuth state value for later validation.

States are stored both in memory and on disk to survive server restarts.
"""
if not state:
raise ValueError("OAuth state must be provided")
if expires_in_seconds < 0:
Expand All @@ -236,6 +343,10 @@ def store_oauth_state(
"expires_at": expiry,
"created_at": now,
}

# Persist to disk to survive server restarts
self._save_oauth_states_to_disk()

logger.debug(
"Stored OAuth state %s (expires at %s)",
state[:8] if len(state) > 8 else state,
Expand Down Expand Up @@ -277,6 +388,7 @@ def validate_and_consume_oauth_state(
if bound_session and session_id and bound_session != session_id:
# Consume the state to prevent replay attempts
del self._oauth_states[state]
self._save_oauth_states_to_disk()
logger.error(
"SECURITY: OAuth state session mismatch (expected %s, got %s)",
bound_session,
Expand All @@ -286,6 +398,7 @@ def validate_and_consume_oauth_state(

# State is valid – consume it to prevent reuse
del self._oauth_states[state]
self._save_oauth_states_to_disk()
logger.debug(
"Validated OAuth state %s",
state[:8] if len(state) > 8 else state,
Expand Down
Loading