Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion shapeshifter_uftp/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..exceptions import ClientTransportException
from ..logging import logger
from ..uftp import PayloadMessage, PayloadMessageResponse, SignedMessage
from shapeshifter_uftp.token_manager import AuthTokenManager


class ShapeshifterClient:
Expand All @@ -34,6 +35,7 @@ def __init__(
recipient_domain: str,
recipient_endpoint: str = None,
recipient_signing_key: str = None,
oauth_token_manager: AuthTokenManager = None,
):
"""
Shapeshifter client class that allows you to initiate messages to a different party.
Expand All @@ -55,6 +57,7 @@ def __init__(
self.recipient_domain = recipient_domain
self.recipient_endpoint = recipient_endpoint
self.recipient_signing_key = recipient_signing_key
self.oauth_token_manager = oauth_token_manager

# The outgoing queue and scheduler are used when queueing
# messages for delivery later. This allows the Shapeshifter
Expand Down Expand Up @@ -114,11 +117,23 @@ def _send_message(self, message: PayloadMessage) -> PayloadMessageResponse:
logger.debug(f"Sending message to {self.recipient_endpoint}:")
logger.debug(serialized_message)

# Find the right headers to use for the request. If we have
# an OAuth2 token manager, we will use that to get the
# request headers. If not, we will use the basic Content-Type
try:
if self.oauth_token_manager:
header = self.oauth_token_manager.get_request_headers()
else:
header = {"Content-Type": "text/xml; charset=utf-8"}
except Exception as e:
logger.warning(f"Failed to get OAuth2 headers, falling back to basic headers: {e}")
header = {"Content-Type": "text/xml; charset=utf-8"}

# Send the request to the relevant endpoint
response = requests.post(
self.recipient_endpoint,
data=serialized_message,
headers={"Content-Type": "text/xml; charset=utf-8"},
headers=header,
timeout=self.request_timeout,
)
if response.status_code != 200:
Expand Down
25 changes: 23 additions & 2 deletions shapeshifter_uftp/service/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PayloadMessageResponse,
SignedMessage,
)
from ..token_manager import AuthTokenManager


class ShapeshifterService():
Expand All @@ -46,11 +47,15 @@ def __init__(
self,
sender_domain,
signing_key,
oauth_token_endpoint: str = None,
oauth_client_id: str = None,
oauth_client_secret: str = None,
token_refresh_buffer: int = 30,
key_lookup_function=None,
endpoint_lookup_function=None,
host: str = "0.0.0.0",
port: int = 8080,
path="/shapeshifter/api/v3/message",
path="/shapeshifter/api/v3/message"
):
"""
:param sender_domain: our sender domain (FQDN) that the recipient uses to look us up.
Expand All @@ -64,6 +69,9 @@ def __init__(
:param host: the host to bind the server to (usually 127.0.0.1 or 0.0.0.0)
:param port: the port to bind the server to (default: 8080)
:param path: the URL path that the server listens on (default: /shapeshifter/api/v3/message)
:param oauth_token_endpoint: the OAuth2 token endpoint to use for obtaining access tokens
:param oauth_client_id: the OAuth2 client ID to use for obtaining access tokens
:param oauth_client_secret: the OAuth2 client secret to use for obtaining access tokens
"""

# Set the sender domain, which is used
Expand All @@ -87,6 +95,18 @@ def __init__(
# The FastAPI web app takes care of routing messages to the
# (one) endpoint, and by virtue of FastAPI-XML convert the
# python-friendly objects into XML and vice versa.

# Create Auth Manager for OAuth2 Client Credentials flow (if configured)
if oauth_token_endpoint and oauth_client_id and oauth_client_secret:
self.auth_token_manager = AuthTokenManager(
oauth_token_endpoint=oauth_token_endpoint,
oauth_client_id=oauth_client_id,
oauth_client_secret=oauth_client_secret,
token_refresh_buffer=token_refresh_buffer
)
else:
self.auth_token_manager = None

self.app = FastAPI(default_response_class=XmlAppResponse)
self.app.router.route_class = XmlRoute
self.app.router.add_api_route(
Expand Down Expand Up @@ -249,7 +269,8 @@ def _get_client(self, recipient_domain, recipient_role):
signing_key = self.signing_key,
recipient_domain = recipient_domain,
recipient_endpoint = recipient_endpoint,
recipient_signing_key = recipient_signing_key
recipient_signing_key = recipient_signing_key,
oauth_token_manager = self.auth_token_manager
)

def __enter__(self):
Expand Down
122 changes: 122 additions & 0 deletions shapeshifter_uftp/token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from datetime import datetime, timezone, timedelta

import requests

from .logging import logger

from typing import Optional
from threading import Lock

class AuthTokenManager:
"""
A token manager that can be used to manage tokens for the Shapeshifter client.
It handles OAuth2 Client Credentials flow to obtain and refresh tokens.
This class is thread-safe and ensures that tokens are refreshed only when necessary.
It provides a method to get request headers with the Bearer token included.
"""
request_timeout: int = 30

def __init__(self,
oauth_token_endpoint: str,
oauth_client_id: str,
oauth_client_secret: str,
token_refresh_buffer: int = 30):
self.oauth_token_endpoint = oauth_token_endpoint
self.oauth_client_id = oauth_client_id
self.oauth_client_secret = oauth_client_secret
self.token_refresh_buffer = token_refresh_buffer
self._access_token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
self._token_lock = Lock()

def _is_oauth_configured(self) -> bool:
"""Check if OAuth2 is properly configured."""
return all([
self.oauth_token_endpoint,
self.oauth_client_id,
self.oauth_client_secret
])

def _is_token_valid(self) -> bool:
"""Check if the current token is valid and not close to expiring."""
if not self._access_token or not self._token_expires_at:
return False

buffer_time = datetime.now(timezone.utc) + timedelta(seconds=self.token_refresh_buffer)
return self._token_expires_at > buffer_time

def _obtain_bearer_token(self) -> str:
"""
Obtain a Bearer token using OAuth2 Client Credentials flow.

:return: Access token string
:raises: Exception if token acquisition fails
"""
if not self._is_oauth_configured():
raise ValueError("OAuth2 not configured. Please provide oauth_token_endpoint, oauth_client_id, and oauth_client_secret.")

token_data = {
'grant_type': 'client_credentials',
'client_id': self.oauth_client_id,
'client_secret': self.oauth_client_secret
}

headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}

try:
response = requests.post(
self.oauth_token_endpoint,
data=token_data,
headers=headers,
timeout=self.request_timeout
)
response.raise_for_status()

token_response = response.json()
access_token = token_response['access_token']
expires_in = token_response.get('expires_in', 300)

self._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)

logger.info(f"Successfully obtained OAuth2 token, expires at {self._token_expires_at}")
return access_token

except requests.exceptions.RequestException as e:
logger.error(f"Failed to obtain OAuth2 token: {e}")
raise
except KeyError as e:
logger.error(f"Invalid token response format, missing key: {e}")
raise

def _get_valid_token(self) -> Optional[str]:
"""
Get a valid Bearer token, refreshing if necessary.
Thread-safe implementation.

:return: Valid access token or None if OAuth2 not configured
"""
if not self._is_oauth_configured():
return None

with self._token_lock:
if not self._is_token_valid():
logger.debug("Token invalid or expired, obtaining new token")
self._access_token = self._obtain_bearer_token()

return self._access_token

def get_request_headers(self) -> dict:
"""
Get headers for HTTP requests, including Bearer token if configured.

:return: Dictionary of headers
"""
headers = {"Content-Type": "text/xml; charset=utf-8"}

token = self._get_valid_token()
if token:
headers["Authorization"] = f"Bearer {token}"

return headers
2 changes: 1 addition & 1 deletion test/helpers/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def key_lookup_function(domain, role):
return CRO_PUBLIC_KEY
elif domain == "dso.dev":
return DSO_PUBLIC_KEY


class DummyAgrService(ShapeshifterAgrService):

Expand Down
Loading