diff --git a/src/api/common.py b/src/api/common.py index 5b8d326..593fe3a 100644 --- a/src/api/common.py +++ b/src/api/common.py @@ -373,7 +373,7 @@ def get_access_token() -> str: logger.debug(f"Getting access token, is_remote: {settings.is_remote}") - access_token: str + access_token: str = "" if isinstance(settings, config.RemoteSettings) and settings.auth_provider: request = get_session_request() access_token = request.headers.get("Authorization", "").replace("Bearer ", "") @@ -382,6 +382,8 @@ def get_access_token() -> str: ) if real_token: access_token = real_token.token + else: + access_token = "" # Clear access_token if auth provider returns None logger.debug( f"Remote access token retrieved (length: {len(access_token) if access_token else 0})" ) diff --git a/src/auth/provider.py b/src/auth/provider.py index 7c10fa9..6d39818 100644 --- a/src/auth/provider.py +++ b/src/auth/provider.py @@ -239,7 +239,7 @@ async def exchange_authorization_code( ) raise HTTPException(400, "Failed to exchange code for token") - data = response.json() + data = await response.json() if "error" in data: raise HTTPException(400, data.get("error_description", data["error"])) @@ -349,7 +349,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: token=token, client_id=client_id, scopes=scopes, - expires_at=expires_at, + expires_at=int(expires_at) if expires_at else None, ) async def load_refresh_token( diff --git a/tests/unit/auth/README.md b/tests/unit/auth/README.md new file mode 100644 index 0000000..eff0b94 --- /dev/null +++ b/tests/unit/auth/README.md @@ -0,0 +1,312 @@ +# Remote Mode Authentication Test Suite + +This test suite provides comprehensive testing for the remote mode authentication flow in the SingleStore MCP server. The tests cover the complete OAuth 2.0 + PKCE authentication workflow, token management, and API request authentication. + +## Test Structure + +### ๐Ÿ“ Test Files + +#### `test_remote_auth_flow.py` +- **Purpose**: Tests the core OAuth provider implementation (`SingleStoreOAuthProvider`) +- **Coverage**: + - OAuth provider initialization and configuration + - Authorization code generation and exchange + - Token storage and retrieval from database + - Token validation and expiration handling + - Complete end-to-end authentication flow + - Error handling scenarios + +#### `test_oauth_proxy_integration.py` +- **Purpose**: Tests the OAuth proxy integration with FastMCP (`SingleStoreOAuthProxy`) +- **Coverage**: + - OpenID Connect discovery and configuration + - JWT token verification with JWKS + - Proxy provider initialization + - Integration with RemoteSettings + - Error handling for proxy scenarios + +#### `test_remote_api_auth.py` +- **Purpose**: Tests how authentication is used in API requests +- **Coverage**: + - Token retrieval from auth provider + - API request building with authentication headers + - Error handling for invalid/expired tokens + - Session context management + - Concurrent request handling + +#### `conftest.py` +- **Purpose**: Shared test fixtures and utilities +- **Provides**: + - Mock objects for settings, clients, tokens + - Database connection mocks + - HTTP request/response mocks + - Common test data and utilities + +## ๐Ÿš€ Running Tests + +### Run All Tests +```bash +# From the test directory +python test_runner.py + +# Or using pytest directly +pytest tests/unit/auth/ -v +``` + +### Run Specific Test File +```bash +python test_runner.py --test test_remote_auth_flow +``` + +### Run Specific Test Method +```bash +python test_runner.py --test "test_provider_initialization" +``` + +### Run with Coverage +```bash +python test_runner.py --coverage +``` + +### Development Testing +```bash +# Watch mode (if pytest-watch is installed) +pytest tests/unit/auth/ --watch + +# Run only failed tests +pytest tests/unit/auth/ --lf + +# Run with debug output +pytest tests/unit/auth/ -v -s --tb=long +``` + +## ๐Ÿ” Test Scenarios Covered + +### Authentication Flow Tests + +1. **OAuth Provider Initialization** + - Database schema creation + - Settings validation + - PKCE code generation + +2. **Authorization Flow** + - Client registration + - Authorization URL generation + - State management + - Callback handling + +3. **Token Exchange** + - Authorization code validation + - Token exchange with SingleStore + - Token storage in database + - Error handling for invalid codes + +4. **Token Management** + - Token retrieval and validation + - Expiration handling + - Token revocation + - Database cleanup + +### Proxy Integration Tests + +1. **OpenID Connect Discovery** + - Endpoint discovery + - Configuration validation + - Error handling for invalid configs + +2. **JWT Token Verification** + - JWKS retrieval and caching + - Token signature validation + - Expiration checking + - Error handling for malformed tokens + +3. **FastMCP Integration** + - Auth provider setup + - Route registration + - Request handling + +### API Authentication Tests + +1. **Token Retrieval** + - Remote mode token lookup + - Local mode fallbacks + - Error handling for missing tokens + +2. **Request Authentication** + - Authorization header handling + - Token validation in requests + - Error responses for unauthorized requests + +3. **Concurrent Access** + - Multiple simultaneous requests + - Token sharing across requests + - Race condition handling + +## ๐ŸŽฏ Test Coverage Areas + +### Components Tested +- `src.auth.provider.SingleStoreOAuthProvider` +- `src.auth.proxy_provider.SingleStoreOAuthProxy` +- `src.api.common.get_access_token` +- `src.api.common.build_request` +- `src.config.config.RemoteSettings` + +### Authentication Scenarios +- โœ… Valid token authentication +- โœ… Expired token handling +- โœ… Invalid token rejection +- โœ… Missing token handling +- โœ… Authorization code flow +- โœ… Token refresh scenarios +- โœ… Database error handling +- โœ… Network error handling +- โœ… Malformed request handling + +### Error Conditions +- โœ… Database connection failures +- โœ… SingleStore API errors +- โœ… Invalid OAuth configurations +- โœ… Network timeouts +- โœ… Malformed JWT tokens +- โœ… JWKS retrieval failures +- โœ… Concurrent access conflicts + +## ๐Ÿ”ง Mock Strategy + +### Database Mocking +```python +# Database connections and cursors are mocked +mock_conn, mock_cursor = mock_database_connection +mock_cursor.fetchone.return_value = [token_data] +``` + +### HTTP Request Mocking +```python +# External HTTP requests are mocked +@patch('src.auth.provider.create_mcp_http_client') +async def test_token_exchange(mock_http_client): + mock_client.post.return_value = mock_token_response +``` + +### Settings Mocking +```python +# Configuration is mocked for isolated testing +mock_settings = mock_remote_settings() +mock_settings.client_id = "test-client-id" +``` + +## ๐Ÿ“Š Test Metrics + +### Coverage Goals +- **Unit Test Coverage**: >90% for authentication components +- **Integration Coverage**: >80% for auth flow scenarios +- **Error Handling**: 100% for critical error paths + +### Performance Benchmarks +- Token validation: <10ms per request +- Authorization flow: <5 seconds end-to-end +- Database operations: <100ms per query + +## ๐Ÿ› Debugging Tests + +### Common Issues + +1. **Import Errors** + ```bash + # Make sure you're in the project root + cd /path/to/mcp-server-singlestore + python -m pytest tests/unit/auth/ + ``` + +2. **Mock Configuration** + ```python + # Check that mocks are properly configured + assert mock_function.called + assert mock_function.call_count == 1 + ``` + +3. **Async Test Issues** + ```python + # Make sure async tests are properly marked + @pytest.mark.asyncio + async def test_async_function(): + result = await async_function() + ``` + +### Test Debugging Commands +```bash +# Run with maximum verbosity +pytest tests/unit/auth/ -vvv -s + +# Run single test with debugging +pytest tests/unit/auth/test_remote_auth_flow.py::test_specific_function -vvv -s --tb=long + +# Run with pdb debugger +pytest tests/unit/auth/ --pdb +``` + +## ๐Ÿ”„ Continuous Integration + +### CI Pipeline Tests +```yaml +# Example GitHub Actions configuration +- name: Run Remote Auth Tests + run: | + python -m pytest tests/unit/auth/ \ + --cov=src.auth \ + --cov=src.api.common \ + --cov-report=xml \ + --junit-xml=test-results.xml +``` + +### Local Pre-commit Testing +```bash +# Run before committing changes +./tests/unit/auth/test_runner.py --coverage +``` + +## ๐Ÿ“ˆ Extending Tests + +### Adding New Test Cases + +1. **Create test method in appropriate file**: + ```python + async def test_new_scenario(self, oauth_provider, sample_client): + # Test implementation + result = await oauth_provider.new_method() + assert result.is_valid + ``` + +2. **Add fixtures if needed**: + ```python + @pytest.fixture + def new_test_fixture(): + return MockObject() + ``` + +3. **Update test runner if new file created**: + ```python + test_files = [ + "test_remote_auth_flow.py", + "test_oauth_proxy_integration.py", + "test_remote_api_auth.py", + "test_new_feature.py" # Add new file + ] + ``` + +### Test Guidelines + +1. **Use descriptive test names** +2. **Test one scenario per test method** +3. **Use fixtures for common setup** +4. **Mock external dependencies** +5. **Assert both positive and negative cases** +6. **Include error handling tests** + +## ๐Ÿ“š Related Documentation + +- [OAuth 2.0 RFC 6749](https://tools.ietf.org/html/rfc6749) +- [PKCE RFC 7636](https://tools.ietf.org/html/rfc7636) +- [OpenID Connect Core](https://openid.net/specs/openid-connect-core-1_0.html) +- [FastMCP Documentation](https://github.com/modelcontextprotocol/python-sdk) +- [pytest Documentation](https://docs.pytest.org/) diff --git a/tests/unit/auth/conftest.py b/tests/unit/auth/conftest.py new file mode 100644 index 0000000..20a92d5 --- /dev/null +++ b/tests/unit/auth/conftest.py @@ -0,0 +1,339 @@ +""" +Shared test fixtures and configuration for remote authentication tests. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +import time + +from mcp.shared.auth import OAuthClientInformationFull +from mcp.server.auth.provider import AccessToken +from src.config.config import RemoteSettings + + +@pytest.fixture +def sample_oauth_client(): + """Create a sample OAuth client for testing.""" + return OAuthClientInformationFull( + client_id="test-client-12345", + client_name="Test OAuth Client", + redirect_uris=["http://localhost:3000/callback", "http://localhost:8080/auth"], + ) + + +@pytest.fixture +def sample_access_token(): + """Create a sample access token for testing.""" + return AccessToken( + token="sample-access-token-123", + client_id="test-client-12345", + scopes=["openid", "profile", "email"], + expires_at=int(time.time()) + 3600, # Expires in 1 hour + ) + + +@pytest.fixture +def expired_access_token(): + """Create an expired access token for testing.""" + return AccessToken( + token="expired-access-token-123", + client_id="test-client-12345", + scopes=["openid", "profile"], + expires_at=int(time.time()) - 3600, # Expired 1 hour ago + ) + + +@pytest.fixture +def mock_database_connection(): + """Create a mock database connection and cursor.""" + mock_conn = Mock() + mock_cursor = Mock() + + # Setup connection context manager + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=None) + mock_conn.cursor.return_value = mock_cursor + mock_conn.commit = Mock() + + # Setup cursor methods + mock_cursor.execute = Mock() + mock_cursor.fetchone = Mock() + mock_cursor.fetchall = Mock() + + return mock_conn, mock_cursor + + +@pytest.fixture +def mock_remote_settings(): + """Create comprehensive mock RemoteSettings for testing.""" + settings = Mock(spec=RemoteSettings) + settings.is_remote = True + settings.auth_provider = None # Add missing auth_provider attribute + settings.client_id = "test-mcp-client-id" + settings.org_id = "test-organization-id" + settings.oauth_db_url = "mysql://test:test@localhost/test_oauth_db" + settings.callback_path = "http://localhost:8010/oauth/callback" + settings.required_scopes = ["openid", "profile", "email"] + settings.singlestore_auth_url = "https://authsvc.singlestore.com/authorize" + settings.singlestore_token_url = "https://authsvc.singlestore.com/token" + settings.s2_api_base_url = "https://api.singlestore.com" + settings.server_url = "http://localhost:8010" + settings.issuer_url = "https://authsvc.singlestore.com" + settings.segment_write_key = "test-segment-key" + settings.jwt_signing_key = "test-jwt-signing-key-for-remote" + + # Mock auth provider + settings.auth_provider = AsyncMock() + + # Mock analytics manager + settings.analytics_manager = Mock() + settings.analytics_manager.identify = Mock() + settings.analytics_manager.track_event = Mock() + + return settings + + +@pytest.fixture +def mock_openid_configuration(): + """Create a mock OpenID Connect configuration response.""" + return { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + "userinfo_endpoint": "https://authsvc.singlestore.com/userinfo", + "scopes_supported": ["openid", "profile", "email", "phone", "address"], + "response_types_supported": ["code", "token", "id_token"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["plain", "S256"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + ], + } + + +@pytest.fixture +def mock_singlestore_token_response(): + """Create a mock token response from SingleStore OAuth.""" + return { + "access_token": "singlestore_access_token_12345", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "singlestore_refresh_token_67890", + "scope": "openid profile email", + } + + +@pytest.fixture +def mock_singlestore_error_response(): + """Create a mock error response from SingleStore OAuth.""" + return { + "error": "invalid_grant", + "error_description": "The provided authorization grant is invalid, expired, revoked, or does not match the redirection URI.", + } + + +@pytest.fixture +def mock_user_info_response(): + """Create a mock user info response from SingleStore API.""" + return [ + { + "userID": "test-user-12345", + "email": "test@example.com", + "firstName": "Test", + "lastName": "User", + "organizationID": "test-organization-id", + } + ] + + +@pytest.fixture +def mock_http_requests(): + """Mock all HTTP requests made during testing.""" + with ( + patch("requests.get") as mock_get, + patch("requests.post") as mock_post, + patch("requests.put") as mock_put, + patch("requests.patch") as mock_patch, + patch("requests.delete") as mock_delete, + ): + yield { + "get": mock_get, + "post": mock_post, + "put": mock_put, + "patch": mock_patch, + "delete": mock_delete, + } + + +@pytest.fixture +def mock_fastmcp_context(): + """Create a mock FastMCP context for testing.""" + context = Mock() + context.info = AsyncMock() + context.error = AsyncMock() + context.warning = AsyncMock() + context.request_context = Mock() + context.request_context.request = Mock() + context.request_context.request.headers = {"Authorization": "Bearer test-token"} + context.request_context.lifespan_context = {"org_id": "test-org-123"} + return context + + +@pytest.fixture(autouse=True) +def reset_context_vars(): + """Reset context variables before each test to ensure isolation.""" + from src.config.config import _settings_ctx, _user_id_ctx, _app_ctx + + # Store original values + original_settings = _settings_ctx.get(None) + original_user_id = _user_id_ctx.get(None) + original_app = _app_ctx.get(None) + + # Reset to None + _settings_ctx.set(None) + _user_id_ctx.set(None) + _app_ctx.set(None) + + yield + + # Restore original values + if original_settings is not None: + _settings_ctx.set(original_settings) + if original_user_id is not None: + _user_id_ctx.set(original_user_id) + if original_app is not None: + _app_ctx.set(original_app) + + +class MockAsyncHttpClient: + """Mock async HTTP client for testing.""" + + def __init__(self, response_data=None, status_code=200, raise_exception=None): + self.response_data = response_data or {} + self.status_code = status_code + self.raise_exception = raise_exception + self.request_history = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def post(self, url, **kwargs): + self.request_history.append({"method": "POST", "url": url, "kwargs": kwargs}) + + if self.raise_exception: + raise self.raise_exception + + response = AsyncMock() + response.status_code = self.status_code + response.json.return_value = self.response_data + response.text = str(self.response_data) + return response + + async def get(self, url, **kwargs): + self.request_history.append({"method": "GET", "url": url, "kwargs": kwargs}) + + if self.raise_exception: + raise self.raise_exception + + response = AsyncMock() + response.status_code = self.status_code + response.json.return_value = self.response_data + response.text = str(self.response_data) + return response + + +@pytest.fixture +def mock_async_http_client(): + """Create a mock async HTTP client factory.""" + + def create_client(response_data=None, status_code=200, raise_exception=None): + return MockAsyncHttpClient(response_data, status_code, raise_exception) + + return create_client + + +# Common test data +VALID_JWT_PAYLOAD = { + "client_id": "test-client-12345", + "exp": int(time.time()) + 3600, # Expires in 1 hour + "iat": int(time.time()), + "iss": "https://authsvc.singlestore.com", + "aud": ["test-client-12345"], + "sub": "test-user-12345", + "scope": "openid profile email", +} + +EXPIRED_JWT_PAYLOAD = { + "client_id": "test-client-12345", + "exp": int(time.time()) - 3600, # Expired 1 hour ago + "iat": int(time.time()) - 7200, # Issued 2 hours ago + "iss": "https://authsvc.singlestore.com", + "aud": ["test-client-12345"], + "sub": "test-user-12345", + "scope": "openid profile email", +} + + +@pytest.fixture +def valid_jwt_payload(): + """Provide a valid JWT payload for testing.""" + return VALID_JWT_PAYLOAD.copy() + + +@pytest.fixture +def expired_jwt_payload(): + """Provide an expired JWT payload for testing.""" + return EXPIRED_JWT_PAYLOAD.copy() + + +@pytest.fixture +def mock_jwks_response(): + """Create a mock JWKS (JSON Web Key Set) response.""" + return { + "keys": [ + { + "kty": "EC", + "use": "sig", + "crv": "P-521", + "kid": "test-key-id", + "x": "test-x-coordinate", + "y": "test-y-coordinate", + "alg": "ES512", + } + ] + } + + +# Test utilities +def create_mock_auth_code(client_id="test-client", code="test-code", expires_in=300): + """Utility function to create mock authorization code data.""" + import json + + return [ + client_id, # client_id + "http://localhost:3000/callback", # redirect_uri + True, # redirect_uri_provided_explicitly + time.time() + expires_in, # expires_at + json.dumps(["openid", "profile", "email"]), # scopes + "test-code-challenge", # code_challenge + ] + + +def create_mock_token_data(client_id="test-client", expires_in=3600): + """Utility function to create mock token data.""" + import json + + return [ + client_id, # client_id + json.dumps(["openid", "profile", "email"]), # scopes + time.time() + expires_in, # expires_at + ] + + +# Mark all tests as asyncio by default for async test functions +pytest_plugins = ["pytest_asyncio"] diff --git a/tests/unit/auth/test_oauth_proxy_integration.py b/tests/unit/auth/test_oauth_proxy_integration.py new file mode 100644 index 0000000..d30f4e2 --- /dev/null +++ b/tests/unit/auth/test_oauth_proxy_integration.py @@ -0,0 +1,522 @@ +""" +Test suite for OAuth proxy integration with FastMCP. + +This module tests the integration between the SingleStoreOAuthProxy +and FastMCP server, including: +- Proxy provider initialization +- Token verification with JWT +- FastMCP auth middleware integration +- Error handling in proxy scenarios +""" + +import json +import jwt +import pytest +from unittest.mock import Mock, AsyncMock, patch +from datetime import datetime + +import requests +from mcp.server.auth.provider import AccessToken +from fastmcp.server.auth.oauth_proxy import OAuthProxy + +from src.auth.proxy_provider import SingleStoreOAuthProxy +from src.config.config import RemoteSettings + + +class TestSingleStoreOAuthProxy: + """Test cases for SingleStoreOAuthProxy initialization and configuration.""" + + @pytest.fixture + def mock_openid_config(self): + """Mock OpenID Connect configuration response.""" + return { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + "scopes_supported": ["openid", "profile", "email"], + } + + @pytest.fixture + def mock_requests_get(self, mock_openid_config): + """Mock requests.get for OpenID configuration discovery.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = mock_openid_config + mock_get.return_value = mock_response + yield mock_get + + def test_proxy_initialization_success(self, mock_requests_get): + """Test successful OAuth proxy initialization with OpenID discovery.""" + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + client_secret="test-client-secret", + base_url="http://localhost:8010", + jwt_signing_key="test-jwt-key", + ) + + assert proxy.issuer_url == "https://authsvc.singlestore.com" + assert proxy.client_id == "test-client-id" + assert proxy.client_secret == "test-client-secret" + assert proxy.base_url == "http://localhost:8010" + assert proxy.jwt_signing_key == "test-jwt-key" + + # Should have fetched OpenID config + mock_requests_get.assert_called_once() + assert proxy._config is not None + + # Should have created verifier and provider + assert proxy._verifier is not None + assert proxy.provider is not None + + def test_proxy_initialization_discovery_failure(self): + """Test proxy initialization when OpenID discovery fails.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_get.side_effect = requests.RequestException("Connection failed") + + with pytest.raises( + RuntimeError, match="Failed to fetch OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://invalid.example.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_proxy_initialization_missing_endpoints(self, mock_requests_get): + """Test proxy initialization with incomplete OpenID config.""" + # Mock incomplete config + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + # Missing authorization_endpoint and token_endpoint + } + mock_get.return_value = mock_response + + with pytest.raises( + RuntimeError, match="Missing required fields in OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_proxy_initialization_missing_jwt_key(self, mock_requests_get): + """Test proxy initialization without JWT signing key.""" + with pytest.raises(RuntimeError, match="JWT signing key is not set"): + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key=None, + ) + + def test_get_provider_returns_oauth_proxy(self, mock_requests_get): + """Test that get_provider returns a properly configured OAuthProxy.""" + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + provider = proxy.get_provider() + + assert isinstance(provider, OAuthProxy) + + +class TestCustomJWTVerifier: + """Test cases for the custom JWT token verifier.""" + + @pytest.fixture + def mock_jwks_client(self): + """Mock PyJWKClient for JWT verification.""" + with patch("src.auth.proxy_provider.PyJWKClient") as mock_client_class: + mock_client = Mock() + mock_signing_key = Mock() + mock_signing_key.key = "test-signing-key" + mock_client.get_signing_key_from_jwt.return_value = mock_signing_key + mock_client_class.return_value = mock_client + yield mock_client + + @pytest.fixture + def jwt_verifier(self, mock_jwks_client): + """Create a JWT verifier instance for testing.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + return proxy._verifier + + @patch("src.auth.proxy_provider.jwt.decode") + @pytest.mark.asyncio + async def test_verify_token_success( + self, mock_jwt_decode, jwt_verifier, mock_jwks_client + ): + """Test successful JWT token verification.""" + # Mock decoded token payload + mock_jwt_decode.return_value = { + "client_id": "test-client-id", + "exp": int(datetime.now().timestamp()) + 3600, + "aud": ["test-client-id"], + "iss": "https://authsvc.singlestore.com", + "sub": "user-123", + } + + access_token = await jwt_verifier.verify_token("valid.jwt.token") + + assert access_token is not None + assert isinstance(access_token, AccessToken) + assert access_token.client_id == "test-client-id" + assert access_token.scopes == ["openid"] + assert access_token.resource == "test-client-id" + + # Should have called JWT decode with proper parameters + mock_jwt_decode.assert_called_once() + args, kwargs = mock_jwt_decode.call_args + assert kwargs["audience"] == "test-client-id" + assert kwargs["algorithms"] == ["ES512"] + + @patch("src.auth.proxy_provider.jwt.decode") + @pytest.mark.asyncio + async def test_verify_token_expired( + self, mock_jwt_decode, jwt_verifier, mock_jwks_client + ): + """Test JWT token verification with expired token.""" + mock_jwt_decode.side_effect = jwt.ExpiredSignatureError("Token has expired") + + access_token = await jwt_verifier.verify_token("expired.jwt.token") + + assert access_token is None + + @patch("src.auth.proxy_provider.jwt.decode") + @pytest.mark.asyncio + async def test_verify_token_invalid_signature( + self, mock_jwt_decode, jwt_verifier, mock_jwks_client + ): + """Test JWT token verification with invalid signature.""" + mock_jwt_decode.side_effect = jwt.InvalidSignatureError("Invalid signature") + + access_token = await jwt_verifier.verify_token("invalid.jwt.token") + + assert access_token is None + + @patch("src.auth.proxy_provider.jwt.decode") + @pytest.mark.asyncio + async def test_verify_token_malformed( + self, mock_jwt_decode, jwt_verifier, mock_jwks_client + ): + """Test JWT token verification with malformed token.""" + mock_jwt_decode.side_effect = jwt.DecodeError("Invalid token format") + + access_token = await jwt_verifier.verify_token("malformed.token") + + assert access_token is None + + @pytest.mark.asyncio + async def test_verify_token_jwks_error(self, jwt_verifier): + """Test JWT token verification when JWKS retrieval fails.""" + with patch.object(jwt_verifier, "jwks_client") as mock_client: + mock_client.get_signing_key_from_jwt.side_effect = Exception("JWKS error") + + # Should raise the JWKS error + with pytest.raises(Exception, match="JWKS error"): + await jwt_verifier.verify_token("test.jwt.token") + + +class TestOAuthProxyIntegration: + """Test integration between OAuth proxy and MCP components.""" + + @pytest.fixture + def mock_storage(self): + """Mock client storage for testing.""" + storage = AsyncMock() + storage.get = AsyncMock(return_value=None) + storage.set = AsyncMock() + storage.delete = AsyncMock() + return storage + + @pytest.fixture + def oauth_proxy_with_storage(self, mock_storage): + """Create OAuth proxy with mocked storage.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + client_storage=mock_storage, + ) + + return proxy + + def test_proxy_with_encrypted_storage(self, mock_storage): + """Test OAuth proxy with encrypted client storage.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + # Create proxy with encryption enabled (default) + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + client_storage=mock_storage, + encrypt_db=True, + ) + + assert proxy.encrypt_db is True + assert proxy.provider is not None + + def test_proxy_with_custom_scopes(self): + """Test OAuth proxy with custom valid scopes.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + custom_scopes = ["openid", "profile", "email", "custom_scope"] + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + valid_scopes=custom_scopes, + ) + + assert proxy.valid_scopes == custom_scopes + + def test_proxy_with_custom_redirect_path(self): + """Test OAuth proxy with custom redirect path.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + custom_redirect = "/custom/oauth/callback" + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + redirect_path=custom_redirect, + ) + + assert proxy.redirect_path == custom_redirect + + +class TestRemoteSettingsIntegration: + """Test integration with RemoteSettings configuration.""" + + @pytest.fixture + def sample_remote_settings(self): + """Create sample remote settings for testing.""" + return { + "transport": "sse", + "is_remote": True, + "issuer_url": "https://authsvc.singlestore.com", + "required_scopes": ["openid", "profile", "email"], + "server_url": "http://localhost:8010", + "client_id": "test-client-id-uuid", + "callback_path": "/oauth/callback", + "oauth_db_url": "mysql://test:test@localhost/oauth_test", + "segment_write_key": "test-segment-key", + "jwt_signing_key": "test-jwt-signing-key", + } + + @patch("src.config.config.AnalyticsManager") + @patch("src.config.config.SingleStoreKV") + def test_remote_settings_creates_auth_provider( + self, mock_kv, mock_analytics, sample_remote_settings + ): + """Test that RemoteSettings automatically creates auth provider.""" + mock_kv_instance = Mock() + mock_kv.return_value = mock_kv_instance + + mock_analytics_instance = Mock() + mock_analytics.return_value = mock_analytics_instance + + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + settings = RemoteSettings(**sample_remote_settings) + + assert settings.auth_provider is not None + assert isinstance(settings.auth_provider, OAuthProxy) + assert settings.singlestore_kv == mock_kv_instance + assert settings.analytics_manager == mock_analytics_instance + + @patch("src.config.config.AnalyticsManager") + def test_remote_settings_without_oauth_db_url( + self, mock_analytics, sample_remote_settings + ): + """Test RemoteSettings behavior without oauth_db_url.""" + mock_analytics_instance = Mock() + mock_analytics.return_value = mock_analytics_instance + + # Remove oauth_db_url from settings + settings_without_db = sample_remote_settings.copy() + settings_without_db["oauth_db_url"] = None + + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + settings = RemoteSettings(**settings_without_db) + + assert settings.singlestore_kv is None + assert settings.auth_provider is not None + + +class TestErrorScenarios: + """Test various error scenarios in OAuth proxy operation.""" + + def test_openid_discovery_timeout(self): + """Test OpenID discovery with timeout.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_get.side_effect = requests.Timeout("Request timeout") + + with pytest.raises( + RuntimeError, match="Failed to fetch OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://slow.example.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_openid_discovery_http_error(self): + """Test OpenID discovery with HTTP error.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError( + "404 Not Found" + ) + mock_get.return_value = mock_response + + with pytest.raises( + RuntimeError, match="Failed to fetch OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://notfound.example.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_invalid_json_in_openid_config(self): + """Test OpenID discovery with invalid JSON response.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + mock_get.return_value = mock_response + + with pytest.raises( + RuntimeError, match="Failed to fetch OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://badjson.example.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_missing_issuer_in_openid_config(self): + """Test OpenID config missing required issuer field.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + # Missing "issuer" field + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + mock_get.return_value = mock_response + + with pytest.raises( + RuntimeError, match="Missing required fields in OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) + + def test_missing_jwks_uri_in_openid_config(self): + """Test OpenID config missing JWKS URI.""" + with patch("src.auth.proxy_provider.requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + # Missing "jwks_uri" field + } + mock_get.return_value = mock_response + + with pytest.raises( + RuntimeError, match="Missing required fields in OpenID configuration" + ): + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com", + client_id="test-client-id", + jwt_signing_key="test-jwt-key", + ) diff --git a/tests/unit/auth/test_remote_api_auth.py b/tests/unit/auth/test_remote_api_auth.py new file mode 100644 index 0000000..db37e18 --- /dev/null +++ b/tests/unit/auth/test_remote_api_auth.py @@ -0,0 +1,556 @@ +""" +Test suite for remote mode API authentication integration. + +This module tests how authentication tokens are used in API requests +in remote mode, including: +- Token retrieval from auth provider +- API request authentication +- Error handling for invalid/expired tokens +- Session context management +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from starlette.exceptions import HTTPException +from starlette.requests import Request +from mcp.server.auth.provider import AccessToken + +from src.api.common import get_access_token, build_request +from src.config.config import RemoteSettings, LocalSettings + + +class TestGetAccessTokenRemoteMode: + """Test access token retrieval in remote mode.""" + + @pytest.fixture + def mock_remote_settings(self): + """Create mock RemoteSettings with auth provider.""" + settings = Mock(spec=RemoteSettings) + settings.is_remote = True + settings.auth_provider = AsyncMock() + return settings + + @pytest.fixture + def mock_local_settings(self): + """Create mock LocalSettings.""" + settings = Mock(spec=LocalSettings) + settings.is_remote = False + settings.api_key = None + settings.jwt_token = "local-jwt-token" + return settings + + @pytest.fixture + def mock_request(self): + """Create mock HTTP request with Authorization header.""" + request = Mock(spec=Request) + request.headers = {"Authorization": "Bearer client-token-123"} + return request + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + @patch("src.api.common.async_to_sync") + def test_get_access_token_remote_success( + self, + mock_async_to_sync, + mock_get_settings, + mock_get_session_request, + mock_remote_settings, + mock_request, + ): + """Test successful token retrieval in remote mode.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_session_request.return_value = mock_request + + # Mock auth provider returning access token + mock_access_token = AccessToken( + token="real-singlestore-token", + client_id="test-client", + scopes=["openid", "profile"], + expires_at=9999999999, + ) + + # Mock async_to_sync to return a sync function that returns the access token + mock_sync_function = Mock(return_value=mock_access_token) + mock_async_to_sync.return_value = mock_sync_function + + # Call function + result = get_access_token() + + # Assertions + assert result == "real-singlestore-token" + mock_get_session_request.assert_called_once() + mock_sync_function.assert_called_once_with("client-token-123") + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + @patch("src.api.common.async_to_sync") + def test_get_access_token_remote_no_token_in_provider( + self, + mock_async_to_sync, + mock_get_settings, + mock_get_session_request, + mock_remote_settings, + mock_request, + ): + """Test remote mode when auth provider returns None (token not found/expired).""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_session_request.return_value = mock_request + + # Mock auth provider returning None (token invalid/expired) + mock_sync_function = Mock(return_value=None) + mock_async_to_sync.return_value = mock_sync_function + + # Should raise HTTPException for unauthorized + with pytest.raises(HTTPException) as exc_info: + get_access_token() + + assert exc_info.value.status_code == 401 + assert "Unauthorized: No access token provided" in str(exc_info.value.detail) + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + @patch("src.api.common.async_to_sync") + def test_get_access_token_remote_no_auth_header( + self, + mock_async_to_sync, + mock_get_settings, + mock_get_session_request, + mock_remote_settings, + ): + """Test remote mode with missing Authorization header.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + + # Mock request without Authorization header + mock_request = Mock(spec=Request) + mock_request.headers = {} + mock_get_session_request.return_value = mock_request + + # Mock auth provider returning None for empty token + mock_sync_function = Mock(return_value=None) + mock_async_to_sync.return_value = mock_sync_function + + # Should raise HTTPException for unauthorized + with pytest.raises(HTTPException) as exc_info: + get_access_token() + + assert exc_info.value.status_code == 401 + assert "Unauthorized: No access token provided" in str(exc_info.value.detail) + + @patch("src.api.common.get_settings") + def test_get_access_token_local_api_key( + self, mock_get_settings, mock_local_settings + ): + """Test local mode with API key.""" + mock_local_settings.api_key = "api-key-123" + mock_local_settings.jwt_token = None + mock_get_settings.return_value = mock_local_settings + + result = get_access_token() + + assert result == "api-key-123" + + @patch("src.api.common.get_settings") + def test_get_access_token_local_jwt_token( + self, mock_get_settings, mock_local_settings + ): + """Test local mode with JWT token.""" + mock_local_settings.api_key = None + mock_local_settings.jwt_token = "jwt-token-123" + mock_get_settings.return_value = mock_local_settings + + result = get_access_token() + + assert result == "jwt-token-123" + + @patch("src.api.common.get_settings") + def test_get_access_token_local_no_tokens( + self, mock_get_settings, mock_local_settings + ): + """Test local mode without any tokens.""" + mock_local_settings.api_key = None + mock_local_settings.jwt_token = None + mock_get_settings.return_value = mock_local_settings + + with pytest.raises(HTTPException) as exc_info: + get_access_token() + + assert exc_info.value.status_code == 401 + + +class TestBuildRequestRemoteMode: + """Test API request building with remote mode authentication.""" + + @pytest.fixture + def mock_remote_settings(self): + """Create mock RemoteSettings.""" + settings = Mock(spec=RemoteSettings) + settings.s2_api_base_url = "https://api.singlestore.com" + settings.is_remote = True + return settings + + @pytest.fixture + def mock_successful_response(self): + """Create mock successful API response.""" + response = Mock() + response.status_code = 200 + response.json.return_value = {"success": True, "data": "test-data"} + return response + + @pytest.fixture + def mock_error_response(self): + """Create mock error API response.""" + response = Mock() + response.status_code = 401 + response.text = "Unauthorized" + return response + + @patch("src.api.common.requests.get") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_build_request_get_success( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_get, + mock_remote_settings, + mock_successful_response, + ): + """Test successful GET request with authentication.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_org_id.return_value = "org-123" + mock_get_access_token.return_value = "bearer-token-123" + mock_requests_get.return_value = mock_successful_response + + # Make request + result = build_request("GET", "test/endpoint", params={"param1": "value1"}) + + # Assertions + assert result == {"success": True, "data": "test-data"} + + # Verify request was made with proper authentication + mock_requests_get.assert_called_once() + args, kwargs = mock_requests_get.call_args + + # Check URL + expected_url = "https://api.singlestore.com/v1/test/endpoint?param1=value1&organizationID=org-123" + assert args[0] == expected_url + + # Check headers + assert "Authorization" in kwargs["headers"] + assert kwargs["headers"]["Authorization"] == "Bearer bearer-token-123" + assert kwargs["headers"]["Content-Type"] == "application/json" + + @patch("src.api.common.requests.post") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_build_request_post_with_data( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_post, + mock_remote_settings, + mock_successful_response, + ): + """Test POST request with JSON data and authentication.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_org_id.return_value = "org-123" + mock_get_access_token.return_value = "bearer-token-123" + mock_requests_post.return_value = mock_successful_response + + test_data = {"key": "value", "number": 123} + + # Make request + result = build_request("POST", "test/endpoint", data=test_data) + + # Assertions + assert result == {"success": True, "data": "test-data"} + + # Verify request + mock_requests_post.assert_called_once() + args, kwargs = mock_requests_post.call_args + + # Check URL + expected_url = ( + "https://api.singlestore.com/v1/test/endpoint?organizationID=org-123" + ) + assert args[0] == expected_url + + # Check headers and data + assert kwargs["headers"]["Authorization"] == "Bearer bearer-token-123" + import json + + assert json.loads(kwargs["data"]) == test_data + + @patch("src.api.common.requests.get") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_build_request_api_error( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_get, + mock_remote_settings, + mock_error_response, + ): + """Test API request that returns an error status.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_org_id.return_value = "org-123" + mock_get_access_token.return_value = "invalid-token" + mock_requests_get.return_value = mock_error_response + + # Should raise HTTPException + with pytest.raises(HTTPException) as exc_info: + build_request("GET", "test/endpoint") + + assert exc_info.value.status_code == 401 + assert "Unauthorized" in str(exc_info.value.detail) + + @patch("src.api.common.get_settings") + @patch("src.api.common.get_access_token") + def test_build_request_no_access_token( + self, mock_get_access_token, mock_get_settings, mock_remote_settings + ): + """Test build_request when access token retrieval fails.""" + mock_get_settings.return_value = mock_remote_settings + mock_get_access_token.side_effect = HTTPException(401, "No token") + + with pytest.raises(HTTPException) as exc_info: + build_request("GET", "test/endpoint") + + assert exc_info.value.status_code == 401 + + @patch("src.api.common.requests.put") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_build_request_put_method( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_put, + mock_remote_settings, + mock_successful_response, + ): + """Test PUT request method.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_org_id.return_value = "org-123" + mock_get_access_token.return_value = "bearer-token-123" + mock_requests_put.return_value = mock_successful_response + + # Make request + build_request("PUT", "test/endpoint", data={"update": "data"}) + + # Verify PUT was called + mock_requests_put.assert_called_once() + + @patch("src.api.common.requests.delete") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_build_request_delete_method( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_delete, + mock_remote_settings, + mock_successful_response, + ): + """Test DELETE request method.""" + # Setup mocks + mock_get_settings.return_value = mock_remote_settings + mock_get_org_id.return_value = "org-123" + mock_get_access_token.return_value = "bearer-token-123" + mock_requests_delete.return_value = mock_successful_response + + # Make request + build_request("DELETE", "test/endpoint") + + # Verify DELETE was called + mock_requests_delete.assert_called_once() + + @patch("src.api.common.get_org_id") # Mock to avoid session issues + @patch("src.api.common.get_access_token") # Mock to avoid auth issues + def test_build_request_unsupported_method( + self, mock_get_access_token, mock_get_org_id, mock_remote_settings + ): + """Test build_request with unsupported HTTP method.""" + # Mock access token to avoid auth issues and reach method validation + mock_get_access_token.return_value = "test-token" + mock_get_org_id.return_value = "test-org-123" + + with patch("src.api.common.get_settings") as mock_get_settings: + mock_get_settings.return_value = mock_remote_settings + + with pytest.raises(ValueError, match="Unsupported request type: INVALID"): + build_request("INVALID", "test/endpoint") + + +class TestAuthenticationIntegrationScenarios: + """Test various authentication integration scenarios.""" + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + @patch("src.api.common.async_to_sync") + def test_token_refresh_scenario( + self, mock_async_to_sync, mock_get_settings, mock_get_session_request + ): + """Test scenario where token needs to be refreshed.""" + # Setup settings + settings = Mock(spec=RemoteSettings) + settings.is_remote = True + settings.auth_provider = AsyncMock() + mock_get_settings.return_value = settings + + # Setup request + request = Mock(spec=Request) + request.headers = {"Authorization": "Bearer expired-token-123"} + mock_get_session_request.return_value = request + + # Mock auth provider returning None first (token expired) + mock_sync_function = Mock(return_value=None) + mock_async_to_sync.return_value = mock_sync_function + + # Should raise unauthorized exception, triggering client to refresh + with pytest.raises(HTTPException) as exc_info: + get_access_token() + + assert exc_info.value.status_code == 401 + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + @patch("src.api.common.async_to_sync") + def test_malformed_authorization_header( + self, mock_async_to_sync, mock_get_settings, mock_get_session_request + ): + """Test scenario with malformed Authorization header.""" + # Setup settings + settings = Mock(spec=RemoteSettings) + settings.is_remote = True + settings.auth_provider = AsyncMock() + mock_get_settings.return_value = settings + + # Setup request with malformed header + request = Mock(spec=Request) + request.headers = {"Authorization": "NotBearer token-123"} # Missing "Bearer " + mock_get_session_request.return_value = request + + # Mock auth provider - should be called with the malformed header value + mock_access_token = AccessToken( + token="real-token", + client_id="test-client", + scopes=["openid"], + expires_at=9999999999, + ) + mock_sync_function = Mock(return_value=mock_access_token) + mock_async_to_sync.return_value = mock_sync_function + + result = get_access_token() + + # Should handle malformed header gracefully - "NotBearer token-123" becomes "Nottoken-123" after replace + mock_sync_function.assert_called_once_with("Nottoken-123") + assert result == "real-token" + + @patch("src.api.common.requests.get") + @patch("src.api.common.get_access_token") + @patch("src.api.common.get_org_id") + @patch("src.api.common.get_settings") + def test_concurrent_request_handling( + self, + mock_get_settings, + mock_get_org_id, + mock_get_access_token, + mock_requests_get, + ): + """Test handling of concurrent API requests with authentication.""" + import threading + import time + + # Setup mocks + settings = Mock(spec=RemoteSettings) + settings.s2_api_base_url = "https://api.singlestore.com" + settings.is_remote = True + mock_get_settings.return_value = settings + mock_get_org_id.return_value = "org-123" + + # Mock token retrieval with slight delay + def mock_token_retrieval(): + time.sleep(0.1) + return "concurrent-token" + + mock_get_access_token.side_effect = mock_token_retrieval + + # Mock successful response + response = Mock() + response.status_code = 200 + response.json.return_value = {"concurrent": True} + mock_requests_get.return_value = response + + results = [] + exceptions = [] + + def make_request(endpoint): + try: + result = build_request("GET", endpoint) + results.append(result) + except Exception as e: + exceptions.append(e) + + # Start multiple concurrent requests + threads = [] + for i in range(5): + thread = threading.Thread(target=make_request, args=(f"test/endpoint{i}",)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=5.0) + + # All requests should succeed + assert len(results) == 5 + assert len(exceptions) == 0 + assert all(result["concurrent"] for result in results) + + @patch("src.api.common.get_session_request") + @patch("src.api.common.get_settings") + def test_auth_provider_exception_handling( + self, mock_get_settings, mock_get_session_request + ): + """Test handling of exceptions from auth provider.""" + # Setup settings + settings = Mock(spec=RemoteSettings) + settings.is_remote = True + settings.auth_provider = AsyncMock() + mock_get_settings.return_value = settings + + # Setup request + request = Mock(spec=Request) + request.headers = {"Authorization": "Bearer test-token"} + mock_get_session_request.return_value = request + + # Mock auth provider raising exception + with patch("src.api.common.async_to_sync") as mock_async_to_sync: + mock_sync_function = Mock(side_effect=Exception("Auth provider error")) + mock_async_to_sync.return_value = mock_sync_function + + # Should handle exception and raise 401 + with pytest.raises(Exception) as exc_info: + get_access_token() + + assert "Auth provider error" in str(exc_info.value) diff --git a/tests/unit/auth/test_remote_auth_flow.py b/tests/unit/auth/test_remote_auth_flow.py new file mode 100644 index 0000000..f569418 --- /dev/null +++ b/tests/unit/auth/test_remote_auth_flow.py @@ -0,0 +1,690 @@ +""" +Test suite for remote mode OAuth authentication flow. + +This module tests the complete remote authentication flow including: +- OAuth provider initialization +- Authorization code generation and exchange +- Token storage and retrieval +- Token validation and expiration +- Error handling scenarios +""" + +import json +import time +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +from mcp.server.auth.provider import AuthorizationCode, AuthorizationParams, AccessToken +from pydantic import AnyHttpUrl +from starlette.exceptions import HTTPException + +from src.auth.provider import SingleStoreOAuthProvider +from src.config.config import RemoteSettings + + +class TestSingleStoreOAuthProvider: + """Test cases for the SingleStoreOAuthProvider class.""" + + @pytest.fixture + def mock_settings(self): + """Create mock RemoteSettings for testing.""" + settings = Mock(spec=RemoteSettings) + settings.client_id = "test-client-id" + settings.org_id = "test-org-id" + settings.oauth_db_url = "mysql://test:test@localhost/test_oauth" + settings.callback_path = "http://localhost:8010/callback" + settings.required_scopes = ["openid", "profile"] + settings.singlestore_auth_url = "https://auth.singlestore.com/authorize" + settings.singlestore_token_url = "https://auth.singlestore.com/token" + settings.s2_api_base_url = "https://api.singlestore.com" + return settings + + @pytest.fixture + def mock_db_connection(self): + """Mock database connection and cursor.""" + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=None) + return mock_conn, mock_cursor + + @pytest.fixture + def oauth_provider(self, mock_settings, mock_db_connection): + """Create OAuth provider instance with mocked dependencies.""" + mock_conn, mock_cursor = mock_db_connection + + with patch.object(SingleStoreOAuthProvider, "_ensure_tables"): + provider = SingleStoreOAuthProvider(mock_settings) + provider._mock_conn = mock_conn + provider._mock_cursor = mock_cursor + + # Mock the _get_conn method to return our mock connection + provider._get_conn = lambda: mock_conn + + return provider + + @pytest.fixture + def sample_client(self): + """Create a sample OAuth client for testing.""" + return OAuthClientInformationFull( + client_id="test-client-123", + client_name="Test Client", + redirect_uris=["http://localhost:3000/callback"], + ) + + def test_provider_initialization(self, mock_settings): + """Test that the OAuth provider initializes correctly.""" + with patch.object(SingleStoreOAuthProvider, "_ensure_tables") as mock_ensure: + provider = SingleStoreOAuthProvider(mock_settings) + + assert provider.settings == mock_settings + assert isinstance(provider.state_mapping, dict) + mock_ensure.assert_called_once() + + def test_ensure_tables_creates_schema(self, oauth_provider): + """Test that database tables are created properly.""" + with patch("builtins.open") as mock_open: + mock_open.return_value.__enter__.return_value.read.return_value = """ + CREATE TABLE oauth_clients (client_id VARCHAR(255) PRIMARY KEY); + CREATE TABLE oauth_tokens (token VARCHAR(255) PRIMARY KEY); + """ + with patch.object(oauth_provider, "_get_conn") as mock_get_conn: + mock_get_conn.return_value = oauth_provider._mock_conn + + oauth_provider._ensure_tables() + + # Should execute SQL statements + oauth_provider._mock_cursor.execute.assert_called() + + @pytest.mark.asyncio + async def test_get_client_existing(self, oauth_provider, sample_client): + """Test retrieving an existing client from database.""" + # Mock database response + oauth_provider._mock_cursor.fetchone.return_value = [ + sample_client.model_dump_json() + ] + + result = await oauth_provider.get_client("test-client-123") + + assert result is not None + assert result.client_id == "test-client-123" + oauth_provider._mock_cursor.execute.assert_called_with( + "SELECT client_info FROM oauth_clients WHERE client_id=%s", + ("test-client-123",), + ) + + @pytest.mark.asyncio + async def test_get_client_not_found(self, oauth_provider): + """Test retrieving a non-existent client.""" + oauth_provider._mock_cursor.fetchone.return_value = None + + result = await oauth_provider.get_client("nonexistent-client") + + assert result is None + + @pytest.mark.asyncio + async def test_register_client(self, oauth_provider, sample_client): + """Test registering a new client.""" + await oauth_provider.register_client(sample_client) + + oauth_provider._mock_cursor.execute.assert_called_with( + "REPLACE INTO oauth_clients (client_id, client_info) VALUES (%s, %s)", + (sample_client.client_id, sample_client.model_dump_json()), + ) + oauth_provider._mock_conn.commit.assert_called_once() + + def test_generate_code_verifier(self, oauth_provider): + """Test PKCE code verifier generation.""" + code_verifier = oauth_provider._generate_code_verifier() + + assert isinstance(code_verifier, str) + assert 43 <= len(code_verifier) <= 128 + assert hasattr(oauth_provider, "singlestore_code_verifier") + assert oauth_provider.singlestore_code_verifier == code_verifier + + def test_generate_code_challenge(self, oauth_provider): + """Test PKCE code challenge generation.""" + code_verifier = "test_code_verifier" + code_challenge = oauth_provider._generate_code_challenge(code_verifier) + + assert isinstance(code_challenge, str) + assert len(code_challenge) > 0 + # Base64 URL-safe encoded SHA256 hash should not have padding + assert "=" not in code_challenge + + @pytest.mark.asyncio + async def test_authorize_creates_auth_url(self, oauth_provider, sample_client): + """Test that authorization creates proper SingleStore OAuth URL.""" + params = AuthorizationParams( + state="test-state", + scopes=["openid", "profile"], + code_challenge="test-challenge", + redirect_uri=AnyHttpUrl("http://localhost:3000/callback"), + redirect_uri_provided_explicitly=True, + ) + + auth_url = await oauth_provider.authorize(sample_client, params) + + assert auth_url.startswith(oauth_provider.settings.singlestore_auth_url) + assert "client_id=" + oauth_provider.settings.client_id in auth_url + assert "redirect_uri=" in auth_url + assert "code_challenge=" in auth_url + assert "state=" in auth_url + + # Check that state mapping was created + state_key = list(oauth_provider.state_mapping.keys())[0] + assert ( + oauth_provider.state_mapping[state_key]["client_id"] + == sample_client.client_id + ) + + @pytest.mark.asyncio + async def test_handle_singlestore_callback(self, oauth_provider, sample_client): + """Test handling callback from SingleStore OAuth.""" + # Set up state mapping + test_state = "test-state-123" + oauth_provider.state_mapping[test_state] = { + "code": "test-code", + "state": test_state, + "redirect_uri": "http://localhost:3000/callback", + "code_challenge": "test-challenge", + "redirect_uri_provided_explicitly": "True", + "client_id": sample_client.client_id, + } + + redirect_url = await oauth_provider.handle_singlestore_callback( + "auth-code-123", test_state + ) + + # Should store authorization code in database + oauth_provider._mock_cursor.execute.assert_called() + oauth_provider._mock_conn.commit.assert_called() + + # Should return redirect URL with code and state + assert "code=auth-code-123" in redirect_url + assert "state=" + test_state in redirect_url + + # Should clean up state mapping + assert test_state not in oauth_provider.state_mapping + + @pytest.mark.asyncio + async def test_handle_callback_invalid_state(self, oauth_provider): + """Test handling callback with invalid state.""" + with pytest.raises(HTTPException) as exc_info: + await oauth_provider.handle_singlestore_callback( + "auth-code", "invalid-state" + ) + + assert exc_info.value.status_code == 400 + assert "Invalid state parameter" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_load_authorization_code(self, oauth_provider, sample_client): + """Test loading authorization code from database.""" + # Mock database response + oauth_provider._mock_cursor.fetchone.return_value = [ + sample_client.client_id, # client_id + "http://localhost:3000/callback", # redirect_uri + True, # redirect_uri_provided_explicitly + time.time() + 300, # expires_at + json.dumps(["openid", "profile"]), # scopes + "test-challenge", # code_challenge + ] + + auth_code = await oauth_provider.load_authorization_code( + sample_client, "test-code" + ) + + assert auth_code is not None + assert auth_code.code == "test-code" + assert auth_code.client_id == sample_client.client_id + assert auth_code.scopes == ["openid", "profile"] + + oauth_provider._mock_cursor.execute.assert_called_with( + "SELECT client_id, redirect_uri, redirect_uri_provided_explicitly, expires_at, scopes, code_challenge FROM oauth_auth_codes WHERE code=%s", + ("test-code",), + ) + + @pytest.mark.asyncio + async def test_load_authorization_code_not_found( + self, oauth_provider, sample_client + ): + """Test loading non-existent authorization code.""" + oauth_provider._mock_cursor.fetchone.return_value = None + + auth_code = await oauth_provider.load_authorization_code( + sample_client, "nonexistent-code" + ) + + assert auth_code is None + + @patch("src.auth.provider.create_mcp_http_client") + @pytest.mark.asyncio + async def test_exchange_authorization_code_success( + self, mock_http_client, oauth_provider, sample_client + ): + """Test successful authorization code exchange.""" + # Mock database check for code existence + oauth_provider._mock_cursor.fetchone.return_value = ["test-code"] + + # Mock HTTP client response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "access-token-123", + "token_type": "Bearer", + "expires_in": 3600, + } + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_http_client.return_value.__aenter__.return_value = mock_client + + # Create authorization code + auth_code = AuthorizationCode( + code="test-code", + client_id=sample_client.client_id, + redirect_uri=AnyHttpUrl("http://localhost:3000/callback"), + redirect_uri_provided_explicitly=True, + expires_at=time.time() + 300, + scopes=["openid", "profile"], + code_challenge="test-challenge", + ) + + oauth_provider.singlestore_code_verifier = "test-verifier" + + token = await oauth_provider.exchange_authorization_code( + sample_client, auth_code + ) + + assert isinstance(token, OAuthToken) + assert token.access_token == "access-token-123" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Should store token in database + assert oauth_provider._mock_cursor.execute.call_count >= 2 + oauth_provider._mock_conn.commit.assert_called() + + @patch("src.auth.provider.create_mcp_http_client") + @pytest.mark.asyncio + async def test_exchange_authorization_code_invalid_code( + self, mock_http_client, oauth_provider, sample_client + ): + """Test exchange with invalid authorization code.""" + # Mock database check - code not found + oauth_provider._mock_cursor.fetchone.return_value = None + + auth_code = AuthorizationCode( + code="invalid-code", + client_id=sample_client.client_id, + redirect_uri=AnyHttpUrl("http://localhost:3000/callback"), + redirect_uri_provided_explicitly=True, + expires_at=time.time() + 300, + scopes=["openid", "profile"], + code_challenge="test-challenge", + ) + + with pytest.raises(ValueError, match="Invalid authorization code"): + await oauth_provider.exchange_authorization_code(sample_client, auth_code) + + @patch("src.auth.provider.create_mcp_http_client") + @pytest.mark.asyncio + async def test_exchange_authorization_code_singlestore_error( + self, mock_http_client, oauth_provider, sample_client + ): + """Test exchange when SingleStore returns an error.""" + # Mock database check for code existence + oauth_provider._mock_cursor.fetchone.return_value = ["test-code"] + + # Mock HTTP client error response + mock_response = AsyncMock() + mock_response.status_code = 400 + mock_response.text = "Invalid request" + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_http_client.return_value.__aenter__.return_value = mock_client + + auth_code = AuthorizationCode( + code="test-code", + client_id=sample_client.client_id, + redirect_uri=AnyHttpUrl("http://localhost:3000/callback"), + redirect_uri_provided_explicitly=True, + expires_at=time.time() + 300, + scopes=["openid", "profile"], + code_challenge="test-challenge", + ) + + oauth_provider.singlestore_code_verifier = "test-verifier" + + with pytest.raises(HTTPException) as exc_info: + await oauth_provider.exchange_authorization_code(sample_client, auth_code) + + assert exc_info.value.status_code == 400 + + @pytest.mark.asyncio + async def test_load_access_token_valid(self, oauth_provider): + """Test loading a valid access token.""" + future_time = int(time.time() + 3600) # Convert to int + oauth_provider._mock_cursor.fetchone.return_value = [ + "test-client-id", # client_id + json.dumps(["openid", "profile"]), # scopes + future_time, # expires_at + ] + + with patch.object(oauth_provider, "get_client") as mock_get_client: + mock_client = Mock() + mock_client.client_name = "Test Client" + mock_get_client.return_value = mock_client + + with patch.object(oauth_provider, "get_user_id", return_value="user-123"): + with patch("src.auth.provider.get_settings") as mock_get_settings: + mock_settings = Mock() + mock_settings.analytics_manager.identify = Mock() + mock_get_settings.return_value = mock_settings + + access_token = await oauth_provider.load_access_token("valid-token") + + assert access_token is not None + assert isinstance(access_token, AccessToken) + assert access_token.token == "valid-token" + assert access_token.client_id == "test-client-id" + assert access_token.scopes == ["openid", "profile"] + assert access_token.expires_at == future_time + + @pytest.mark.asyncio + async def test_load_access_token_not_found(self, oauth_provider): + """Test loading non-existent access token.""" + oauth_provider._mock_cursor.fetchone.return_value = None + + access_token = await oauth_provider.load_access_token("nonexistent-token") + + assert access_token is None + + @pytest.mark.asyncio + async def test_load_access_token_expired(self, oauth_provider): + """Test loading expired access token.""" + past_time = time.time() - 3600 + oauth_provider._mock_cursor.fetchone.return_value = [ + "test-client-id", + json.dumps(["openid", "profile"]), + past_time, + ] + + access_token = await oauth_provider.load_access_token("expired-token") + + assert access_token is None + # Should delete expired token + oauth_provider._mock_cursor.execute.assert_called_with( + "DELETE FROM oauth_tokens WHERE token=%s", ("expired-token",) + ) + + @pytest.mark.asyncio + async def test_revoke_token(self, oauth_provider): + """Test token revocation.""" + await oauth_provider.revoke_token("token-to-revoke") + + oauth_provider._mock_cursor.execute.assert_called_with( + "DELETE FROM oauth_tokens WHERE token=%s", ("token-to-revoke",) + ) + oauth_provider._mock_conn.commit.assert_called_once() + + @patch("src.auth.provider.requests.get") + def test_get_user_id_success(self, mock_get, oauth_provider): + """Test successful user ID extraction.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"userID": "user-123", "email": "test@example.com"} + ] + mock_get.return_value = mock_response + + with patch("src.auth.provider.get_settings") as mock_get_settings: + mock_settings = Mock() + mock_settings.s2_api_base_url = "https://api.singlestore.com" + mock_settings.org_id = "org-123" + mock_get_settings.return_value = mock_settings + + user_id = oauth_provider.get_user_id("valid-token") + + assert user_id == "user-123" + mock_get.assert_called_once() + + @patch("src.auth.provider.requests.get") + def test_get_user_id_api_error(self, mock_get, oauth_provider): + """Test user ID extraction with API error.""" + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_get.return_value = mock_response + + with patch("src.auth.provider.get_settings") as mock_get_settings: + mock_settings = Mock() + mock_settings.s2_api_base_url = "https://api.singlestore.com" + mock_settings.org_id = "org-123" + mock_get_settings.return_value = mock_settings + + with pytest.raises(HTTPException) as exc_info: + oauth_provider.get_user_id("invalid-token") + + assert exc_info.value.status_code == 401 + + +class TestRemoteAuthFlowIntegration: + """Integration tests for the complete remote auth flow.""" + + @pytest.fixture + def mock_settings(self): + """Create mock settings for integration tests.""" + settings = Mock(spec=RemoteSettings) + settings.client_id = "integration-client-id" + settings.org_id = "integration-org-id" + settings.oauth_db_url = "mysql://test:test@localhost/integration_test" + settings.callback_path = "http://localhost:8010/callback" + settings.required_scopes = ["openid", "profile", "email"] + settings.singlestore_auth_url = "https://auth.singlestore.com/authorize" + settings.singlestore_token_url = "https://auth.singlestore.com/token" + settings.s2_api_base_url = "https://api.singlestore.com" + return settings + + @pytest.fixture + def sample_client(self): + """Create sample client for integration tests.""" + return OAuthClientInformationFull( + client_id="integration-client-123", + client_name="Integration Test Client", + redirect_uris=["http://localhost:3000/callback"], + ) + + @patch("src.auth.provider.s2.connect") + @patch("src.auth.provider.create_mcp_http_client") + @patch("src.auth.provider.requests.get") + @pytest.mark.asyncio + async def test_complete_auth_flow_success( + self, + mock_requests_get, + mock_http_client, + mock_s2_connect, + mock_settings, + sample_client, + ): + """Test the complete authentication flow from start to finish.""" + # Setup mocks + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=None) + mock_s2_connect.return_value = mock_conn + + # Mock SingleStore token exchange + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "singlestore-access-token", + "token_type": "Bearer", + "expires_in": 3600, + } + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_http_client.return_value.__aenter__.return_value = mock_client + + # Mock user API response + mock_user_response = Mock() + mock_user_response.status_code = 200 + mock_user_response.json.return_value = [{"userID": "test-user-123"}] + mock_requests_get.return_value = mock_user_response + + # Initialize provider + with patch.object(SingleStoreOAuthProvider, "_ensure_tables"): + provider = SingleStoreOAuthProvider(mock_settings) + + # Step 1: Register client + await provider.register_client(sample_client) + + # Step 2: Generate authorization URL + params = AuthorizationParams( + state="integration-test-state", + scopes=["openid", "profile", "email"], + code_challenge="test-challenge-integration", + redirect_uri=AnyHttpUrl("http://localhost:3000/callback"), + redirect_uri_provided_explicitly=True, + ) + + auth_url = await provider.authorize(sample_client, params) + assert "https://auth.singlestore.com/authorize" in auth_url + + # Step 3: Handle callback from SingleStore + test_state = list(provider.state_mapping.keys())[0] + redirect_url = await provider.handle_singlestore_callback( + "singlestore-auth-code", test_state + ) + assert "code=singlestore-auth-code" in redirect_url + + # Step 4: Exchange authorization code for token + # Mock code exists in DB - reset mock for this step + mock_cursor.fetchone.side_effect = None + mock_cursor.fetchone.return_value = [ + sample_client.client_id, + "http://localhost:3000/callback", + True, + time.time() + 300, + json.dumps(["openid", "profile", "email"]), + "test-challenge-integration", + ] + + auth_code = await provider.load_authorization_code( + sample_client, "singlestore-auth-code" + ) + assert auth_code is not None + + provider.singlestore_code_verifier = "integration-verifier" + token = await provider.exchange_authorization_code(sample_client, auth_code) + + assert token.access_token == "singlestore-access-token" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Step 5: Load and validate access token + # Mock token lookup + future_time = time.time() + 3600 + mock_cursor.fetchone.side_effect = [ + [ + sample_client.client_id, + json.dumps(["openid", "profile", "email"]), + future_time, + ], # Token lookup + [sample_client.model_dump_json()], # Client lookup + ] + + with patch("src.auth.provider.get_settings") as mock_get_settings: + mock_analytics = Mock() + mock_analytics.analytics_manager.identify = Mock() + mock_get_settings.return_value = mock_analytics + + with patch("src.auth.provider.set_user_id") as mock_set_user_id: + access_token = await provider.load_access_token( + "singlestore-access-token" + ) + + assert access_token is not None + assert access_token.token == "singlestore-access-token" + assert access_token.client_id == sample_client.client_id + + # Verify user ID was extracted and set + mock_set_user_id.assert_called_with("test-user-123") + + @patch("src.auth.provider.s2.connect") + @pytest.mark.asyncio + async def test_auth_flow_with_expired_tokens(self, mock_s2_connect, mock_settings): + """Test auth flow behavior with expired tokens.""" + # Setup mocks + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=None) + mock_s2_connect.return_value = mock_conn + + with patch.object(SingleStoreOAuthProvider, "_ensure_tables"): + provider = SingleStoreOAuthProvider(mock_settings) + + # Mock expired token + past_time = time.time() - 3600 + mock_cursor.fetchone.return_value = [ + "test-client", + json.dumps(["openid"]), + past_time, + ] + + access_token = await provider.load_access_token("expired-token") + + # Should return None and delete expired token + assert access_token is None + mock_cursor.execute.assert_any_call( + "DELETE FROM oauth_tokens WHERE token=%s", ("expired-token",) + ) + + @patch("src.auth.provider.s2.connect") + @pytest.mark.asyncio + async def test_error_handling_scenarios(self, mock_s2_connect, mock_settings): + """Test various error scenarios in the auth flow.""" + # Setup mocks + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=None) + mock_s2_connect.return_value = mock_conn + + with patch.object(SingleStoreOAuthProvider, "_ensure_tables"): + provider = SingleStoreOAuthProvider(mock_settings) + + # Test 1: Invalid state in callback + with pytest.raises(HTTPException, match="Invalid state parameter"): + await provider.handle_singlestore_callback("code", "invalid-state") + + # Test 2: Non-existent authorization code + mock_cursor.fetchone.return_value = None + with pytest.raises(ValueError, match="Invalid authorization code"): + auth_code = AuthorizationCode( + code="nonexistent", + client_id="test", + redirect_uri=AnyHttpUrl("http://localhost/callback"), + redirect_uri_provided_explicitly=True, + expires_at=time.time() + 300, + scopes=["openid"], + code_challenge="challenge", + ) + await provider.exchange_authorization_code( + OAuthClientInformationFull( + client_id="test", + client_name="test", + redirect_uris=[AnyHttpUrl("http://localhost/callback")], + ), + auth_code, + ) diff --git a/tests/unit/auth/test_runner.py b/tests/unit/auth/test_runner.py new file mode 100755 index 0000000..6b6c3d3 --- /dev/null +++ b/tests/unit/auth/test_runner.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Test runner for remote mode authentication tests. + +This script runs the complete test suite for remote authentication, +including unit tests and integration tests. +""" + +import subprocess +import sys +from pathlib import Path + + +def run_tests(): + """Run all remote authentication tests.""" + # Get the project root directory + project_root = Path(__file__).parent.parent.parent.parent + + # Test directories + auth_test_dir = project_root / "tests" / "unit" / "auth" + + # Test files to run + test_files = [ + "test_remote_auth_flow.py", + "test_oauth_proxy_integration.py", + "test_remote_api_auth.py", + ] + + print("๐Ÿงช Running Remote Mode Authentication Tests") + print("=" * 50) + + # Check if test files exist + missing_files = [] + for test_file in test_files: + if not (auth_test_dir / test_file).exists(): + missing_files.append(test_file) + + if missing_files: + print(f"โŒ Missing test files: {', '.join(missing_files)}") + return 1 + + # Run pytest for each test file + failed_tests = [] + + for test_file in test_files: + test_path = auth_test_dir / test_file + print(f"\n๐Ÿ”„ Running {test_file}...") + + try: + # Run pytest with verbose output + result = subprocess.run( + [ + sys.executable, + "-m", + "pytest", + str(test_path), + "-v", + "--tb=short", + "--color=yes", + ], + cwd=project_root, + capture_output=True, + text=True, + ) + + if result.returncode == 0: + print(f"โœ… {test_file} passed") + # Show summary of passed tests + lines = result.stdout.split("\n") + for line in lines: + if "passed" in line and ("failed" in line or "error" in line): + print(f" {line.strip()}") + else: + print(f"โŒ {test_file} failed") + failed_tests.append(test_file) + # Show error details + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + + except Exception as e: + print(f"๐Ÿ’ฅ Error running {test_file}: {e}") + failed_tests.append(test_file) + + # Summary + print("\n" + "=" * 50) + print("๐Ÿ“Š Test Summary") + print("=" * 50) + + total_tests = len(test_files) + passed_tests = total_tests - len(failed_tests) + + print(f"Total test files: {total_tests}") + print(f"Passed: {passed_tests}") + print(f"Failed: {len(failed_tests)}") + + if failed_tests: + print(f"\nโŒ Failed tests: {', '.join(failed_tests)}") + return 1 + else: + print("\n๐ŸŽ‰ All tests passed!") + return 0 + + +def run_specific_test(test_name): + """Run a specific test or test method.""" + project_root = Path(__file__).parent.parent.parent.parent + auth_test_dir = project_root / "tests" / "unit" / "auth" + + print(f"๐Ÿงช Running specific test: {test_name}") + print("=" * 50) + + try: + result = subprocess.run( + [ + sys.executable, + "-m", + "pytest", + str(auth_test_dir), + "-k", + test_name, + "-v", + "--tb=long", + "--color=yes", + ], + cwd=project_root, + ) + + return result.returncode + + except Exception as e: + print(f"๐Ÿ’ฅ Error running test {test_name}: {e}") + return 1 + + +def run_coverage(): + """Run tests with coverage reporting.""" + project_root = Path(__file__).parent.parent.parent.parent + auth_test_dir = project_root / "tests" / "unit" / "auth" + + print("๐Ÿงช Running tests with coverage...") + print("=" * 50) + + try: + # Run pytest with coverage + result = subprocess.run( + [ + sys.executable, + "-m", + "pytest", + str(auth_test_dir), + "--cov=src.auth", + "--cov=src.api.common", + "--cov=src.config.config", + "--cov-report=html", + "--cov-report=term-missing", + "-v", + ], + cwd=project_root, + ) + + if result.returncode == 0: + print("\n๐Ÿ“ˆ Coverage report generated in htmlcov/") + + return result.returncode + + except Exception as e: + print(f"๐Ÿ’ฅ Error running coverage: {e}") + return 1 + + +def main(): + """Main entry point for test runner.""" + if len(sys.argv) > 1: + command = sys.argv[1] + + if command == "--coverage": + return run_coverage() + elif command == "--test": + if len(sys.argv) > 2: + return run_specific_test(sys.argv[2]) + else: + print("โŒ Please provide a test name: --test ") + return 1 + elif command == "--help": + print("Remote Auth Test Runner") + print("Usage:") + print(" python test_runner.py # Run all tests") + print(" python test_runner.py --coverage # Run with coverage") + print(" python test_runner.py --test # Run specific test") + print(" python test_runner.py --help # Show this help") + return 0 + else: + print(f"โŒ Unknown command: {command}") + print("Use --help for usage information") + return 1 + else: + return run_tests() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/auth/test_single_store_oauth_proxy.py b/tests/unit/auth/test_single_store_oauth_proxy.py new file mode 100644 index 0000000..109626c --- /dev/null +++ b/tests/unit/auth/test_single_store_oauth_proxy.py @@ -0,0 +1,965 @@ +"""Unit tests for SingleStoreOAuthProxy class.""" + +import pytest +import json +from unittest.mock import Mock, patch +from fastmcp.server.auth.oauth_proxy import OAuthProxy +from mcp.server.auth.provider import AccessToken +import jwt + +from src.auth.proxy_provider import SingleStoreOAuthProxy + + +class TestSingleStoreOAuthProxy: + """Test cases for SingleStoreOAuthProxy class.""" + + def setup_method(self): + """Set up test fixtures before each test method.""" + self.test_issuer_url = "https://authsvc.singlestore.com/" + self.test_client_id = "b7dbf19e-d140-4334-bae4-e8cd03614485" + self.test_client_secret = "test-secret" + self.test_base_url = "http://localhost:8010/" + self.test_redirect_path = "/callback" + self.test_valid_scopes = ["openid", "profile"] + self.test_jwt_signing_key = "test-jwt-key" + + # Mock OpenID configuration response + self.mock_openid_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["ES512"], + } + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_init_success(self, mock_oauth_proxy, mock_jwks_client, mock_requests_get): + """Test successful initialization of SingleStoreOAuthProxy.""" + # Arrange + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = self.mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + # Act + proxy = SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + client_secret=self.test_client_secret, + base_url=self.test_base_url, + redirect_path=self.test_redirect_path, + valid_scopes=self.test_valid_scopes, + jwt_signing_key=self.test_jwt_signing_key, + ) + + # Assert + assert proxy.issuer_url == self.test_issuer_url + assert proxy.client_id == self.test_client_id + assert proxy.client_secret == self.test_client_secret + assert proxy.base_url == self.test_base_url + assert proxy.redirect_path == self.test_redirect_path + assert proxy.valid_scopes == self.test_valid_scopes + assert proxy.jwt_signing_key == self.test_jwt_signing_key + + # Verify OpenID config was fetched + expected_config_url = ( + "https://authsvc.singlestore.com/.well-known/openid-configuration" + ) + mock_requests_get.assert_called_once_with(expected_config_url, timeout=10.0) + + # Verify JWKS client was created + mock_jwks_client.assert_called_once_with(self.mock_openid_config["jwks_uri"]) + + # Verify OAuth proxy was created with correct parameters + mock_oauth_proxy.assert_called_once() + call_kwargs = mock_oauth_proxy.call_args[1] + assert ( + call_kwargs["upstream_authorization_endpoint"] + == self.mock_openid_config["authorization_endpoint"] + ) + assert ( + call_kwargs["upstream_token_endpoint"] + == self.mock_openid_config["token_endpoint"] + ) + assert call_kwargs["upstream_client_id"] == self.test_client_id + assert call_kwargs["upstream_client_secret"] == self.test_client_secret + assert call_kwargs["base_url"] == self.test_base_url + assert call_kwargs["redirect_path"] == self.test_redirect_path + assert call_kwargs["valid_scopes"] == self.test_valid_scopes + assert call_kwargs["jwt_signing_key"] == self.test_jwt_signing_key + + # Verify provider property works + assert proxy.provider == mock_oauth_proxy_instance + + def test_init_with_default_values(self): + """Test initialization with default values.""" + with ( + patch("requests.get") as mock_requests_get, + patch("src.auth.proxy_provider.PyJWKClient"), + patch("src.auth.proxy_provider.OAuthProxy"), + ): + # Arrange + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = self.mock_openid_config + mock_requests_get.return_value = mock_response + + # Act + proxy = SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + # Assert default values + assert proxy.client_secret == "-" + assert proxy.base_url == "http://localhost:8010/" + assert proxy.redirect_path == "/callback" + assert proxy.valid_scopes == ["openid"] + + @patch("requests.get") + def test_fetch_openid_config_network_error(self, mock_requests_get): + """Test OpenID configuration fetch with network error.""" + # Arrange + mock_requests_get.side_effect = ConnectionError("Network error") + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Failed to fetch OpenID configuration" in str(exc_info.value) + assert "Network error" in str(exc_info.value) + + @patch("requests.get") + def test_fetch_openid_config_http_error(self, mock_requests_get): + """Test OpenID configuration fetch with HTTP error.""" + # Arrange + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception("HTTP 404 Not Found") + mock_requests_get.return_value = mock_response + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Failed to fetch OpenID configuration" in str(exc_info.value) + assert "HTTP 404 Not Found" in str(exc_info.value) + + @patch("requests.get") + def test_fetch_openid_config_json_error(self, mock_requests_get): + """Test OpenID configuration fetch with JSON parsing error.""" + # Arrange + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + mock_requests_get.return_value = mock_response + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Failed to fetch OpenID configuration" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + def test_create_verifier_missing_jwks_uri( + self, mock_jwks_client, mock_requests_get + ): + """Test verifier creation with missing jwks_uri in config.""" + # Arrange + incomplete_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + # Missing jwks_uri + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = incomplete_config + mock_requests_get.return_value = mock_response + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Missing required fields in OpenID configuration" in str(exc_info.value) + assert "jwks_uri=None" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + def test_create_verifier_missing_issuer(self, mock_jwks_client, mock_requests_get): + """Test verifier creation with missing issuer in config.""" + # Arrange + incomplete_config = { + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + # Missing issuer + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = incomplete_config + mock_requests_get.return_value = mock_response + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Missing required fields in OpenID configuration" in str(exc_info.value) + assert "issuer=None" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_create_oauth_proxy_missing_authorization_endpoint( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test OAuth proxy creation with missing authorization_endpoint.""" + # Arrange + incomplete_config = { + "issuer": "https://authsvc.singlestore.com", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + # Missing authorization_endpoint + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = incomplete_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Missing required fields in OpenID configuration" in str(exc_info.value) + assert "authorization_endpoint=None" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_create_oauth_proxy_missing_token_endpoint( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test OAuth proxy creation with missing token_endpoint.""" + # Arrange + incomplete_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + # Missing token_endpoint + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = incomplete_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + assert "Missing required fields in OpenID configuration" in str(exc_info.value) + assert "token_endpoint=None" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_create_oauth_proxy_missing_jwt_signing_key( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test OAuth proxy creation with missing JWT signing key.""" + # Arrange + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = self.mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + # Missing jwt_signing_key + ) + + assert "JWT signing key is not set" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_get_provider(self, mock_oauth_proxy, mock_jwks_client, mock_requests_get): + """Test get_provider method returns the OAuth proxy instance.""" + # Arrange + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = self.mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + proxy = SingleStoreOAuthProxy( + issuer_url=self.test_issuer_url, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + # Act + provider = proxy.get_provider() + + # Assert + assert provider == mock_oauth_proxy_instance + assert provider == proxy.provider + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_issuer_url_normalization( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test that issuer URL is properly normalized for OpenID config URL.""" + # Arrange + test_issuer_without_trailing_slash = "https://authsvc.singlestore.com" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = self.mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + # Act + proxy = SingleStoreOAuthProxy( + issuer_url=test_issuer_without_trailing_slash, + client_id=self.test_client_id, + jwt_signing_key=self.test_jwt_signing_key, + ) + + # Assert + expected_config_url = ( + "https://authsvc.singlestore.com/.well-known/openid-configuration" + ) + mock_requests_get.assert_called_once_with(expected_config_url, timeout=10.0) + assert proxy.openid_config_url == expected_config_url + + +class TestCustomJWTVerifier: + """Test cases for the CustomJWTVerifier inner class.""" + + def setup_method(self): + """Set up test fixtures before each test method.""" + self.test_jwks_uri = "https://authsvc.singlestore.com/.well-known/jwks.json" + self.test_issuer = "https://authsvc.singlestore.com" + self.test_audience = "b7dbf19e-d140-4334-bae4-e8cd03614485" + self.test_base_url = "http://localhost:8010/" + self.test_required_scopes = ["openid"] + + # Mock token data + self.mock_decoded_token = { + "iss": self.test_issuer, + "aud": [self.test_audience], + "client_id": self.test_audience, + "exp": 1734567890, + "iat": 1734564290, + "sub": "user123", + } + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_verify_token_success( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test successful token verification.""" + # Arrange + mock_openid_config = { + "issuer": self.test_issuer, + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": self.test_jwks_uri, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_signing_key = Mock() + mock_jwks_client_instance = Mock() + mock_jwks_client_instance.get_signing_key_from_jwt.return_value = ( + mock_signing_key + ) + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + test_token = "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9..." + + with patch("jwt.decode") as mock_jwt_decode: + mock_jwt_decode.return_value = self.mock_decoded_token + + # Create proxy which will create the verifier + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id=self.test_audience, + jwt_signing_key="test-key", + ) + + # Act + import asyncio + + result = asyncio.run(proxy._verifier.verify_token(test_token)) + + # Assert + assert result is not None + assert isinstance(result, AccessToken) + assert result.token == test_token + assert result.client_id == self.test_audience + assert result.scopes == ["openid"] + assert result.expires_at == 1734567890 + assert result.resource == self.test_audience + + # Verify JWT validation was called with correct parameters + mock_jwt_decode.assert_called_once_with( + test_token, + mock_signing_key, + audience=self.test_audience, + options={"verify_exp": True}, + algorithms=["ES512"], + ) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_verify_token_audience_list( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test token verification with audience as a list.""" + # Arrange + mock_openid_config = { + "issuer": self.test_issuer, + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": self.test_jwks_uri, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_signing_key = Mock() + mock_jwks_client_instance = Mock() + mock_jwks_client_instance.get_signing_key_from_jwt.return_value = ( + mock_signing_key + ) + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + test_token = "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9..." + + # Mock decoded token with audience as list + mock_decoded_with_list_aud = self.mock_decoded_token.copy() + mock_decoded_with_list_aud["aud"] = [self.test_audience, "other-audience"] + + with patch("jwt.decode") as mock_jwt_decode: + mock_jwt_decode.return_value = mock_decoded_with_list_aud + + # Create proxy which will create the verifier + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id=self.test_audience, + jwt_signing_key="test-key", + ) + + # Act + import asyncio + + result = asyncio.run(proxy._verifier.verify_token(test_token)) + + # Assert + assert result is not None + assert isinstance(result, AccessToken) + assert result.resource == self.test_audience # First item in list + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_verify_token_audience_string( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test token verification with audience as a string.""" + # Arrange + mock_openid_config = { + "issuer": self.test_issuer, + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": self.test_jwks_uri, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_signing_key = Mock() + mock_jwks_client_instance = Mock() + mock_jwks_client_instance.get_signing_key_from_jwt.return_value = ( + mock_signing_key + ) + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + test_token = "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9..." + + # Mock decoded token with audience as string + mock_decoded_with_str_aud = self.mock_decoded_token.copy() + mock_decoded_with_str_aud["aud"] = self.test_audience + + with patch("jwt.decode") as mock_jwt_decode: + mock_jwt_decode.return_value = mock_decoded_with_str_aud + + # Create proxy which will create the verifier + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id=self.test_audience, + jwt_signing_key="test-key", + ) + + # Act + import asyncio + + result = asyncio.run(proxy._verifier.verify_token(test_token)) + + # Assert + assert result is not None + assert isinstance(result, AccessToken) + assert result.resource == self.test_audience # String value + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_verify_token_jwt_error( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test token verification with JWT validation error.""" + # Arrange + mock_openid_config = { + "issuer": self.test_issuer, + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": self.test_jwks_uri, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_signing_key = Mock() + mock_jwks_client_instance = Mock() + mock_jwks_client_instance.get_signing_key_from_jwt.return_value = ( + mock_signing_key + ) + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + test_token = "invalid-token" + + with ( + patch("jwt.decode") as mock_jwt_decode, + patch("builtins.print") as mock_print, + ): + mock_jwt_decode.side_effect = jwt.PyJWTError("Invalid token") + + # Create proxy which will create the verifier + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id=self.test_audience, + jwt_signing_key="test-key", + ) + + # Act + import asyncio + + result = asyncio.run(proxy._verifier.verify_token(test_token)) + + # Assert + assert result is None + mock_print.assert_called_with( + "Token validation error:", mock_jwt_decode.side_effect + ) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_verify_token_jwks_client_error( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test token verification with JWKS client error.""" + # Arrange + mock_openid_config = { + "issuer": self.test_issuer, + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": self.test_jwks_uri, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + test_token = "test-token" + + # Create proxy which will create the verifier + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id=self.test_audience, + jwt_signing_key="test-key", + ) + + # Now set up the error to occur during verification + mock_jwks_client_instance.get_signing_key_from_jwt.side_effect = Exception( + "JWKS client error" + ) + + # Act & Assert + import asyncio + + with pytest.raises(Exception) as exc_info: + asyncio.run(proxy._verifier.verify_token(test_token)) + + assert "JWKS client error" in str(exc_info.value) + + +class TestSingleStoreOAuthProxyIntegration: + """Integration test cases for SingleStoreOAuthProxy.""" + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_complete_initialization_workflow( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test the complete initialization workflow with all components.""" + # Arrange + mock_openid_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + "response_types_supported": ["code"], + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy_instance = Mock(spec=OAuthProxy) + mock_oauth_proxy.return_value = mock_oauth_proxy_instance + + # Act + proxy = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="test-client-id", + client_secret="test-secret", + base_url="https://myapp.example.com", + redirect_path="/oauth/callback", + valid_scopes=["openid", "profile", "email"], + jwt_signing_key="my-jwt-secret", + ) + + # Assert configuration was loaded + assert proxy._config == mock_openid_config + + # Assert verifier was created with correct parameters + assert proxy._verifier is not None + + # Assert OAuth proxy was created and is accessible + provider = proxy.get_provider() + assert provider == mock_oauth_proxy_instance + + # Verify the OAuth proxy was called with the correct configuration + call_kwargs = mock_oauth_proxy.call_args[1] + assert ( + call_kwargs["upstream_authorization_endpoint"] + == "https://authsvc.singlestore.com/authorize" + ) + assert ( + call_kwargs["upstream_token_endpoint"] + == "https://authsvc.singlestore.com/token" + ) + assert call_kwargs["upstream_client_id"] == "test-client-id" + assert call_kwargs["upstream_client_secret"] == "test-secret" + assert call_kwargs["base_url"] == "https://myapp.example.com" + assert call_kwargs["redirect_path"] == "/oauth/callback" + assert call_kwargs["valid_scopes"] == ["openid", "profile", "email"] + assert call_kwargs["jwt_signing_key"] == "my-jwt-secret" + + def test_multiple_proxy_instances(self): + """Test creating multiple proxy instances with different configurations.""" + with ( + patch("requests.get") as mock_requests_get, + patch("src.auth.proxy_provider.PyJWKClient"), + patch("src.auth.proxy_provider.OAuthProxy") as mock_oauth_proxy, + ): + # Arrange + mock_openid_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_oauth_proxy.side_effect = [ + Mock(spec=OAuthProxy), + Mock(spec=OAuthProxy), + ] + + # Act - Create two different proxy instances + proxy1 = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="client-1", + jwt_signing_key="key-1", + ) + + proxy2 = SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="client-2", + jwt_signing_key="key-2", + base_url="https://different.example.com", + valid_scopes=["custom", "scopes"], + ) + + # Assert + assert proxy1.client_id == "client-1" + assert proxy2.client_id == "client-2" + assert proxy1.base_url == "http://localhost:8010/" + assert proxy2.base_url == "https://different.example.com" + assert proxy1.valid_scopes == ["openid"] + assert proxy2.valid_scopes == ["custom", "scopes"] + assert proxy1.provider != proxy2.provider + + +class TestSingleStoreOAuthProxyErrorHandling: + """Test cases for error handling scenarios in SingleStoreOAuthProxy.""" + + def test_invalid_issuer_url_format(self): + """Test initialization with invalid issuer URL format.""" + with patch("requests.get") as mock_requests_get: + mock_requests_get.side_effect = Exception("Invalid URL") + + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url="not-a-valid-url", + client_id="test-client", + jwt_signing_key="test-key", + ) + + assert "Failed to fetch OpenID configuration" in str(exc_info.value) + + @patch("requests.get") + def test_timeout_during_config_fetch(self, mock_requests_get): + """Test timeout during OpenID configuration fetch.""" + # Arrange + import requests + + mock_requests_get.side_effect = requests.Timeout("Request timed out") + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="test-client", + jwt_signing_key="test-key", + ) + + assert "Failed to fetch OpenID configuration" in str(exc_info.value) + assert "Request timed out" in str(exc_info.value) + + @patch("requests.get") + def test_malformed_openid_config(self, mock_requests_get): + """Test handling of malformed OpenID configuration.""" + # Arrange + malformed_config = {"invalid": "config", "missing_required_fields": True} + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = malformed_config + mock_requests_get.return_value = mock_response + + # Act & Assert + with pytest.raises(RuntimeError) as exc_info: + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="test-client", + jwt_signing_key="test-key", + ) + + assert "Missing required fields in OpenID configuration" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + def test_jwks_client_initialization_failure( + self, mock_jwks_client, mock_requests_get + ): + """Test handling of JWKS client initialization failure.""" + # Arrange + mock_openid_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client.side_effect = Exception("JWKS client initialization failed") + + # Act & Assert + with pytest.raises(Exception) as exc_info: + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="test-client", + jwt_signing_key="test-key", + ) + + assert "JWKS client initialization failed" in str(exc_info.value) + + @patch("requests.get") + @patch("src.auth.proxy_provider.PyJWKClient") + @patch("src.auth.proxy_provider.OAuthProxy") + def test_oauth_proxy_initialization_failure( + self, mock_oauth_proxy, mock_jwks_client, mock_requests_get + ): + """Test handling of OAuth proxy initialization failure.""" + # Arrange + mock_openid_config = { + "issuer": "https://authsvc.singlestore.com", + "authorization_endpoint": "https://authsvc.singlestore.com/authorize", + "token_endpoint": "https://authsvc.singlestore.com/token", + "jwks_uri": "https://authsvc.singlestore.com/.well-known/jwks.json", + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = mock_openid_config + mock_requests_get.return_value = mock_response + + mock_jwks_client_instance = Mock() + mock_jwks_client.return_value = mock_jwks_client_instance + + mock_oauth_proxy.side_effect = Exception("OAuth proxy initialization failed") + + # Act & Assert + with pytest.raises(Exception) as exc_info: + SingleStoreOAuthProxy( + issuer_url="https://authsvc.singlestore.com/", + client_id="test-client", + jwt_signing_key="test-key", + ) + + assert "OAuth proxy initialization failed" in str(exc_info.value)