Skip to content

Commit

Permalink
✨ Allow importing ddd without sqlalchemy installed (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Nov 24, 2024
1 parent 43eefb9 commit 259a873
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 39 deletions.
9 changes: 8 additions & 1 deletion flama/config/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import yaml

from flama import exceptions

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover
try:
import tomli
Expand Down Expand Up @@ -88,7 +90,12 @@ def load(self, f: t.Union[str, os.PathLike]) -> dict[str, t.Any]:
:param f: File path.
:return: Dict with the file contents.
"""
assert tomllib is not None, "`tomli` must be installed to use TOMLFileLoader in Python versions older than 3.11"
if tomllib is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.tomli,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
msg="for Python versions lower than 3.11",
)

with open(f, "rb") as fs:
return tomllib.load(fs)
90 changes: 77 additions & 13 deletions flama/ddd/repositories/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import typing as t

from flama.ddd import exceptions
from flama import exceptions
from flama.ddd import exceptions as ddd_exceptions
from flama.ddd.repositories import AbstractRepository

try:
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.exc as sqlalchemy_exceptions
except Exception: # pragma: no cover
raise AssertionError("`sqlalchemy[asyncio]` must be installed to use ddd") from None
sqlalchemy = None
sqlalchemy_exceptions = None


if t.TYPE_CHECKING:
Expand All @@ -16,11 +18,19 @@
except Exception: # pragma: no cover
...

__all__ = ["SQLAlchemyRepository", "SQLAlchemyTableManager", "SQLAlchemyTableRepository"]


class SQLAlchemyRepository(AbstractRepository):
"""Base class for SQLAlchemy repositories. It provides a connection to the database."""

def __init__(self, connection: "AsyncConnection", *args, **kwargs):
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

super().__init__(*args, **kwargs)
self._connection = connection

Expand All @@ -29,7 +39,13 @@ def __eq__(self, other):


class SQLAlchemyTableManager:
def __init__(self, table: sqlalchemy.Table, connection: "AsyncConnection"):
def __init__(self, table: sqlalchemy.Table, connection: "AsyncConnection"): # type: ignore
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

self._connection = connection
self.table = table

Expand All @@ -50,10 +66,16 @@ async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]:
:return: The created elements.
:raises IntegrityError: If the element already exists or cannot be inserted.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

try:
result = await self._connection.execute(sqlalchemy.insert(self.table).values(data).returning(self.table))
except sqlalchemy.exc.IntegrityError as e:
raise exceptions.IntegrityError(str(e))
except sqlalchemy_exceptions.IntegrityError as e:
raise ddd_exceptions.IntegrityError(str(e))
return [dict[str, t.Any](element._asdict()) for element in result]

async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]:
Expand All @@ -72,14 +94,20 @@ async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]:
:raises NotFoundError: If the element does not exist.
:raises MultipleRecordsError: If more than one element is found.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters)

try:
element = (await self._connection.execute(query)).one()
except sqlalchemy.exc.NoResultFound:
raise exceptions.NotFoundError()
except sqlalchemy.exc.MultipleResultsFound:
raise exceptions.MultipleRecordsError()
except sqlalchemy_exceptions.NoResultFound:
raise ddd_exceptions.NotFoundError()
except sqlalchemy_exceptions.MultipleResultsFound:
raise ddd_exceptions.MultipleRecordsError()

return dict[str, t.Any](element._asdict())

Expand All @@ -93,14 +121,20 @@ async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict
:return: The updated elements.
:raises IntegrityError: If the elements cannot be updated.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = (
self._filter_query(sqlalchemy.update(self.table), *clauses, **filters).values(**data).returning(self.table)
)

try:
result = await self._connection.execute(query)
except sqlalchemy.exc.IntegrityError:
raise exceptions.IntegrityError
except sqlalchemy_exceptions.IntegrityError:
raise ddd_exceptions.IntegrityError

return [dict[str, t.Any](element._asdict()) for element in result]

Expand All @@ -118,6 +152,12 @@ async def delete(self, *clauses, **filters) -> None:
:raises NotFoundError: If the element does not exist.
:raises MultipleRecordsError: If more than one element is found.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

await self.retrieve(*clauses, **filters)

query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters)
Expand Down Expand Up @@ -145,6 +185,12 @@ async def list(
:param filters: Filters to filter the elements.
:return: Async iterable of the elements.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters)

if order_by:
Expand Down Expand Up @@ -173,6 +219,12 @@ async def drop(self, *clauses, **filters) -> int:
:param filters: Filters to filter the elements.
:return: The number of elements dropped.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters)

result = await self._connection.execute(query)
Expand All @@ -196,6 +248,12 @@ def _filter_query(self, query, *clauses, **filters):
:param filters: Filters to filter the elements.
:return: The filtered query.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items())

if where_clauses:
Expand All @@ -205,9 +263,15 @@ def _filter_query(self, query, *clauses, **filters):


class SQLAlchemyTableRepository(SQLAlchemyRepository):
_table: t.ClassVar[sqlalchemy.Table]
_table: t.ClassVar[sqlalchemy.Table] # type: ignore

def __init__(self, connection: "AsyncConnection", *args, **kwargs):
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

super().__init__(connection, *args, **kwargs)
self._table_manager = SQLAlchemyTableManager(self._table, connection)

Expand Down
5 changes: 1 addition & 4 deletions flama/ddd/workers/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from flama.ddd.workers.base import AbstractWorker

if t.TYPE_CHECKING:
try:
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncTransaction
except Exception: # pragma: no cover
...
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncTransaction

__all__ = ["SQLAlchemyWorker"]

Expand Down
52 changes: 52 additions & 0 deletions flama/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import enum
import http
import sys
import typing as t

import starlette.exceptions

__all__ = [
"ApplicationError",
"DependencyNotInstalled",
"SQLAlchemyError",
"DecodeError",
"HTTPException",
Expand All @@ -18,11 +21,60 @@
"FrameworkVersionWarning",
]

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover

class StrEnum(str, enum.Enum):
@staticmethod
def _generate_next_value_(name, start, count, last_values):
return name.lower()

enum.StrEnum = StrEnum # type: ignore


class ApplicationError(Exception):
...


class DependencyNotInstalled(ApplicationError):
class Dependency(enum.StrEnum): # type: ignore # PORT: Remove this comment when stop supporting 3.10
pydantic = "pydantic"
marshmallow = "marshmallow"
apispec = "apispec"
typesystem = "typesystem"
sqlalchemy = "sqlalchemy[asyncio]"
httpx = "httpx"
tomli = "tomli"

def __init__(
self,
*,
dependency: t.Optional[t.Union[str, Dependency]] = None,
dependant: t.Optional[str] = None,
msg: str = "",
) -> None:
super().__init__()
self.dependency = self.Dependency(dependency) if dependency else None
self.dependant = dependant
self.msg = msg

def __str__(self) -> str:
if self.dependency:
s = f"Dependency '{self.dependency.value}' must be installed"
if self.dependant:
s += f" to use '{self.dependant}'"
if self.msg:
s += f" ({self.msg})"
else:
s = self.msg

return s

def __repr__(self) -> str:
params = ("msg", "dependency", "dependant")
formatted_params = ", ".join([f"{x}={getattr(self, x)}" for x in params if getattr(self, x)])
return f"{self.__class__.__name__}({formatted_params})"


class SQLAlchemyError(ApplicationError):
...

Expand Down
34 changes: 23 additions & 11 deletions flama/resources/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing as t
import uuid

from flama import exceptions
from flama.ddd.repositories import SQLAlchemyTableRepository
from flama.resources import data_structures
from flama.resources.exceptions import ResourceAttributeError
Expand All @@ -11,18 +12,11 @@
import sqlalchemy
from sqlalchemy.dialects import postgresql
except Exception: # pragma: no cover
raise AssertionError("`sqlalchemy[asyncio]` must be installed to use rest resources") from None
sqlalchemy = None
postgresql = None

__all__ = ["RESTResource", "RESTResourceType"]

PK_MAPPING: dict[t.Any, t.Any] = {
sqlalchemy.Integer: int,
sqlalchemy.String: str,
sqlalchemy.Date: datetime.date,
sqlalchemy.DateTime: datetime.datetime,
postgresql.UUID: uuid.UUID,
}


class RESTResourceType(ResourceType):
def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, t.Any]):
Expand All @@ -34,6 +28,11 @@ def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, t.Any]):
:param bases: List of superclasses.
:param namespace: Variables namespace used to create the class.
"""
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy, dependant="RESTResourceType"
)

if not mcs._is_abstract(namespace):
try:
# Get model
Expand Down Expand Up @@ -69,6 +68,12 @@ def _get_model(cls, bases: t.Sequence[t.Any], namespace: dict[str, t.Any]) -> da
:param namespace: Variables namespace used to create the class.
:return: Resource model.
"""
if sqlalchemy is None or postgresql is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{cls.__module__}.{cls.__name__}",
)

model = cls._get_attribute("model", bases, namespace, metadata_namespace="rest")

# Already defined model probably because resource inheritance, so no need to create it
Expand All @@ -89,7 +94,14 @@ def _get_model(cls, bases: t.Sequence[t.Any], namespace: dict[str, t.Any]) -> da

# Check primary key is a valid type
try:
model_pk_type = PK_MAPPING[model_pk.type.__class__]
model_pk_mapping: dict[type, type] = {
sqlalchemy.Integer: int,
sqlalchemy.String: str,
sqlalchemy.Date: datetime.date,
sqlalchemy.DateTime: datetime.datetime,
postgresql.UUID: uuid.UUID,
}
model_pk_type = model_pk_mapping[model_pk.type.__class__]
except KeyError:
raise AttributeError(ResourceAttributeError.PK_WRONG_TYPE)

Expand Down Expand Up @@ -142,7 +154,7 @@ def _get_schemas(cls, name: str, bases: t.Sequence[t.Any], namespace: dict[str,


class RESTResource(Resource, metaclass=RESTResourceType):
model: sqlalchemy.Table
model: sqlalchemy.Table # type: ignore
schema: t.Any
input_schema: t.Any
output_schema: t.Any
Loading

0 comments on commit 259a873

Please sign in to comment.