diff --git a/apiclient/__init__.py b/apiclient/__init__.py index 05d711e..8a71389 100644 --- a/apiclient/__init__.py +++ b/apiclient/__init__.py @@ -5,7 +5,7 @@ NoAuthentication, QueryParameterAuthentication, ) -from apiclient.client import APIClient +from apiclient.client import AbstractClient, APIClient, AsyncAPIClient from apiclient.decorates import endpoint from apiclient.paginators import paginated from apiclient.request_formatters import JsonRequestFormatter diff --git a/apiclient/authentication_methods.py b/apiclient/authentication_methods.py index f7a9159..345ff5e 100644 --- a/apiclient/authentication_methods.py +++ b/apiclient/authentication_methods.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: # pragma: no cover # Stupid way of getting around cyclic imports when # using typehinting. - from apiclient import APIClient + from apiclient.client import AbstractClient class BaseAuthenticationMethod: @@ -19,7 +19,7 @@ def get_query_params(self) -> dict: def get_username_password_authentication(self) -> Optional[BasicAuthType]: return None - def perform_initial_auth(self, client: "APIClient"): + def perform_initial_auth(self, client: "AbstractClient"): pass @@ -91,7 +91,7 @@ def __init__( self._auth_url = auth_url self._authentication = authentication - def perform_initial_auth(self, client: "APIClient"): + def perform_initial_auth(self, client: "AbstractClient"): client.get( self._auth_url, headers=self._authentication.get_headers(), diff --git a/apiclient/client.py b/apiclient/client.py index 1313b20..af018b5 100644 --- a/apiclient/client.py +++ b/apiclient/client.py @@ -5,7 +5,7 @@ from apiclient.authentication_methods import BaseAuthenticationMethod, NoAuthentication from apiclient.error_handlers import BaseErrorHandler, ErrorHandler from apiclient.request_formatters import BaseRequestFormatter, NoOpRequestFormatter -from apiclient.request_strategies import BaseRequestStrategy, RequestStrategy +from apiclient.request_strategies import AsyncRequestStrategy, BaseRequestStrategy, RequestStrategy from apiclient.response_handlers import BaseResponseHandler, RequestsResponseHandler from apiclient.utils.typing import OptionalDict @@ -15,7 +15,7 @@ DEFAULT_TIMEOUT = 10.0 -class APIClient: +class AbstractClient: def __init__( self, authentication_method: Optional[BaseAuthenticationMethod] = None, @@ -37,11 +37,14 @@ def __init__( self.set_response_handler(response_handler) self.set_error_handler(error_handler) self.set_request_formatter(request_formatter) - self.set_request_strategy(request_strategy or RequestStrategy()) + self.set_request_strategy(request_strategy or self.get_default_request_strategy()) # Perform any one time authentication required by api self._authentication_method.perform_initial_auth(self) + def get_default_request_strategy(self): # pragma: no cover + raise NotImplementedError + def get_session(self) -> Any: return self._session @@ -135,3 +138,44 @@ def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): """Remove resource with DELETE endpoint.""" LOG.debug("DELETE %s", endpoint) return self.get_request_strategy().delete(endpoint, params=params, **kwargs) + + +class APIClient(AbstractClient): + def get_default_request_strategy(self): + return RequestStrategy() + + +class AsyncAPIClient(AbstractClient): + async def __aenter__(self): + session = await self._request_strategy.create_session() + self.set_session(session) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + session = self.get_session() + if session: + await session.close() + self.set_session(None) + + def get_default_request_strategy(self): + return AsyncRequestStrategy() + + async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + """Send data and return response data from POST endpoint.""" + return await self.get_request_strategy().post(endpoint, data=data, params=params, **kwargs) + + async def get(self, endpoint: str, params: OptionalDict = None, **kwargs): + """Return response data from GET endpoint.""" + return await self.get_request_strategy().get(endpoint, params=params, **kwargs) + + async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + """Send data to overwrite resource and return response data from PUT endpoint.""" + return await self.get_request_strategy().put(endpoint, data=data, params=params, **kwargs) + + async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + """Send data to update resource and return response data from PATCH endpoint.""" + return await self.get_request_strategy().patch(endpoint, data=data, params=params, **kwargs) + + async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): + """Remove resource with DELETE endpoint.""" + return await self.get_request_strategy().delete(endpoint, params=params, **kwargs) diff --git a/apiclient/request_strategies.py b/apiclient/request_strategies.py index ce53498..dc57146 100644 --- a/apiclient/request_strategies.py +++ b/apiclient/request_strategies.py @@ -1,55 +1,99 @@ from copy import deepcopy -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable +import aiohttp import requests from apiclient.exceptions import UnexpectedError -from apiclient.response import RequestsResponse, Response +from apiclient.response import AioHttpResponse, RequestsResponse, Response from apiclient.utils.typing import OptionalDict if TYPE_CHECKING: # pragma: no cover # Stupid way of getting around cyclic imports when # using typehinting. - from apiclient import APIClient + from apiclient.client import AbstractClient class BaseRequestStrategy: - def set_client(self, client: "APIClient"): + def set_client(self, client: "AbstractClient"): self._client = client - def get_client(self) -> "APIClient": + def get_session(self): + return self.get_client().get_session() + + def set_session(self, session: Any): + self.get_client().set_session(session) + + def create_session(self): # pragma: no cover + """Abstract method that will create a session object.""" + raise NotImplementedError + + def get_client(self) -> "AbstractClient": return self._client - def post(self, *args, **kwargs): # pragma: no cover + def _get_request_params(self, params: OptionalDict) -> dict: + """Return dictionary with any additional authentication query parameters.""" + if params is None: + params = {} + params.update(self.get_client().get_default_query_params()) + return params + + def _get_request_headers(self, headers: OptionalDict) -> dict: + """Return dictionary with any additional authentication headers.""" + if headers is None: + headers = {} + headers.update(self.get_client().get_default_headers()) + return headers + + def _get_username_password_authentication(self): + return self.get_client().get_default_username_password_authentication() + + def _get_formatted_data(self, data: OptionalDict): + return self.get_client().get_request_formatter().format(data) + + def _get_request_timeout(self) -> float: + """Return the number of seconds before the request times out.""" + return self.get_client().get_request_timeout() + + def _check_response(self, response: Response): + """Raise a custom exception if the response is not OK.""" + status_code = response.get_status_code() + if status_code < 200 or status_code >= 300: + self._handle_bad_response(response) + + def _decode_response_data(self, response: Response): + return self.get_client().get_response_handler().get_request_data(response) + + def _handle_bad_response(self, response: Response): + """Convert the error into an understandable client exception.""" + raise self.get_client().get_error_handler().get_exception(response) + + def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover raise NotImplementedError - def get(self, *args, **kwargs): # pragma: no cover + def get(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover raise NotImplementedError - def put(self, *args, **kwargs): # pragma: no cover + def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover raise NotImplementedError - def patch(self, *args, **kwargs): # pragma: no cover + def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): # pragma: no cover raise NotImplementedError - def delete(self, *args, **kwargs): # pragma: no cover + def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): # pragma: no cover raise NotImplementedError class RequestStrategy(BaseRequestStrategy): """Requests strategy that uses the `requests` lib with a `requests.session`.""" - def set_client(self, client: "APIClient"): + def set_client(self, client: "AbstractClient"): super().set_client(client) - # Set a global `requests.session` on the parent client instance. if self.get_session() is None: - self.set_session(requests.session()) + self.set_session(self.create_session()) - def get_session(self): - return self.get_client().get_session() - - def set_session(self, session: requests.Session): - self.get_client().set_session(session) + def create_session(self) -> requests.Session: + return requests.session() def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): """Send data and return response data from POST endpoint.""" @@ -102,43 +146,6 @@ def _make_request( self._check_response(response) return self._decode_response_data(response) - def _get_request_params(self, params: OptionalDict) -> dict: - """Return dictionary with any additional authentication query parameters.""" - if params is None: - params = {} - params.update(self.get_client().get_default_query_params()) - return params - - def _get_request_headers(self, headers: OptionalDict) -> dict: - """Return dictionary with any additional authentication headers.""" - if headers is None: - headers = {} - headers.update(self.get_client().get_default_headers()) - return headers - - def _get_username_password_authentication(self): - return self.get_client().get_default_username_password_authentication() - - def _get_formatted_data(self, data: OptionalDict): - return self.get_client().get_request_formatter().format(data) - - def _get_request_timeout(self) -> float: - """Return the number of seconds before the request times out.""" - return self.get_client().get_request_timeout() - - def _check_response(self, response: Response): - """Raise a custom exception if the response is not OK.""" - status_code = response.get_status_code() - if status_code < 200 or status_code >= 300: - self._handle_bad_response(response) - - def _decode_response_data(self, response: Response): - return self.get_client().get_response_handler().get_request_data(response) - - def _handle_bad_response(self, response: Response): - """Convert the error into an understandable client exception.""" - raise self.get_client().get_error_handler().get_exception(response) - class QueryParamPaginatedRequestStrategy(RequestStrategy): """Strategy for GET requests where pages are defined in query params.""" @@ -192,3 +199,56 @@ def get(self, endpoint: str, params: OptionalDict = None, **kwargs): def get_next_page_url(self, response, previous_page_url: str) -> OptionalDict: return self._next_page(response, previous_page_url) + + +class AsyncRequestStrategy(BaseRequestStrategy): + async def create_session(self) -> aiohttp.ClientSession: + return aiohttp.ClientSession() + + def get_session(self) -> aiohttp.ClientSession: + return self.get_client().get_session() + + async def post(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + return await self._make_request( + self.get_session().post, endpoint, data=data, params=params, **kwargs + ) + + async def get(self, endpoint: str, params: OptionalDict = None, **kwargs): + return await self._make_request(self.get_session().get, endpoint, params=params, **kwargs) + + async def put(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + return await self._make_request(self.get_session().put, endpoint, data=data, params=params, **kwargs) + + async def patch(self, endpoint: str, data: dict, params: OptionalDict = None, **kwargs): + return await self._make_request( + self.get_session().patch, endpoint, data=data, params=params, **kwargs + ) + + async def delete(self, endpoint: str, params: OptionalDict = None, **kwargs): + return await self._make_request(self.get_session().delete, endpoint, params=params, **kwargs) + + async def _make_request( + self, + request_method: Callable, + endpoint: str, + params: OptionalDict = None, + headers: OptionalDict = None, + data: OptionalDict = None, + **kwargs, + ) -> Response: + try: + async with request_method( + endpoint, + params=self._get_request_params(params), + headers=self._get_request_headers(headers), + auth=self._get_username_password_authentication(), + data=self._get_formatted_data(data), + timeout=self._get_request_timeout(), + **kwargs, + ) as raw_response: + response = AioHttpResponse(raw_response, content=await raw_response.read()) + except Exception as error: + raise UnexpectedError(f"Error when contacting '{endpoint}'") from error + else: + self._check_response(response) + return self._decode_response_data(response) diff --git a/apiclient/response.py b/apiclient/response.py index f690f25..2948bac 100644 --- a/apiclient/response.py +++ b/apiclient/response.py @@ -1,5 +1,7 @@ +import json from typing import Any +import aiohttp import requests from apiclient.utils.typing import JsonType @@ -62,3 +64,26 @@ def get_status_reason(self) -> str: def get_requested_url(self) -> str: return self._response.url + + +class AioHttpResponse(RequestsResponse): + """Implementation of the response for a requests.response type.""" + + def __init__(self, response: aiohttp.ClientResponse, content: bytes): + self._response = response + self._content = content + self._text = "" + + def get_status_code(self) -> int: + return self._response.status + + def get_raw_data(self) -> str: + if not self._text: + self._text = self._content.decode(self._response.get_encoding(), errors="strict") + return self._text + + def get_json(self) -> JsonType: + return json.loads(self._text) + + def get_requested_url(self) -> str: + return str(self._response.url) diff --git a/setup.cfg b/setup.cfg index 114a27e..e839de9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [tool:pytest] -addopts = --cov=apiclient/ --cov-fail-under=100 --cov-report html +addopts = --asyncio-mode=auto --cov=apiclient/ --cov-fail-under=100 --cov-report html env = ENDPOINT_BASE_URL=http://environment.com diff --git a/setup.py b/setup.py index c4d46e3..2718533 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ import setuptools # Pinning tenacity as the api has changed slightly which breaks all tests. -application_dependencies = ["requests>=2.16", "tenacity>=5.1.0"] +application_dependencies = ["requests>=2.16", "aiohttp>=3.8", "tenacity>=5.1.0"] prod_dependencies = [] -test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock"] +test_dependencies = ["pytest", "pytest-env", "pytest-cov", "vcrpy", "requests-mock", "pytest-asyncio", "aioresponses"] lint_dependencies = ["flake8", "flake8-docstrings", "black", "isort"] docs_dependencies = [] dev_dependencies = test_dependencies + lint_dependencies + docs_dependencies + ["ipdb"] diff --git a/tests/conftest.py b/tests/conftest.py index 05774eb..9baa7be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import requests import requests_mock import vcr +from aioresponses import aioresponses from apiclient import APIClient from apiclient.request_formatters import BaseRequestFormatter @@ -70,3 +71,9 @@ def mock_client(): return MockClient( client=_mock_client, request_formatter=mock_request_formatter, response_handler=mock_response_handler ) + + +@pytest.fixture() +def mock_aioresponse(): + with aioresponses() as m: + yield m diff --git a/tests/integration_tests/client.py b/tests/integration_tests/client.py index 28bfbe6..2c5f8c6 100644 --- a/tests/integration_tests/client.py +++ b/tests/integration_tests/client.py @@ -1,7 +1,7 @@ from enum import IntEnum, unique from json import JSONDecodeError -from apiclient import APIClient, endpoint, paginated, retry_request +from apiclient import APIClient, AsyncAPIClient, endpoint, paginated, retry_request from apiclient.error_handlers import ErrorHandler from apiclient.exceptions import APIRequestError from apiclient.response import Response @@ -104,3 +104,33 @@ def delete_user(self, user_id): @paginated(by_query_params=by_query_params_callable) def list_user_accounts_paginated(self, user_id): return self.get(Urls.accounts, params={"userId": user_id}) + + +class AsyncClient(AsyncAPIClient): + def get_request_timeout(self): + return 0.1 + + async def list_users(self): + return await self.get(Urls.users) + + async def create_user(self, first_name, last_name): + data = {"firstName": first_name, "lastName": last_name} + return await self.post(Urls.users, data=data) + + async def overwrite_user(self, user_id, first_name, last_name): + data = {"firstName": first_name, "lastName": last_name} + url = Urls.user.format(id=user_id) + return await self.put(url, data=data) + + async def update_user(self, user_id, first_name=None, last_name=None): + data = {} + if first_name: + data["firstName"] = first_name + if last_name: + data["lastName"] = last_name + url = Urls.user.format(id=user_id) + return await self.patch(url, data=data) + + async def delete_user(self, user_id): + url = Urls.user.format(id=user_id) + return await self.delete(url) diff --git a/tests/integration_tests/test_async_client_integration.py b/tests/integration_tests/test_async_client_integration.py new file mode 100644 index 0000000..d11d302 --- /dev/null +++ b/tests/integration_tests/test_async_client_integration.py @@ -0,0 +1,69 @@ +import asyncio + +import pytest + +from apiclient import JsonRequestFormatter, JsonResponseHandler, NoAuthentication +from apiclient.exceptions import UnexpectedError +from tests.integration_tests.client import AsyncClient, Urls + + +@pytest.mark.asyncio +async def test_client_response(mock_aioresponse): + mock_aioresponse.get( + Urls.users, + status=200, + payload=[ + {"userId": 1, "firstName": "Mike", "lastName": "Foo"}, + {"userId": 2, "firstName": "Sarah", "lastName": "Bar"}, + {"userId": 3, "firstName": "Barry", "lastName": "Baz"}, + ], + ) + mock_aioresponse.post( + Urls.users, status=201, payload={"userId": 4, "firstName": "Lucy", "lastName": "Qux"} + ) + mock_aioresponse.put( + Urls.user.format(id=4), status=200, payload={"userId": 4, "firstName": "Lucy", "lastName": "Foo"} + ) + mock_aioresponse.patch( + Urls.user.format(id=4), status=200, payload={"userId": 4, "firstName": "Lucy", "lastName": "Qux"} + ) + mock_aioresponse.delete(Urls.user.format(id=4), status=204, payload=None) + + async with AsyncClient( + authentication_method=NoAuthentication(), + response_handler=JsonResponseHandler, + request_formatter=JsonRequestFormatter, + ) as client: + responses = await asyncio.gather( + client.list_users(), + client.create_user(first_name="Lucy", last_name="Qux"), + client.overwrite_user(user_id=4, first_name="Lucy", last_name="Foo"), + client.update_user(user_id=4, first_name="Lucy", last_name="Qux"), + client.delete_user(user_id=4), + client.get("mock://testserver"), + return_exceptions=True, + ) + # users = await client.list_users() + # new_user = await client.create_user(first_name="Lucy", last_name="Qux") + # overwritten_user = await client.overwrite_user(user_id=4, first_name="Lucy", last_name="Foo") + # updated_user = await client.update_user(user_id=4, first_name="Lucy", last_name="Qux") + # deleted_user = await client.delete_user(user_id=4) + users, new_user, overwritten_user, updated_user, deleted_user, error = responses + + assert len(users) == 3 + assert users == [ + {"userId": 1, "firstName": "Mike", "lastName": "Foo"}, + {"userId": 2, "firstName": "Sarah", "lastName": "Bar"}, + {"userId": 3, "firstName": "Barry", "lastName": "Baz"}, + ] + + assert new_user == {"userId": 4, "firstName": "Lucy", "lastName": "Qux"} + assert overwritten_user == {"userId": 4, "firstName": "Lucy", "lastName": "Foo"} + assert updated_user == {"userId": 4, "firstName": "Lucy", "lastName": "Qux"} + assert deleted_user is None + + assert isinstance(error, UnexpectedError) + with pytest.raises(UnexpectedError) as exc_info: + async with AsyncClient() as client: + await client.get("mock://testserver") + assert str(exc_info.value) == "Error when contacting 'mock://testserver'" diff --git a/tests/test_response.py b/tests/test_response.py index b511a58..4109bc3 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -2,7 +2,7 @@ import pytest -from apiclient.response import RequestsResponse, Response +from apiclient.response import AioHttpResponse, RequestsResponse, Response class TestResponse: @@ -31,3 +31,10 @@ def test_get_status_reason_returns_empty_string_when_none(self): requests_response = Mock(reason=None) response = RequestsResponse(requests_response) assert response.get_status_reason() == "" + + +class TestAiRequestsResponse: + def test_get_url(self): + requests_response = Mock(url=1) + response = AioHttpResponse(requests_response, b"") + assert response.get_requested_url() == "1" diff --git a/tox.ini b/tox.ini index bf0e780..42b904e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - requests{216,217,218,219,220,221,222,223,224,225,226}-tenacity{51,60,61,62,63,70,80} + requests{216,217,218,219,220,221,222,223,224,225,226,227}-tenacity{51,60,61,62,63,70,80} lint [testenv] @@ -16,6 +16,7 @@ deps = requests224: requests>=2.24,<2.25 requests225: requests>=2.25,<2.26 requests226: requests>=2.26,<2.27 + requests227: requests>=2.27,<2.28 tenacity51: tenacity>=5.1.0,<5.2.0 tenacity60: tenacity>=6.0.0,<6.1.0 tenacity61: tenacity>=6.1.0,<6.2.0