Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
27 changes: 27 additions & 0 deletions .github/workflows/flake8-lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Lint Python code with Flake8

on:
push:
paths:
- 'src/**/*.py'
pull_request:
paths:
- 'src/**/*.py'

jobs:
flake8-lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install Flake8
run: pip install flake8

- name: Run Flake8
run: flake8 src/
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ RUN pip install hatchling \
# Expose the port the MCP server runs on
EXPOSE 8000

# Expose the port of the authentication callback server
EXPOSE 8001

CMD ["python", "src/server.py", "start", "--protocol", "http"]
123 changes: 35 additions & 88 deletions src/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import http.server
import socketserver
import urllib.parse
from fastapi import HTTPException
import requests
from pathlib import Path
from typing import Optional, Dict, Any, Tuple, TYPE_CHECKING
Expand All @@ -16,6 +17,7 @@
CLIENT_ID,
OAUTH_HOST,
AUTH_TIMEOUT_SECONDS,
REDIRECT_URI,
ROOT_DIR,
)
from src.config.app_config import app_config, AuthMethod
Expand Down Expand Up @@ -266,48 +268,48 @@ def authenticate(
# Use provided client_id or the default one
client_id = client_id or CLIENT_ID

# In HTTP mode, the server must not perform the OAuth flow
if app_config.server_mode == "http":
# Throw a 401 error if authentication fails
raise HTTPException(status_code=401, detail="Authentication failed")

# STDIO mode: keep the local browser-based flow
try:
# Discover OAuth server endpoints
print("Discovering OAuth server...")
oauth_config = discover_oauth_server(OAUTH_HOST)
authorization_endpoint = oauth_config.get("authorization_endpoint")
token_endpoint = oauth_config.get("token_endpoint")

if not authorization_endpoint or not token_endpoint:
print("Invalid OAuth server configuration")
return False, None

# Generate PKCE code verifier and challenge
code_verifier = generate_code_verifier()
code_challenge = generate_code_challenge(code_verifier)

# Generate state for security
state = generate_state()

# Find an available port for the redirect server
with socketserver.TCPServer(("127.0.0.1", 0), None) as s:
redirect_uri = REDIRECT_URI
parsed_uri = urllib.parse.urlparse(redirect_uri)
host = "127.0.0.1"
port = parsed_uri.port or 0
print(f"Using redirect URI: {redirect_uri}")
print(f"Using host: {host}, port: {port}")
with socketserver.TCPServer((host, port), None) as s:
port = s.server_address[1]

# Redirect URI
redirect_uri = f"http://127.0.0.1:{port}/callback"

# Create server class with additional attributes
class CallbackServer(socketserver.TCPServer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.received_callback = False
self.callback_params = None

# Create a custom handler factory
def handler(*args, **kwargs):
return AuthCallbackHandler(*args, **kwargs)

# Start a temporary web server to capture the callback
with CallbackServer(("127.0.0.1", port), handler) as httpd:
print(f"Starting authentication server on port {port}...")
print("Opening browser for authentication...")

# Prepare authorization URL
scopes = " ".join(ALWAYS_PRESENT_SCOPES)
auth_params = {
"client_id": client_id,
Expand All @@ -318,77 +320,50 @@ def handler(*args, **kwargs):
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

auth_url = f"{authorization_endpoint}?{urllib.parse.urlencode(auth_params)}"
print(f"Authenticating with client ID: {client_id}")
print(f"Auth URL: {auth_url}")

# Open browser to auth URL
webbrowser.open(auth_url)

# Set timeout
httpd.timeout = 1

# Serve until callback is received or timeout
start_time = time.time()
while not httpd.received_callback:
httpd.handle_request()
if time.time() - start_time > AUTH_TIMEOUT_SECONDS:
print("Authentication timed out")
return False, None

# Process callback parameters
if not httpd.callback_params:
print("No callback parameters received")
return False, None

# Check state parameter
if httpd.callback_params.get("state") != state:
print("State parameter mismatch, possible CSRF attack")
return False, None

# Extract authorization code
code = httpd.callback_params.get("code")
if not code:
print("No authorization code received")
return False, None

# Exchange code for tokens
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
"code_verifier": code_verifier,
}

# Send token request
response = requests.post(
token_endpoint,
data=token_data,
headers={"Content-Type": ("application/x-www-form-urlencoded")},
timeout=10,
)
response.raise_for_status()

# Parse token response
token_response = response.json()

# Add expires_at if we got expires_in
if "expires_in" in token_response and "expires_at" not in token_response:
token_response["expires_at"] = (
datetime.now().timestamp() + token_response["expires_in"]
)

# Create token set
token_set = TokenSet(token_response)

# Only save credentials to file in stdio mode
if app_config.server_mode == "stdio":
save_credentials(token_set)

return True, token_set

except Exception as e:
print(f"Authentication failed: {e}")
return False, None
Expand All @@ -412,68 +387,40 @@ def get_authentication_token(
"""
server_mode = app_config.server_mode # "stdio" or "http"

print(f"Server mode: {server_mode}")
print(f"Client ID: {client_id}")
print(f"HTTP Authorization header: {http_auth_header}")
print(f"App config auth token: {app_config.get_auth_token()}")

# For HTTP mode, first check the Authorization header
if server_mode == "http" and http_auth_header:
print("Using token from HTTP Authorization header")
if http_auth_header.startswith("Bearer "):
if server_mode == "http":
# In HTTP mode, always require Authorization header
if http_auth_header and http_auth_header.startswith("Bearer "):
token = http_auth_header[7:]
print("Using token from HTTP Authorization header")
app_config.set_auth_token(token, AuthMethod.JWT_TOKEN)
# Optionally: validate the token here (signature, expiry, etc.)
return token
# No valid token provided: server must not attempt to authenticate, just return None
return None

# Next, check for existing token in app_config
# STDIO mode: local/desktop flow
api_key = app_config.get_auth_token()
auth_method = app_config.get_auth_method()

if api_key:
print(f"Using existing authentication token (type: {auth_method.name})")
return api_key

# For stdio mode, check saved credentials file
if server_mode == "stdio":
credentials = load_credentials()
if credentials and "token_set" in credentials:
token_set = TokenSet(credentials["token_set"])

# If token is expired, try to refresh it
if token_set.is_expired() and token_set.refresh_token:
print("Access token expired, refreshing...")
refreshed_token_set = refresh_token(token_set, client_id or CLIENT_ID)
if refreshed_token_set:
token_set = refreshed_token_set
# Update app config with the refreshed token
app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH)
else:
print("Token refresh failed, proceeding to re-authentication")

# If we have a valid token, use it
if not token_set.is_expired() and token_set.access_token:
print("Using saved OAuth token.")
credentials = load_credentials()
if credentials and "token_set" in credentials:
token_set = TokenSet(credentials["token_set"])
if token_set.is_expired() and token_set.refresh_token:
refreshed_token_set = refresh_token(token_set, client_id or CLIENT_ID)
if refreshed_token_set:
token_set = refreshed_token_set
app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH)
return token_set.access_token
if not token_set.is_expired() and token_set.access_token:
app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH)
return token_set.access_token

# If no valid credentials found, launch browser authentication
print("No API key or valid authentication token found.")
# If no valid credentials found, launch browser authentication (stdio only)
success, token_set = authenticate(client_id)

if success and token_set and token_set.access_token:
print("Authentication successful!")
app_config.set_auth_token(token_set.access_token, AuthMethod.OAUTH)

# Only save to credentials file in stdio mode
# In HTTP mode, we just keep it in memory (app_config)
if server_mode == "stdio" and token_set:
save_credentials(token_set)

save_credentials(token_set)
return token_set.access_token
else:
print("Authentication failed. Please try again or provide an API key.")
return None
return None


def get_oauth_provider() -> Optional["SingleStoreOAuthProvider"]:
Expand Down
6 changes: 5 additions & 1 deletion src/config/app_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from enum import Enum

from src.config.config import SINGLESTORE_ORG_ID, SINGLESTORE_ORG_NAME
from src.config.config import SERVER_MODE, SINGLESTORE_ORG_ID, SINGLESTORE_ORG_NAME


class AuthMethod(Enum):
Expand Down Expand Up @@ -410,3 +410,7 @@ def server_mode(self, mode: str):
)
app_config.set_organization(SINGLESTORE_ORG_ID, SINGLESTORE_ORG_NAME)
print(app_config.get_organization())

if SERVER_MODE:
app_config.server_mode = SERVER_MODE
print(f"Server mode set to {SERVER_MODE}")
3 changes: 3 additions & 0 deletions src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
SINGLESTORE_API_BASE_URL = "https://api.singlestore.com"
SINGLESTORE_GRAPHQL_PUBLIC_ENDPOINT = "https://backend.singlestore.com/public"

SERVER_MODE = os.getenv("SERVER_MODE", "stdio")


# The root directory of the project
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand All @@ -18,3 +20,4 @@
OAUTH_HOST = os.getenv("OAUTH_HOST", "https://authsvc.singlestore.com/")
AUTH_TIMEOUT_SECONDS = 60 # In seconds
CLIENT_URI = os.getenv("CLIENT_URI", "http://localhost:8000")
REDIRECT_URI = os.getenv("REDIRECT_URI", "http://localhost:8001/callback")
10 changes: 2 additions & 8 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ def main():
nargs="?",
help="SingleStore API key (optional, will use web auth if not provided)",
)
start_parser.add_argument(
"--protocol",
default="stdio",
choices=["stdio", "sse", "http"],
help="Protocol to run the server on (default: stdio)",
)
start_parser.add_argument(
"--port",
default=8000,
Expand Down Expand Up @@ -129,7 +123,7 @@ def main():
sys.exit(init_command(api_key, auth_token, args.client))
elif args.command == "start":
# Ensure protocol is set for the start command
protocol = getattr(args, "protocol", "stdio")
protocol = os.getenv("SERVER_MODE", "stdio")

if getattr(args, "api_key", None):
print(
Expand All @@ -156,7 +150,7 @@ def main():
f"Running Streamable HTTP server with protocol {protocol.upper()} on port {args.port}"
)
app_config.set_server_port(args.port)
app_config.server_mode = "stdio"
app_config.server_mode = "http"
else:
print(f"Running server with protocol {protocol.upper()}")
app_config.server_mode = "stdio"
Expand Down
1 change: 0 additions & 1 deletion src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def select_organization():
Dictionary with the selected organization ID and name
"""

print("select_org: ", app_config.organization_id)
# If organization is already selected, return it
if app_config.is_organization_selected():
return {
Expand Down
Loading
Loading