Skip to content
112 changes: 40 additions & 72 deletions fastapi_code_generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pathlib
import re
from functools import reduce
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -51,7 +52,7 @@
class CachedPropertyModel(BaseModel):
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
ignored_types = (cached_property,)


class Response(BaseModel):
Expand Down Expand Up @@ -117,8 +118,10 @@ class Operation(CachedPropertyModel):
imports: List[Import] = []
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[str]] = []
arguments: str = ''
snake_case_arguments: str = ''
arguments: List[Argument] = []
plain_parameters: str = ''
plain_arguments: str = ''
snake_case_arguments: List[Argument] = []
request: Optional[Argument] = None
response: str = ''
additional_responses: Dict[Union[str, int], Dict[str, str]] = {}
Expand Down Expand Up @@ -317,11 +320,6 @@ def get_parameter_type(
required=field.required,
)

def get_arguments(self, snake_case: bool, path: List[str]) -> str:
return ", ".join(
argument.argument for argument in self.get_argument_list(snake_case, path)
)

def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]:
arguments: List[Argument] = []

Expand All @@ -334,9 +332,15 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
if parameter_type:
arguments.append(parameter_type)

request = self._temporary_operation.get('_request')
if request:
arguments.append(request)
arguments.append(
Argument(
name='request', # type: ignore
type_hint='Request', # type: ignore
required=True,
)
)

self.imports_for_fastapi.append(Import.from_full_path("fastapi.Request"))

positional_argument: bool = False
for argument in arguments:
Expand All @@ -358,66 +362,21 @@ def parse_request_body(
) -> None:
super().parse_request_body(name, request_body, path)
arguments: List[Argument] = []
for (
media_type,
media_obj,
) in request_body.content.items(): # type: str, MediaObject
if isinstance(
media_obj.schema_, (JsonSchemaObject, ReferenceObject)
): # pragma: no cover
# TODO: support other content-types
if RE_APPLICATION_JSON_PATTERN.match(media_type):
if isinstance(media_obj.schema_, ReferenceObject):
data_type = self.get_ref_data_type(media_obj.schema_.ref)
else:
data_type = self.parse_schema(
name, media_obj.schema_, [*path, media_type]
)
data_type = self._collapse_root_model(data_type)
arguments.append(
# TODO: support multiple body
Argument(
name='body', # type: ignore
type_hint=UsefulStr(data_type.type_hint),
required=request_body.required,
)
)
self.data_types.append(data_type)
elif media_type == 'application/x-www-form-urlencoded':
arguments.append(
# TODO: support form with `Form()`
Argument(
name='request', # type: ignore
type_hint='Request', # type: ignore
required=True,
)
)
self.imports_for_fastapi.append(
Import.from_full_path('starlette.requests.Request')
for media_obj in request_body.content.values():
if isinstance(media_obj.schema_, JsonSchemaObject) and (
media_obj.schema_.format == 'binary'
):
arguments.append(
Argument(
name='file', # type: ignore
type_hint='UploadFile', # type: ignore
required=True,
)
elif media_type == 'application/octet-stream':
arguments.append(
Argument(
name='request', # type: ignore
type_hint='Request', # type: ignore
required=True,
)
)
self.imports_for_fastapi.append(
Import.from_full_path("fastapi.Request")
)
elif media_type == 'multipart/form-data':
arguments.append(
Argument(
name='file', # type: ignore
type_hint='UploadFile', # type: ignore
required=True,
)
)
self.imports_for_fastapi.append(
Import.from_full_path("fastapi.UploadFile")
)
self._temporary_operation['_request'] = arguments[0] if arguments else None
)
self.imports_for_fastapi.append(
Import.from_full_path("fastapi.UploadFile")
)
self._temporary_operation['_request'] = arguments

def parse_responses( # type: ignore[override]
self,
Expand Down Expand Up @@ -467,12 +426,21 @@ def parse_operation(
resolved_path = self.model_resolver.resolve_ref(path)
path_name, method = path[-2:]

self._temporary_operation['arguments'] = self.get_arguments(
self._temporary_operation['arguments'] = self.get_argument_list(
snake_case=False, path=path
)
self._temporary_operation['snake_case_arguments'] = self.get_arguments(
self._temporary_operation['snake_case_arguments'] = self.get_argument_list(
snake_case=True, path=path
)
self._temporary_operation['plain_arguments'] = ",".join(
map(lambda a: a.name, self._temporary_operation['snake_case_arguments'])
)
self._temporary_operation['plain_parameters'] = ",".join(
map(
lambda a: f'{a.name}{": " + a.type_hint if a.type_hint is not None else ""}',
self._temporary_operation['snake_case_arguments'],
)
)
main_operation = self._temporary_operation

# Handle callbacks. This iterates over callbacks, shifting each one
Expand Down