diff --git a/demo/auth.py b/demo/auth.py index 065d222b..f3633717 100644 --- a/demo/auth.py +++ b/demo/auth.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, Request from fastui import AnyComponent, FastUI from fastui import components as c -from fastui.auth import AuthRedirect, GitHubAuthProvider +from fastui.auth import AuthRedirect, GitHubAuthProvider, GoogleAuthProvider from fastui.events import AuthEvent, GoToEvent, PageEvent from fastui.forms import fastui_form from httpx import AsyncClient @@ -27,6 +27,11 @@ GITHUB_REDIRECT = os.getenv('GITHUB_REDIRECT') +GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID', 'yourkey.apps.googleusercontent.com') +GOOGLE_CLIENT_SECRET = SecretStr(os.getenv('GOOGLE_CLIENT_SECRET', 'yoursecret')) +GOOGLE_REDIRECT_URI = os.getenv('GOOGLE_REDIRECT_URI', 'http://localhost:8000/auth/login/google/redirect') + + async def get_github_auth(request: Request) -> GitHubAuthProvider: client: AsyncClient = request.app.state.httpx_client return GitHubAuthProvider( @@ -38,7 +43,7 @@ async def get_github_auth(request: Request) -> GitHubAuthProvider: ) -LoginKind: TypeAlias = Literal['password', 'github'] +LoginKind: TypeAlias = Literal['password', 'github', 'google'] @router.get('/login/{kind}', response_model=FastUI, response_model_exclude_none=True) @@ -63,6 +68,11 @@ def auth_login( on_click=PageEvent(name='tab', push_path='/auth/login/github', context={'kind': 'github'}), active='/auth/login/github', ), + c.Link( + components=[c.Text(text='Google Login')], + on_click=PageEvent(name='tab', push_path='/auth/login/google', context={'kind': 'google'}), + active='/auth/login/google', + ), ], mode='tabs', class_name='+ mb-4', @@ -98,6 +108,13 @@ def auth_login_content(kind: LoginKind) -> list[AnyComponent]: c.Paragraph(text='(Credentials are stored in the browser via a JWT only)'), c.Button(text='Login with GitHub', on_click=GoToEvent(url='/auth/login/github/gen')), ] + case 'google': + return [ + c.Heading(text='Google Login', level=3), + c.Paragraph(text='Demo of Google authentication.'), + c.Paragraph(text='(Credentials are stored in the browser via a JWT only)'), + c.Button(text='Login with Google', on_click=GoToEvent(url='/auth/login/google/gen')), + ] case _: raise ValueError(f'Invalid kind {kind!r}') @@ -167,3 +184,50 @@ async def github_redirect( ) token = user.encode_token() return [c.FireEvent(event=AuthEvent(token=token, url='/auth/profile'))] + + +async def get_google_auth(request: Request) -> GoogleAuthProvider: + client: AsyncClient = request.app.state.httpx_client + return GoogleAuthProvider( + httpx_client=client, + google_client_id=GOOGLE_CLIENT_ID, + google_client_secret=GOOGLE_CLIENT_SECRET, + redirect_uri=GOOGLE_REDIRECT_URI, + scopes=['https://www.googleapis.com/auth/userinfo.email', 'https://www.googleapis.com/auth/userinfo.profile'], + ) + + +@router.get('/login/google/gen', response_model=FastUI, response_model_exclude_none=True) +async def auth_google_gen(request: Request) -> list[AnyComponent]: + google_auth = await get_google_auth(request) + try: + # here we should use the refresh token to get a new access token but for the demo we don't store it + refresh_token = 'fake_refresh_token' + exchange = await google_auth.refresh_access_token(refresh_token) + google_user = await google_auth.get_google_user(exchange) + user = User( + email=google_user.email, + extra={'google_user_info': google_user.dict()}, + ) + token = user.encode_token() + return [c.FireEvent(event=AuthEvent(token=token, url='/auth/profile'))] + except Exception: + auth_url = await google_auth.authorization_url() + return [c.FireEvent(event=GoToEvent(url=auth_url))] + + +@router.get('/login/google/redirect', response_model=FastUI, response_model_exclude_none=True) +async def google_redirect( + request: Request, + code: str, +) -> list[AnyComponent]: + google_auth = await get_google_auth(request) + exchange = await google_auth.exchange_code(code) + google_user = await google_auth.get_google_user(exchange) + user = User( + email=google_user.email, + extra={'google_user_info': google_user.dict()}, + ) + # here should store the refresh token somewhere but for the demo we don't store it + token = user.encode_token() + return [c.FireEvent(event=AuthEvent(token=token, url='/auth/profile'))] diff --git a/src/python-fastui/fastui/auth/__init__.py b/src/python-fastui/fastui/auth/__init__.py index 89377824..0d5452a0 100644 --- a/src/python-fastui/fastui/auth/__init__.py +++ b/src/python-fastui/fastui/auth/__init__.py @@ -1,4 +1,5 @@ from .github import GitHubAuthProvider, GitHubEmail, GitHubExchange, GithubUser +from .google import GoogleAuthProvider, GoogleExchange, GoogleExchangeError, GoogleUser from .shared import AuthError, AuthRedirect, fastapi_auth_exception_handling __all__ = ( @@ -6,6 +7,10 @@ 'GitHubExchange', 'GithubUser', 'GitHubEmail', + 'GoogleAuthProvider', + 'GoogleExchange', + 'GoogleUser', + 'GoogleExchangeError', 'AuthError', 'AuthRedirect', 'fastapi_auth_exception_handling', diff --git a/src/python-fastui/fastui/auth/github.py b/src/python-fastui/fastui/auth/github.py index e32d41dd..5bbcf26a 100644 --- a/src/python-fastui/fastui/auth/github.py +++ b/src/python-fastui/fastui/auth/github.py @@ -1,12 +1,12 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Tuple, Union, cast +from typing import TYPE_CHECKING, AsyncIterator, Dict, List, Union, cast from urllib.parse import urlencode from pydantic import BaseModel, SecretStr, TypeAdapter, field_validator -from .shared import AuthError +from .shared import AuthError, ExchangeCache, ExchangeData if TYPE_CHECKING: import httpx @@ -22,7 +22,7 @@ class GitHubExchangeError: @dataclass -class GitHubExchange: +class GitHubExchange(ExchangeData): access_token: str token_type: str scope: List[str] @@ -219,34 +219,6 @@ def _auth_headers(exchange: GitHubExchange) -> Dict[str, str]: } -class ExchangeCache: - def __init__(self): - self._data: Dict[str, Tuple[datetime, GitHubExchange]] = {} - - def get(self, key: str, max_age: timedelta) -> Union[GitHubExchange, None]: - self._purge(max_age) - if v := self._data.get(key): - return v[1] - - def set(self, key: str, value: GitHubExchange) -> None: - self._data[key] = (datetime.now(), value) - - def _purge(self, max_age: timedelta) -> None: - """ - Remove old items from the exchange cache - """ - min_timestamp = datetime.now() - max_age - to_remove = [k for k, (ts, _) in self._data.items() if ts < min_timestamp] - for k in to_remove: - del self._data[k] - - def __len__(self) -> int: - return len(self._data) - - def clear(self) -> None: - self._data.clear() - - # exchange cache is a singleton so instantiating a new GitHubAuthProvider reuse the same cache EXCHANGE_CACHE = ExchangeCache() diff --git a/src/python-fastui/fastui/auth/google.py b/src/python-fastui/fastui/auth/google.py new file mode 100644 index 00000000..1b9265a2 --- /dev/null +++ b/src/python-fastui/fastui/auth/google.py @@ -0,0 +1,153 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import AsyncIterator, List, Optional, Union, cast +from urllib.parse import urlencode + +import httpx +from pydantic import BaseModel, SecretStr, TypeAdapter + +from .shared import AuthError, ExchangeCache, ExchangeData + + +@dataclass +class GoogleExchangeError: + error: str + error_description: Union[str, None] = None + + +@dataclass +class GoogleExchange(ExchangeData): + access_token: str + token_type: str + scope: str + expires_in: int + refresh_token: Union[str, None] = None + + +google_exchange_type = TypeAdapter(Union[GoogleExchange, GoogleExchangeError]) + + +class GoogleUser(BaseModel): + id: str + email: Optional[str] = None + verified_email: Optional[bool] = None + name: Optional[str] = None + given_name: Optional[str] = None + family_name: Optional[str] = None + picture: Optional[str] = None + locale: Optional[str] = None + + +class GoogleAuthProvider: + def __init__( + self, + httpx_client: 'httpx.AsyncClient', + google_client_id: str, + google_client_secret: SecretStr, + redirect_uri: Union[str, None] = None, + scopes: Union[List[str], None] = None, + exchange_cache_age: Union[timedelta, None] = timedelta(seconds=30), + ): + self._httpx_client = httpx_client + self._google_client_id = google_client_id + self._google_client_secret = google_client_secret + self._redirect_uri = redirect_uri + self._scopes = scopes or [ + 'https://www.googleapis.com/auth/userinfo.email', + 'https://www.googleapis.com/auth/userinfo.profile', + ] + self._exchange_cache_age = exchange_cache_age + + @classmethod + @asynccontextmanager + async def create( + cls, + client_id: str, + client_secret: SecretStr, + redirect_uri: Union[str, None] = None, + exchange_cache_age: Union[timedelta, None] = timedelta(seconds=10), + ) -> AsyncIterator['GoogleAuthProvider']: + async with httpx.AsyncClient() as client: + yield cls( + client, + client_id, + client_secret, + redirect_uri=redirect_uri, + exchange_cache_age=exchange_cache_age, + ) + + async def authorization_url(self) -> str: + params = { + 'client_id': self._google_client_id, + 'response_type': 'code', + 'scope': ' '.join(self._scopes), + 'redirect_uri': self._redirect_uri, + 'access_type': 'offline', + 'prompt': 'consent', + } + return f'https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}' + + async def exchange_code(self, code: str) -> GoogleExchange: + if self._exchange_cache_age: + cache_key = f'{code}' + if exchange := EXCHANGE_CACHE.get(cache_key, self._exchange_cache_age): + return exchange + else: + exchange = await self._exchange_code(code) + EXCHANGE_CACHE.set(key=cache_key, value=exchange) + return exchange + else: + return await self._exchange_code(code) + + async def _exchange_code(self, code: str) -> GoogleExchange: + params = { + 'client_id': self._google_client_id, + 'client_secret': self._google_client_secret.get_secret_value(), + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self._redirect_uri, + } + r = await self._httpx_client.post( + 'https://oauth2.googleapis.com/token', + data=params, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + ) + r.raise_for_status() + exchange_response = google_exchange_type.validate_json(r.content) + if isinstance(exchange_response, GoogleExchangeError): + raise AuthError('Google OAuth error', code=exchange_response.error) + else: + return cast(GoogleExchange, exchange_response) + + async def refresh_access_token(self, refresh_token: str) -> GoogleExchange: + params = { + 'client_id': self._google_client_id, + 'client_secret': self._google_client_secret.get_secret_value(), + 'refresh_token': refresh_token, + 'grant_type': 'refresh_token', + } + response = await self._httpx_client.post( + 'https://oauth2.googleapis.com/token', + data=params, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + ) + response.raise_for_status() + exchange_response = google_exchange_type.validate_json(response.content) + if isinstance(exchange_response, GoogleExchangeError): + raise AuthError('Google OAuth error', code=exchange_response.error) + else: + new_access_token = cast(GoogleExchange, exchange_response) + return new_access_token + + async def get_google_user(self, exchange: GoogleExchange) -> GoogleUser: + headers = { + 'Authorization': f'Bearer {exchange.access_token}', + 'Accept': 'application/json', + } + user_response = await self._httpx_client.get('https://www.googleapis.com/oauth2/v1/userinfo', headers=headers) + user_response.raise_for_status() + return GoogleUser.model_validate_json(user_response.content) + + +EXCHANGE_CACHE = ExchangeCache() diff --git a/src/python-fastui/fastui/auth/shared.py b/src/python-fastui/fastui/auth/shared.py index 37abedc5..cf1e35e3 100644 --- a/src/python-fastui/fastui/auth/shared.py +++ b/src/python-fastui/fastui/auth/shared.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Tuple, Union +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Dict, Generic, List, Tuple, TypeVar, Union from .. import AnyComponent, FastUI, events from .. import components as c @@ -8,7 +9,8 @@ if TYPE_CHECKING: from fastapi import FastAPI -__all__ = 'AuthError', 'AuthRedirect', 'fastapi_auth_exception_handling' + +__all__ = 'AuthError', 'AuthRedirect', 'fastapi_auth_exception_handling', 'ExchangeCache', 'ExchangeData' class AuthException(ABC, Exception): @@ -56,3 +58,38 @@ def fastapi_auth_exception_handling(app: 'FastAPI') -> None: def auth_exception_handler(_request: Request, e: AuthException) -> Response: status_code, body = e.response_data() return Response(body, media_type='application/json', status_code=status_code) + + +class ExchangeData: + pass + + +T = TypeVar('T', bound='ExchangeData') + + +class ExchangeCache(Generic[T]): + def __init__(self): + self._data: Dict[str, Tuple[datetime, T]] = {} + + def get(self, key: str, max_age: timedelta) -> Union[T, None]: + self._purge(max_age) + if v := self._data.get(key): + return v[1] + + def set(self, key: str, value: T) -> None: + self._data[key] = (datetime.now(), value) + + def _purge(self, max_age: timedelta) -> None: + """ + Remove old items from the exchange cache + """ + min_timestamp = datetime.now() - max_age + to_remove = [k for k, (ts, _) in self._data.items() if ts < min_timestamp] + for k in to_remove: + del self._data[k] + + def __len__(self) -> int: + return len(self._data) + + def clear(self) -> None: + self._data.clear() diff --git a/src/python-fastui/tests/test_auth_google.py b/src/python-fastui/tests/test_auth_google.py new file mode 100644 index 00000000..53e2e1ee --- /dev/null +++ b/src/python-fastui/tests/test_auth_google.py @@ -0,0 +1,138 @@ +from datetime import timedelta + +import httpx +import pytest +from fastui.auth.google import ( + EXCHANGE_CACHE, + AuthError, + GoogleAuthProvider, + GoogleExchange, + GoogleUser, +) +from httpx import Request, Response +from pydantic import SecretStr + + +class MockTransport(httpx.AsyncBaseTransport): + async def handle_async_request(self, request: Request) -> Response: + url = str(request.url) + method = request.method + + if url == 'https://oauth2.googleapis.com/token' and method == 'POST': + print(request.read()) + if b'code=bad_code' in request.read(): + return Response(200, json={'error': 'bad code'}) + + json_data = { + 'access_token': 'test_access_token', + 'token_type': 'Bearer', + 'expires_in': 3600, + 'refresh_token': 'test_refresh_token', + 'scope': 'email profile', + } + return Response(200, json=json_data) + + elif url == 'https://www.googleapis.com/oauth2/v1/userinfo' and method == 'GET': + json_data = { + 'id': '12345', + 'email': 'user@example.com', + 'verified_email': True, + 'name': 'Test User', + 'given_name': 'Test', + 'family_name': 'User', + 'picture': 'https://example.com/avatar.png', + 'locale': 'en', + } + return Response(200, json=json_data) + + return Response(404, json={'error': 'not found'}) + + +@pytest.fixture +async def mock_httpx_client() -> httpx.AsyncClient: + client = httpx.AsyncClient(transport=MockTransport()) + yield client + await client.aclose() + + +@pytest.fixture +async def google_auth_provider(mock_httpx_client: httpx.AsyncClient): + return GoogleAuthProvider( + httpx_client=mock_httpx_client, + google_client_id='google_client_id', + google_client_secret=SecretStr('google_client_secret'), + redirect_uri='https://example.com/callback', + scopes=['email', 'profile'], + exchange_cache_age=timedelta(minutes=5), + ) + + +async def test_create(): + async with GoogleAuthProvider.create('foo', SecretStr('bar')) as provider: + assert isinstance(provider._httpx_client, httpx.AsyncClient) + + +async def test_authorization_url(google_auth_provider: GoogleAuthProvider): + url = await google_auth_provider.authorization_url() + assert url.startswith('https://accounts.google.com/o/oauth2/v2/auth?') + + +async def test_exchange_code_success(google_auth_provider: GoogleAuthProvider): + exchange = await google_auth_provider.exchange_code('good_code') + assert isinstance(exchange, GoogleExchange) + assert exchange.access_token == 'test_access_token' + assert exchange.token_type == 'Bearer' + assert exchange.scope == 'email profile' + assert exchange.refresh_token == 'test_refresh_token' + + +async def test_exchange_code_error(google_auth_provider: GoogleAuthProvider): + with pytest.raises(AuthError): + await google_auth_provider.exchange_code('bad_code') + + +async def test_refresh_access_token(google_auth_provider: GoogleAuthProvider): + new_token = await google_auth_provider.refresh_access_token('good_refresh_token') + assert isinstance(new_token, GoogleExchange) + assert new_token.access_token == 'test_access_token' + + +async def test_get_google_user(google_auth_provider: GoogleAuthProvider): + exchange = GoogleExchange( + access_token='good_access_token', + token_type='Bearer', + scope='email profile', + expires_in=3600, + refresh_token='good_refresh_token', + ) + user = await google_auth_provider.get_google_user(exchange) + assert isinstance(user, GoogleUser) + assert user.id == '12345' + assert user.email == 'user@example.com' + + +async def test_exchange_cache( + google_auth_provider: GoogleAuthProvider, +): + EXCHANGE_CACHE.clear() + assert len(EXCHANGE_CACHE) == 0 + await google_auth_provider.exchange_code('good_code') + assert len(EXCHANGE_CACHE) == 1 + await google_auth_provider.exchange_code('good_code') + assert len(EXCHANGE_CACHE) == 1 + + +async def test_exchange_no_cache(mock_httpx_client): + EXCHANGE_CACHE.clear() + provider = GoogleAuthProvider( + httpx_client=mock_httpx_client, + google_client_id='google_client_id', + google_client_secret=SecretStr('google_client_secret'), + redirect_uri='https://example.com/callback', + scopes=['email', 'profile'], + exchange_cache_age=None, + ) + await provider.exchange_code('good_code') + assert len(EXCHANGE_CACHE) == 0 + await provider.exchange_code('good_code') + assert len(EXCHANGE_CACHE) == 0