diff --git a/flama/ddd/repositories.py b/flama/ddd/repositories.py index 7cb6e00d..3f0c5e52 100644 --- a/flama/ddd/repositories.py +++ b/flama/ddd/repositories.py @@ -1,5 +1,6 @@ import abc import typing as t +import uuid from flama import types from flama.ddd import exceptions @@ -17,7 +18,7 @@ except Exception: # pragma: no cover ... - from httpx import AsyncClient + from flama.client import Client __all__ = [ @@ -25,7 +26,7 @@ "SQLAlchemyRepository", "SQLAlchemyTableRepository", "SQLAlchemyTableManager", - "RestRepository", + "HTTPRepository", ] @@ -347,10 +348,84 @@ async def drop(self, *clauses, **filters) -> int: return await self._table_manager.drop(*clauses, **filters) -class RestRepository(AbstractRepository): - def __init__(self, client: "AsyncClient", *args, **kwargs): +class HTTPRepository(AbstractRepository): + def __init__(self, client: "Client", *args, **kwargs): super().__init__(*args, **kwargs) self._client = client def __eq__(self, other): - return isinstance(other, RestRepository) and self._client == other._client + return isinstance(other, HTTPRepository) and self._client == other._client + + +class HTTPResourceManager: + def __init__(self, resource: str, client: "Client"): + self._client = client + self.resource = resource + + @property + def path(self): + return self.resource + "/" + + async def create(self, data: dict[str, t.Any]): + """Create a new element in the resource. + + :param data: The data to create the element. + :return: The response of the POST method. + """ + return await self._client.post(f"{self.resource}/", json=data) + + async def retrieve(self, id: t.Union[str, uuid.UUID]): + """Retrieve an element from the resource. + + :param id: The id of the element. + :return: The response of the GET method. + """ + return await self._client.get(f"{self.resource}/{id}/") + + async def update(self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.Any], types.Schema]): + """Update an element in the resource. + + :param id: The id of the element. + :param data: The data to update the element. + :return: The response of the PUT method. + """ + return await self._client.put(f"{self.resource}/{id}/", json=data) + + async def delete(self, id: t.Union[str, uuid.UUID]): + """Delete an element from the resource. + + :param id: The id of the element. + :return: The response of the DELETE method. + """ + return await self._client.delete(f"{self.resource}/{id}/") + + async def list(self, *, page_size: int): + """List all the elements in the resource. + + :param page_size: The number of elements to retrieve per page. + :return: The response of the GET method paginated. + """ + return await self._client.get(f"{self.resource}/", params={"page_size": page_size}) + + +class HTTPResourceRepository(HTTPRepository): + _resource: str + + def __init__(self, client: "Client"): + super().__init__(client) + self._resource_manager = HTTPResourceManager(self._resource, client) + + async def create(self, data: dict[str, t.Any]): + return await self._resource_manager.create(data) + + async def retrieve(self, id: uuid.UUID): + return await self._resource_manager.retrieve(id) + + async def update(self, id: uuid.UUID, data: dict[str, t.Any]): + return await self._resource_manager.update(id, data) + + async def delete(self, id: uuid.UUID): + return await self._resource_manager.delete(id) + + async def list(self, *, page_size: int = 100): + return await self._resource_manager.list(page_size=page_size) diff --git a/tests/ddd/test_workers.py b/tests/ddd/test_workers.py index e9a458e1..af75b6db 100644 --- a/tests/ddd/test_workers.py +++ b/tests/ddd/test_workers.py @@ -3,7 +3,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncConnection -from flama.ddd.repositories import RestRepository, SQLAlchemyRepository +from flama.ddd.repositories import HTTPRepository, SQLAlchemyRepository from flama.ddd.workers import HTTPWorker, SQLAlchemyWorker from flama.exceptions import ApplicationError @@ -147,7 +147,7 @@ class TestCaseHTTPWorker: @pytest.fixture(scope="function") def worker(self, client): class FooWorker(HTTPWorker): - bar: RestRepository + bar: HTTPRepository return FooWorker(client.app) @@ -196,7 +196,7 @@ async def test_begin(self, worker): assert worker.begin_transaction.await_args_list == [call()] assert hasattr(worker, "bar") - assert isinstance(worker.bar, RestRepository) + assert isinstance(worker.bar, HTTPRepository) @pytest.mark.parametrize( ["rollback"],