Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c964f45
Add FunDI > FastAPI compatibility layer
kuyugama Aug 2, 2025
4af07ec
Add extra dependencies support
kuyugama Aug 2, 2025
ba9cb3e
Remove unused introspection
kuyugama Aug 2, 2025
956e845
Skip on removed logic
kuyugama Aug 2, 2025
e2dff3a
divide fundi.compat.fastapi into several submodules to ensure maintai…
kuyugama Aug 3, 2025
cfe0a98
Make AsyncExitStack closing after response is sent
kuyugama Aug 4, 2025
f496d40
Separate body and response exit stacks to ensure body will be closed …
kuyugama Aug 4, 2025
a97fafc
Add router
kuyugama Aug 4, 2025
2bb21db
Include fastapi dependency group in dev group
kuyugama Aug 4, 2025
a4e9b23
Add fastapi.security support via ``secured(dependency, scopes)``
kuyugama Aug 5, 2025
ea2393d
Generate dependency scope using it's own dependant, not flat dependan…
kuyugama Aug 12, 2025
2ad9bda
Store request related aliases inside callable info metadata
kuyugama Aug 12, 2025
e533d1d
Add injection tests
kuyugama Aug 12, 2025
fc6ad00
Merge main into fundi>fastapi compatibility layer
kuyugama Aug 15, 2025
84dc95e
Fix virtual context manager nesting
kuyugama Aug 17, 2025
05c8b77
Rename Virtual[Async]ContextManager to Virtual[Async]ContextProvider
kuyugama Aug 24, 2025
25f60bc
Merge pull request #43 from KuyuCode/fix/virtual-context
kuyugama Aug 24, 2025
446d2f5
Improve scan for types
kuyugama Sep 3, 2025
a08d407
Merge pull request #44 from KuyuCode/feat/improve-types-scan
kuyugama Sep 3, 2025
a364aea
Move security_scopes and request related aliases into FastAPIs Dependant
kuyugama Aug 17, 2025
810574a
Merge main into feat/compat/fastapi
kuyugama Sep 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fundi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .inject import inject, ainject
from .configurable import configurable_dependency, MutableConfigurationWarning
from .util import injection_trace, is_configured, get_configuration, normalize_annotation
from .virtual_context import virtual_context, VirtualContextManager, AsyncVirtualContextManager
from .virtual_context import virtual_context, VirtualContextProvider, AsyncVirtualContextProvider
from .types import CallableInfo, TypeResolver, InjectionTrace, R, Parameter, DependencyConfiguration


Expand All @@ -33,9 +33,9 @@
"injection_trace",
"get_configuration",
"normalize_annotation",
"VirtualContextManager",
"VirtualContextProvider",
"DependencyConfiguration",
"configurable_dependency",
"AsyncVirtualContextManager",
"AsyncVirtualContextProvider",
"MutableConfigurationWarning",
]
13 changes: 13 additions & 0 deletions fundi/compat/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .secured import secured
from .route import FunDIRoute
from .router import FunDIRouter
from .handler import get_request_handler
from .dependant import get_scope_dependant

__all__ = [
"secured",
"FunDIRoute",
"FunDIRouter",
"get_request_handler",
"get_scope_dependant",
]
5 changes: 5 additions & 0 deletions fundi/compat/fastapi/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__all__ = ["METADATA_SECURITY_SCOPES", "METADATA_DEPENDANT", "METADATA_SCOPE_EXTRA"]

METADATA_SECURITY_SCOPES = "fastapi_security_scopes"
METADATA_DEPENDANT = "fastapi_dependant"
METADATA_SCOPE_EXTRA = "scope_extra"
106 changes: 106 additions & 0 deletions fundi/compat/fastapi/dependant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import typing

from fastapi import params
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
from fastapi.dependencies.models import Dependant, SecurityRequirement

from fundi.util import callable_str
from fundi.types import CallableInfo

from .metadata import build_metadata, get_metadata
from .constants import METADATA_DEPENDANT, METADATA_SECURITY_SCOPES

from fastapi.dependencies.utils import (
analyze_param,
add_param_to_fields,
add_non_field_param_to_dependency,
)

MF = typing.TypeVar("MF", bound=ModelField)


def merge(into: list[MF], from_: list[MF]):
names = {field.name for field in into}

for field in from_:
if field.name not in names:
into.append(field)


def update_dependant(source: Dependant, target: Dependant):
merge(target.path_params, source.path_params)
merge(target.query_params, source.query_params)
merge(target.header_params, source.header_params)
merge(target.cookie_params, source.cookie_params)
merge(target.body_params, source.body_params)

target.security_requirements.extend(source.security_requirements)
target.dependencies.extend(source.dependencies)
if source.security_scopes:
if target.security_scopes is None:
target.security_scopes = []

target.security_scopes[::] = set().union(target.security_scopes, source.security_scopes)


def get_scope_dependant(
ci: CallableInfo[typing.Any],
path_param_names: set[str],
path: str,
) -> Dependant:
build_metadata(ci)

dependant = Dependant(path=path)
dependant_metadata = get_metadata(ci)
dependant.security_scopes = dependant_metadata[METADATA_SECURITY_SCOPES]

dependant_metadata.update({METADATA_DEPENDANT: dependant})

flat_dependant = Dependant(
path=path, security_scopes=dependant_metadata[METADATA_SECURITY_SCOPES]
)

for param in ci.parameters:
if param.from_ is not None:
subci = param.from_

sub = get_scope_dependant(subci, path_param_names, path)
update_dependant(sub, flat_dependant)

# This is required to pass security_scopes to dependency.
# Here parameter name and security scopes itself are set.
metadata = get_metadata(subci)

if isinstance(subci.call, SecurityBase):
flat_dependant.security_requirements.append(
SecurityRequirement(subci.call, metadata[METADATA_SECURITY_SCOPES])
)

continue

details = analyze_param(
param_name=param.name,
annotation=param.annotation,
value=param.default,
is_path_param=param.name in path_param_names,
)

if add_non_field_param_to_dependency(
param_name=param.name, type_annotation=param.annotation, dependant=dependant
):
assert (
details.field is None
), f'Non-field parameter shouldn\'t have field: error caused by analysis of the parameter "{param.name}" in {callable_str(ci.call)}'

continue

assert details.field is not None
if isinstance(details.field.field_info, params.Body):
dependant.body_params.append(details.field)
else:
add_param_to_fields(field=details.field, dependant=dependant)

update_dependant(dependant, flat_dependant)

return flat_dependant
112 changes: 112 additions & 0 deletions fundi/compat/fastapi/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import typing
from collections.abc import Coroutine
from contextlib import AsyncExitStack

from fastapi.types import IncEx
from starlette.requests import Request
from fastapi._compat import ModelField
from fastapi.routing import serialize_response
from starlette.background import BackgroundTasks
from starlette.responses import JSONResponse, Response
from fastapi.utils import is_body_allowed_for_status_code
from fastapi.datastructures import Default, DefaultPlaceholder

from .inject import inject
from fundi.types import CallableInfo

from .validate_request_body import validate_body


def get_request_handler(
ci: CallableInfo[typing.Any],
extra_dependencies: list[CallableInfo[typing.Any]],
body_field: ModelField | None = None,
status_code: int | None = None,
response_class: type[Response] | DefaultPlaceholder = Default(JSONResponse),
response_field: ModelField | None = None,
response_model_include: IncEx | None = None,
response_model_exclude: IncEx | None = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: typing.Any | None = None,
embed_body_fields: bool = False,
) -> typing.Callable[[Request], Coroutine[typing.Any, typing.Any, Response]]:

if isinstance(response_class, DefaultPlaceholder):
actual_response_class: type[Response] = response_class.value
else:
actual_response_class = response_class

async def app(request: Request) -> Response:
background_tasks = BackgroundTasks()
stack = AsyncExitStack()
# Close exit stack at after the response is sent
background_tasks.add_task(stack.aclose)

response = Response()
del response.headers["content-length"]
response.status_code = None # pyright: ignore[reportAttributeAccessIssue]

body_stack = AsyncExitStack()
async with body_stack:
body = await validate_body(request, body_stack, body_field)

for dependency in extra_dependencies:
await inject(
dependency,
stack,
request,
body,
dependency_overrides_provider,
embed_body_fields,
background_tasks,
response,
)

raw_response = await inject(
ci,
stack,
request,
body,
dependency_overrides_provider,
embed_body_fields,
background_tasks,
response,
)

if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks

return raw_response

response_args: dict[str, typing.Any] = {"background": background_tasks}

# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
status = response.status_code or status_code
if status is not None:
response_args["status_code"] = status

content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=ci.async_,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""

response.headers.raw.extend(response.headers.raw)

return response

return app
101 changes: 101 additions & 0 deletions fundi/compat/fastapi/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import typing
import contextlib
import collections.abc

from starlette.requests import Request
from starlette.responses import Response
from starlette.datastructures import FormData
from fastapi._compat import _normalize_errors # pyright: ignore[reportPrivateUsage]
from starlette.background import BackgroundTasks
from fastapi.exceptions import RequestValidationError
from fastapi.dependencies.utils import solve_dependencies

from fundi.inject import injection_impl
from fundi.util import call_async, call_sync
from fundi.types import CacheKey, CallableInfo

from .metadata import get_metadata
from .types import DependencyOverridesProvider
from .constants import METADATA_DEPENDANT, METADATA_SCOPE_EXTRA, METADATA_SECURITY_SCOPES


async def inject(
info: CallableInfo[typing.Any],
stack: contextlib.AsyncExitStack,
request: Request,
body: FormData | typing.Any | None,
dependency_overrides_provider: DependencyOverridesProvider | None,
embed_body_fields: bool,
background_tasks: BackgroundTasks,
response: Response,
cache: collections.abc.MutableMapping[CacheKey, typing.Any] | None = None,
override: collections.abc.Mapping[typing.Callable[..., typing.Any], typing.Any] | None = None,
) -> typing.Any:
"""
Asynchronously inject dependencies into callable.

:param scope: container with contextual values
:param info: callable information
:param stack: exit stack to properly handle generator dependencies
:param cache: dependency cache
:param override: override dependencies
:return: result of callable
"""
if cache is None:
cache = {}

metadata = get_metadata(info)

fastapi_params = await solve_dependencies(
request=request,
dependant=metadata[METADATA_DEPENDANT],
body=body,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=stack,
embed_body_fields=embed_body_fields,
background_tasks=background_tasks,
response=response,
)

if fastapi_params.errors:
raise RequestValidationError(_normalize_errors(fastapi_params.errors), body=body)

scope = fastapi_params.values

scope_extra: collections.abc.Mapping[str, typing.Any] = metadata.get(METADATA_SCOPE_EXTRA, {})

if scope_extra:
scope = {**scope, **scope_extra}

gen = injection_impl(scope, info, cache, override)

value: typing.Any | None = None

try:
while True:
inner_scope, inner_info, more = gen.send(value)

if more:
value = await inject(
inner_info,
stack,
request,
body,
dependency_overrides_provider,
embed_body_fields,
background_tasks,
response,
cache,
override,
)
continue

if info.async_:
return await call_async(stack, inner_info, inner_scope)

return call_sync(stack, inner_info, inner_scope)
except Exception as exc:
with contextlib.suppress(StopIteration):
gen.throw(type(exc), exc, exc.__traceback__)

raise
Loading