From 2bed497a9a7feccfc357942d97c7217d3c7d1a95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Tue, 17 Dec 2024 13:15:09 +0100 Subject: [PATCH] :bug: Enhance ddd workers --- flama/ddd/workers/__init__.py | 2 +- flama/ddd/workers/base.py | 40 ++++++++++++ flama/ddd/workers/http.py | 46 +++++++------- flama/ddd/workers/noop.py | 27 -------- flama/ddd/workers/sqlalchemy.py | 62 +++++++++++-------- flama/ddd/workers/worker.py | 38 ++++++++++++ flama/resources/workers.py | 11 ++-- tests/ddd/test_components.py | 7 ++- tests/ddd/workers/test_base.py | 9 +++ tests/ddd/workers/test_http.py | 36 +++++------ tests/ddd/workers/test_sqlalchemy.py | 16 ++--- .../workers/{test_noop.py => test_worker.py} | 8 +-- 12 files changed, 189 insertions(+), 113 deletions(-) delete mode 100644 flama/ddd/workers/noop.py create mode 100644 flama/ddd/workers/worker.py rename tests/ddd/workers/{test_noop.py => test_worker.py} (59%) diff --git a/flama/ddd/workers/__init__.py b/flama/ddd/workers/__init__.py index 982674bc..df5bc8f0 100644 --- a/flama/ddd/workers/__init__.py +++ b/flama/ddd/workers/__init__.py @@ -1,3 +1,3 @@ from flama.ddd.workers.base import * # noqa from flama.ddd.workers.http import * # noqa -from flama.ddd.workers.noop import * # noqa +from flama.ddd.workers.worker import * # noqa diff --git a/flama/ddd/workers/base.py b/flama/ddd/workers/base.py index 32cbe8f8..dbb00a42 100644 --- a/flama/ddd/workers/base.py +++ b/flama/ddd/workers/base.py @@ -137,3 +137,43 @@ class BaseWorker(AbstractWorker, metaclass=WorkerType): """ _repositories: t.ClassVar[dict[str, type[BaseRepository]]] + + @abc.abstractmethod + async def set_up(self) -> None: + """First step in starting a unit of work.""" + ... + + @abc.abstractmethod + async def tear_down(self, *, rollback: bool = False) -> None: + """Last step in ending a unit of work. + + :param rollback: If the unit of work should be rolled back. + """ + ... + + @abc.abstractmethod + async def repository_params(self) -> tuple[list[t.Any], dict[str, t.Any]]: + """Get the parameters for initialising the repositories. + + :return: Parameters for initialising the repositories. + """ + ... + + async def begin(self) -> None: + """Start a unit of work.""" + await self.set_up() + + args, kwargs = await self.repository_params() + + for repository, repository_class in self._repositories.items(): + setattr(self, repository, repository_class(*args, **kwargs)) + + async def end(self, *, rollback: bool = False) -> None: + """End a unit of work. + + :param rollback: If the unit of work should be rolled back. + """ + await self.tear_down(rollback=rollback) + + for repository in self._repositories.keys(): + delattr(self, repository) diff --git a/flama/ddd/workers/http.py b/flama/ddd/workers/http.py index 950e7aa9..94b21ba2 100644 --- a/flama/ddd/workers/http.py +++ b/flama/ddd/workers/http.py @@ -38,39 +38,41 @@ def client(self) -> "Client": except AttributeError: raise AttributeError("Client not initialized") - async def begin_transaction(self) -> None: - """Initialize the client with the URL.""" - await self._client.__aenter__() + @client.setter + def client(self, client: "Client") -> None: + """Set the client to interact with an HTTP resource. - async def end_transaction(self) -> None: - """Close and delete the client.""" - await self.client.__aexit__() + :param client: Flama client. + """ + self._client = client - async def begin(self) -> None: - """Start a unit of work. + @client.deleter + def client(self) -> None: + """Delete the client.""" + del self._client - Initialize the client, and create the repositories. - """ + async def set_up(self) -> None: + """Initialize the client with the URL.""" from flama.client import Client - self._client = Client(base_url=self.url, **self._client_kwargs) - - await self.begin_transaction() + self.client = Client(base_url=self.url, **self._client_kwargs) - for repository, repository_class in self._repositories.items(): - setattr(self, repository, repository_class(self._client)) + await self.client.__aenter__() - async def end(self, *, rollback: bool = False) -> None: - """End a unit of work. + async def tear_down(self, *, rollback: bool = False) -> None: + """Close and delete the client. - Close the client, and delete the repositories. + :param rollback: If the unit of work should be rolled back. """ - await self.end_transaction() + await self.client.__aexit__() + del self.client - for repository in self._repositories.keys(): - delattr(self, repository) + async def repository_params(self) -> tuple[list[t.Any], dict[str, t.Any]]: + """Get the parameters for initialising the repositories. - del self._client + :return: Parameters for initialising the repositories. + """ + return [self.client], {} async def commit(self) -> None: ... diff --git a/flama/ddd/workers/noop.py b/flama/ddd/workers/noop.py deleted file mode 100644 index 8afdf52b..00000000 --- a/flama/ddd/workers/noop.py +++ /dev/null @@ -1,27 +0,0 @@ -from flama.ddd.workers.base import BaseWorker - - -class NoopWorker(BaseWorker): - """Worker that does not apply any specific behavior. - - A basic implementation of the worker class that does not apply any specific behavior. - """ - - async def begin(self) -> None: - """Start a unit of work.""" - ... - - async def end(self, *, rollback: bool = False) -> None: - """End a unit of work. - - :param rollback: If the unit of work should be rolled back. - """ - ... - - async def commit(self) -> None: - """Commit the unit of work.""" - ... - - async def rollback(self) -> None: - """Rollback the unit of work.""" - ... diff --git a/flama/ddd/workers/sqlalchemy.py b/flama/ddd/workers/sqlalchemy.py index 8f140c8e..f9d31c71 100644 --- a/flama/ddd/workers/sqlalchemy.py +++ b/flama/ddd/workers/sqlalchemy.py @@ -1,4 +1,5 @@ import logging +import typing as t from flama import exceptions from flama.ddd.workers.base import BaseWorker @@ -37,6 +38,19 @@ def connection(self) -> AsyncConnection: except AttributeError: raise AttributeError("Connection not initialized") + @connection.setter + def connection(self, connection: AsyncConnection) -> None: + """Set the connection to the database. + + :param connection: Connection to the database. + """ + self._connection = connection + + @connection.deleter + def connection(self) -> None: + """Delete the connection to the database.""" + del self._connection + @property def transaction(self) -> AsyncTransaction: """Database transaction. @@ -49,45 +63,43 @@ def transaction(self) -> AsyncTransaction: except AttributeError: raise AttributeError("Transaction not started") - async def begin_transaction(self) -> None: + @transaction.setter + def transaction(self, transaction: AsyncTransaction) -> None: + """Set the transaction. + + :param transaction: Database transaction. + """ + self._transaction = transaction + + @transaction.deleter + def transaction(self) -> None: + """Delete the transaction.""" + del self._transaction + + async def set_up(self) -> None: """Open a connection and begin a transaction.""" - self._connection = await self.app.sqlalchemy.open_connection() - self._transaction = await self.app.sqlalchemy.begin_transaction(self._connection) + self.connection = await self.app.sqlalchemy.open_connection() + self.transaction = await self.app.sqlalchemy.begin_transaction(self._connection) - async def end_transaction(self, *, rollback: bool = False) -> None: + async def tear_down(self, *, rollback: bool = False) -> None: """End a transaction and close the connection. :param rollback: If the transaction should be rolled back. :raises AttributeError: If the connection is not initialized or the transaction is not started. """ await self.app.sqlalchemy.end_transaction(self.transaction, rollback=rollback) - del self._transaction + del self.transaction await self.app.sqlalchemy.close_connection(self.connection) - del self._connection + del self.connection - async def begin(self) -> None: - """Start a unit of work. + async def repository_params(self) -> tuple[list[t.Any], dict[str, t.Any]]: + """Get the parameters for initialising the repositories. - Initialize the connection, begin a transaction, and create the repositories. + :return: Parameters for initialising the repositories. """ - await self.begin_transaction() - - for repository, repository_class in self._repositories.items(): - setattr(self, repository, repository_class(self.connection)) - - async def end(self, *, rollback: bool = False) -> None: - """End a unit of work. - - Close the connection, commit or rollback the transaction, and delete the repositories. - - :param rollback: If the unit of work should be rolled back. - """ - await self.end_transaction(rollback=rollback) - - for repository in self._repositories.keys(): - delattr(self, repository) + return [self.connection], {} async def commit(self): """Commit the unit of work.""" diff --git a/flama/ddd/workers/worker.py b/flama/ddd/workers/worker.py new file mode 100644 index 00000000..f4492dea --- /dev/null +++ b/flama/ddd/workers/worker.py @@ -0,0 +1,38 @@ +import typing as t + +from flama.ddd.workers.base import BaseWorker + +__all__ = ["Worker"] + + +class Worker(BaseWorker): + """Worker that does not apply any specific behavior. + + A basic implementation of the worker class that does not apply any specific behavior. + """ + + async def set_up(self) -> None: + """First step in starting a unit of work.""" + ... + + async def tear_down(self, *, rollback: bool = False) -> None: + """Last step in ending a unit of work. + + :param rollback: If the unit of work should be rolled back. + """ + ... + + async def repository_params(self) -> tuple[list[t.Any], dict[str, t.Any]]: + """Get the parameters for initialising the repositories. + + :return: Parameters for initialising the repositories. + """ + return [], {} + + async def commit(self) -> None: + """Commit the unit of work.""" + ... + + async def rollback(self) -> None: + """Rollback the unit of work.""" + ... diff --git a/flama/resources/workers.py b/flama/resources/workers.py index e77e544e..db341ad9 100644 --- a/flama/resources/workers.py +++ b/flama/resources/workers.py @@ -14,8 +14,8 @@ class Repositories: registered: dict[str, type["SQLAlchemyTableRepository"]] = dataclasses.field(default_factory=dict) initialised: t.Optional[dict[str, "SQLAlchemyTableRepository"]] = None - def init(self, connection: t.Any) -> None: - self.initialised = {r: cls(connection) for r, cls in self.registered.items()} + def init(self, *args: t.Any, **kwargs: t.Any) -> None: + self.initialised = {r: cls(*args, **kwargs) for r, cls in self.registered.items()} def delete(self) -> None: self.initialised = None @@ -67,8 +67,9 @@ async def begin(self) -> None: Initialize the connection, begin a transaction, and create the repositories. """ - await self.begin_transaction() - self._resources_repositories.init(self.connection) + await self.set_up() + args, kwargs = await self.repository_params() + self._resources_repositories.init(*args, **kwargs) async def end(self, *, rollback: bool = False) -> None: """End a unit of work. @@ -77,5 +78,5 @@ async def end(self, *, rollback: bool = False) -> None: :param rollback: If the unit of work should be rolled back. """ - await self.end_transaction(rollback=rollback) + await self.tear_down(rollback=rollback) self._resources_repositories.delete() diff --git a/tests/ddd/test_components.py b/tests/ddd/test_components.py index afd175ec..fcf60fa2 100644 --- a/tests/ddd/test_components.py +++ b/tests/ddd/test_components.py @@ -20,12 +20,15 @@ def worker(self, repository): class FooWorker(BaseWorker): foo: repository - async def begin(self): + async def set_up(self): ... - async def end(self, *, rollback: bool = False): + async def tear_down(self, *, rollback: bool = False): ... + async def repository_params(self): + return [], {} + async def commit(self): ... diff --git a/tests/ddd/workers/test_base.py b/tests/ddd/workers/test_base.py index dc49ad2f..77419ce8 100644 --- a/tests/ddd/workers/test_base.py +++ b/tests/ddd/workers/test_base.py @@ -64,6 +64,15 @@ def worker(self, repository): class FooWorker(BaseWorker): foo: repository + async def set_up(self): + ... + + async def tear_down(self, *, rollback: bool = False): + ... + + async def repository_params(self): + return [], {} + async def begin(self): ... diff --git a/tests/ddd/workers/test_http.py b/tests/ddd/workers/test_http.py index 0aae1823..6e566063 100644 --- a/tests/ddd/workers/test_http.py +++ b/tests/ddd/workers/test_http.py @@ -19,33 +19,33 @@ def test_init(self, app): assert worker._app == app assert worker._url == "foo" - assert not hasattr(worker, "_client") + assert not hasattr(worker, "client") def test_client(self, worker): with pytest.raises(AttributeError, match="Client not initialized"): worker.client - async def test_begin_transaction(self, worker): - worker._client = MagicMock() + async def test_set_up(self, worker): + with patch("flama.client.Client"): + await worker.set_up() + assert worker.client.__aenter__.await_args_list == [call()] - await worker.begin_transaction() - assert worker._client.__aenter__.await_args_list == [call()] + async def test_tear_down(self, worker): + client_mock = MagicMock() + worker.client = client_mock - async def test_end_transaction(self, worker): - worker._client = MagicMock() - - await worker.end_transaction() - assert worker._client.__aexit__.await_args_list == [call()] + await worker.tear_down() + assert client_mock.__aexit__.await_args_list == [call()] async def test_begin(self, worker): - with patch.object(worker, "begin_transaction"), patch("flama.client.Client"): + worker.client = MagicMock() + + with patch.object(worker, "set_up"): assert not hasattr(worker, "bar") - assert not hasattr(worker, "_client") await worker.begin() - assert hasattr(worker, "_client") - assert worker.begin_transaction.await_args_list == [call()] + assert worker.set_up.await_args_list == [call()] assert hasattr(worker, "bar") assert isinstance(worker.bar, HTTPRepository) @@ -58,14 +58,12 @@ async def test_begin(self, worker): ) async def test_end(self, worker, rollback): worker.bar = MagicMock() - worker._client = MagicMock() + worker.client = MagicMock() - with patch.object(worker, "end_transaction"): + with patch.object(worker, "tear_down"): assert hasattr(worker, "bar") - assert hasattr(worker, "_client") await worker.end(rollback=rollback) - assert worker.end_transaction.await_args_list == [call()] + assert worker.tear_down.await_args_list == [call(rollback=rollback)] assert not hasattr(worker, "bar") - assert not hasattr(worker, "_client") diff --git a/tests/ddd/workers/test_sqlalchemy.py b/tests/ddd/workers/test_sqlalchemy.py index 350a419e..72c9d88d 100644 --- a/tests/ddd/workers/test_sqlalchemy.py +++ b/tests/ddd/workers/test_sqlalchemy.py @@ -29,7 +29,7 @@ def test_transaction(self, worker): with pytest.raises(AttributeError, match="Transaction not started"): worker.transaction - async def test_begin_transaction(self, app, worker): + async def test_set_up(self, app, worker): connection_mock = AsyncMock() transaction_mock = AsyncMock() @@ -38,7 +38,7 @@ async def test_begin_transaction(self, app, worker): open_connection=AsyncMock(return_value=connection_mock), begin_transaction=AsyncMock(return_value=transaction_mock), ): - await worker.begin_transaction() + await worker.set_up() assert worker._connection == connection_mock assert worker._transaction == transaction_mock @@ -52,12 +52,12 @@ async def test_begin_transaction(self, app, worker): pytest.param(False, id="commit"), ), ) - async def test_end_transaction(self, app, worker, rollback): + async def test_tear_down(self, app, worker, rollback): worker._connection = connection_mock = AsyncMock() worker._transaction = transaction_mock = AsyncMock() with patch.multiple(app.sqlalchemy, end_transaction=AsyncMock(), close_connection=AsyncMock()): - await worker.end_transaction(rollback=rollback) + await worker.tear_down(rollback=rollback) assert not hasattr(worker, "_transaction") assert not hasattr(worker, "_connection") @@ -67,12 +67,12 @@ async def test_end_transaction(self, app, worker, rollback): async def test_begin(self, worker): worker._connection = AsyncMock() - with patch.object(worker, "begin_transaction"): + with patch.object(worker, "set_up"): assert not hasattr(worker, "bar") await worker.begin() - assert worker.begin_transaction.await_args_list == [call()] + assert worker.set_up.await_args_list == [call()] assert hasattr(worker, "bar") assert isinstance(worker.bar, SQLAlchemyRepository) @@ -86,12 +86,12 @@ async def test_begin(self, worker): async def test_end(self, worker, rollback): worker.bar = MagicMock() - with patch.object(worker, "end_transaction"): + with patch.object(worker, "tear_down"): assert hasattr(worker, "bar") await worker.end(rollback=rollback) - assert worker.end_transaction.await_args_list == [call(rollback=rollback)] + assert worker.tear_down.await_args_list == [call(rollback=rollback)] assert not hasattr(worker, "bar") async def test_commit(self, worker): diff --git a/tests/ddd/workers/test_noop.py b/tests/ddd/workers/test_worker.py similarity index 59% rename from tests/ddd/workers/test_noop.py rename to tests/ddd/workers/test_worker.py index 9614c650..d38e4ef5 100644 --- a/tests/ddd/workers/test_noop.py +++ b/tests/ddd/workers/test_worker.py @@ -1,17 +1,17 @@ import pytest -from flama.ddd.workers import NoopWorker +from flama.ddd.workers import Worker -class TestCaseNoopWorker: +class TestCaseWorker: @pytest.fixture(scope="function") def worker(self, client): - class FooWorker(NoopWorker): + class FooWorker(Worker): ... return FooWorker(client.app) def test_init(self, app): - worker = NoopWorker(app) + worker = Worker(app) assert worker._app == app