Skip to content

Commit 88c0b38

Browse files
committed
fix: Run ruff formatting and resolve issues
Run 'ruff check --fix' followed by 'ruff format' and resolve remaining errors manually. For the most part, the manual changes involved one of the 3 following actions: * Split long strings into multiple lines wrapped with parentheses * Refactor code, e.g. retrieve dict values and evaluate conditionals before using the results in f-strings * Move long strings to files (this can be done in many more places around the codebase) Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent e74b6ec commit 88c0b38

File tree

85 files changed

+5726
-2668
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+5726
-2668
lines changed

app/api/api.py

+47-22
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
import logging
21
import asyncio
32
import importlib
3+
import logging
44
import os.path
5-
import api.globals as cms_globals
6-
7-
from typing import Dict, Any, Optional
85
from concurrent.futures import ThreadPoolExecutor
9-
from anyio.lowlevel import RunVar
6+
from typing import Any, Dict, Optional
7+
108
from anyio import CapacityLimiter
9+
from anyio.lowlevel import RunVar
1110
from fastapi import FastAPI, Request
11+
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
1212
from fastapi.openapi.utils import get_openapi
13-
from fastapi.responses import RedirectResponse, HTMLResponse
13+
from fastapi.responses import HTMLResponse, RedirectResponse
1414
from fastapi.staticfiles import StaticFiles
15-
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
1615
from prometheus_fastapi_instrumentator import Instrumentator
1716

17+
from domain import Tags, TagsStreamable
18+
from utils import get_settings
19+
20+
import api.globals as cms_globals
1821
from api.auth.db import make_sure_db_and_tables
1922
from api.auth.users import Props
2023
from api.dependencies import ModelServiceDep
2124
from api.utils import add_exception_handlers, add_rate_limiter
22-
from domain import Tags, TagsStreamable
2325
from management.tracker_client import TrackerClient
24-
from utils import get_settings
25-
2626

2727
logging.getLogger("asyncio").setLevel(logging.ERROR)
2828
logger = logging.getLogger("cms")
@@ -87,25 +87,37 @@ def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> Fast
8787
return app
8888

8989

90-
def _get_app(msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False) -> FastAPI:
91-
tags_metadata = [{"name": tag.name, "description": tag.value} for tag in (Tags if not streamable else TagsStreamable)]
90+
def _get_app(
91+
msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False
92+
) -> FastAPI:
93+
tags_metadata = [
94+
{"name": tag.name, "description": tag.value}
95+
for tag in (Tags if not streamable else TagsStreamable)
96+
]
9297
config = get_settings()
93-
app = FastAPI(title="CogStack ModelServe",
94-
summary="A model serving and governance system for CogStack NLP solutions",
95-
docs_url=None,
96-
redoc_url=None,
97-
debug=(config.DEBUG == "true"),
98-
openapi_tags=tags_metadata)
98+
app = FastAPI(
99+
title="CogStack ModelServe",
100+
summary="A model serving and governance system for CogStack NLP solutions",
101+
docs_url=None,
102+
redoc_url=None,
103+
debug=(config.DEBUG == "true"),
104+
openapi_tags=tags_metadata,
105+
)
99106
add_exception_handlers(app)
100107
instrumentator = Instrumentator(
101-
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app)
108+
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]
109+
).instrument(app)
102110

103111
if msd_overwritten is not None:
104112
cms_globals.model_service_dep = msd_overwritten
105113

106114
cms_globals.props = Props(config.AUTH_USER_ENABLED == "true")
107115

108-
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
116+
app.mount(
117+
"/static",
118+
StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")),
119+
name="static",
120+
)
109121

110122
@app.on_event("startup")
111123
async def on_startup() -> None:
@@ -160,8 +172,11 @@ def custom_openapi() -> Dict[str, Any]:
160172
openapi_schema = get_openapi(
161173
title=f"{cms_globals.model_service_dep().model_name} APIs",
162174
version=cms_globals.model_service_dep().api_version,
163-
description="by CogStack ModelServe, a model serving and governance system for CogStack NLP solutions.",
164-
routes=app.routes
175+
description=(
176+
"by CogStack ModelServe, a model serving and governance system for CogStack NLP"
177+
" solutions."
178+
),
179+
routes=app.routes,
165180
)
166181
openapi_schema["info"]["x-logo"] = {
167182
"url": "https://avatars.githubusercontent.com/u/28688163?s=200&v=4"
@@ -189,69 +204,79 @@ def custom_openapi() -> Dict[str, Any]:
189204

190205
def _load_auth_router(app: FastAPI) -> FastAPI:
191206
from api.routers import authentication
207+
192208
importlib.reload(authentication)
193209
app.include_router(authentication.router)
194210
return app
195211

196212

197213
def _load_model_card(app: FastAPI) -> FastAPI:
198214
from api.routers import model_card
215+
199216
importlib.reload(model_card)
200217
app.include_router(model_card.router)
201218
return app
202219

203220

204221
def _load_invocation_router(app: FastAPI) -> FastAPI:
205222
from api.routers import invocation
223+
206224
importlib.reload(invocation)
207225
app.include_router(invocation.router)
208226
return app
209227

210228

211229
def _load_supervised_training_router(app: FastAPI) -> FastAPI:
212230
from api.routers import supervised_training
231+
213232
importlib.reload(supervised_training)
214233
app.include_router(supervised_training.router)
215234
return app
216235

217236

218237
def _load_evaluation_router(app: FastAPI) -> FastAPI:
219238
from api.routers import evaluation
239+
220240
importlib.reload(evaluation)
221241
app.include_router(evaluation.router)
222242
return app
223243

224244

225245
def _load_preview_router(app: FastAPI) -> FastAPI:
226246
from api.routers import preview
247+
227248
importlib.reload(preview)
228249
app.include_router(preview.router)
229250
return app
230251

231252

232253
def _load_unsupervised_training_router(app: FastAPI) -> FastAPI:
233254
from api.routers import unsupervised_training
255+
234256
importlib.reload(unsupervised_training)
235257
app.include_router(unsupervised_training.router)
236258
return app
237259

238260

239261
def _load_metacat_training_router(app: FastAPI) -> FastAPI:
240262
from api.routers import metacat_training
263+
241264
importlib.reload(metacat_training)
242265
app.include_router(metacat_training.router)
243266
return app
244267

245268

246269
def _load_health_check_router(app: FastAPI) -> FastAPI:
247270
from api.routers import health_check
271+
248272
importlib.reload(health_check)
249273
app.include_router(health_check.router)
250274
return app
251275

252276

253277
def _load_stream_router(app: FastAPI) -> FastAPI:
254278
from api.routers import stream
279+
255280
importlib.reload(stream)
256281
app.include_router(stream.router, prefix="/stream")
257282
return app

app/api/auth/backends.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
from functools import lru_cache
22
from typing import List
3-
from fastapi_users.authentication.transport.base import Transport
3+
4+
from fastapi_users.authentication import (
5+
AuthenticationBackend,
6+
BearerTransport,
7+
CookieTransport,
8+
JWTStrategy,
9+
)
410
from fastapi_users.authentication.strategy.base import Strategy
5-
from fastapi_users.authentication import BearerTransport, JWTStrategy
6-
from fastapi_users.authentication import AuthenticationBackend, CookieTransport
11+
from fastapi_users.authentication.transport.base import Transport
12+
713
from utils import get_settings
814

915

1016
@lru_cache
1117
def get_backends() -> List[AuthenticationBackend]:
1218
return [
13-
AuthenticationBackend(name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy),
14-
AuthenticationBackend(name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy),
19+
AuthenticationBackend(
20+
name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy
21+
),
22+
AuthenticationBackend(
23+
name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy
24+
),
1525
]
1626

1727

@@ -24,4 +34,7 @@ def _get_cookie_transport() -> Transport:
2434

2535

2636
def _get_strategy() -> Strategy:
27-
return JWTStrategy(secret=get_settings().AUTH_JWT_SECRET, lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS)
37+
return JWTStrategy(
38+
secret=get_settings().AUTH_JWT_SECRET,
39+
lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS,
40+
)

app/api/auth/db.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
55
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
66
from sqlalchemy.orm import DeclarativeBase
7+
78
from utils import get_settings
89

910

@@ -29,5 +30,7 @@ async def make_sure_db_and_tables() -> None:
2930
await conn.run_sync(Base.metadata.create_all)
3031

3132

32-
async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
33+
async def get_user_db(
34+
session: AsyncSession = Depends(_get_async_session),
35+
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
3336
yield SQLAlchemyUserDatabase(session, User)

app/api/auth/users.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
import uuid
21
import logging
3-
from typing import Optional, AsyncGenerator, List, Callable
2+
import uuid
3+
from typing import AsyncGenerator, Callable, List, Optional
4+
45
from fastapi import Depends, Request
56
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
6-
from fastapi_users.db import SQLAlchemyUserDatabase
77
from fastapi_users.authentication import AuthenticationBackend
8-
from api.auth.db import User, get_user_db
9-
from api.auth.backends import get_backends
8+
from fastapi_users.db import SQLAlchemyUserDatabase
9+
1010
from utils import get_settings
1111

12+
from api.auth.backends import get_backends
13+
from api.auth.db import User, get_user_db
14+
1215
logger = logging.getLogger("cms")
1316

1417

@@ -19,26 +22,33 @@ class CmsUserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
1922
async def on_after_register(self, user: User, request: Optional[Request] = None) -> None:
2023
logger.info("User %s has registered.", user.id)
2124

22-
async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None) -> None:
25+
async def on_after_forgot_password(
26+
self, user: User, token: str, request: Optional[Request] = None
27+
) -> None:
2328
logger.info("User %s has forgot their password. Reset token: %s", user.id, token)
2429

25-
async def on_after_request_verify(self, user: User, token: str, request: Optional[Request] = None) -> None:
30+
async def on_after_request_verify(
31+
self, user: User, token: str, request: Optional[Request] = None
32+
) -> None:
2633
logger.info("Verification requested for user %s. Verification token: %s", user.id, token)
2734

2835

29-
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)) -> AsyncGenerator:
36+
async def get_user_manager(
37+
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
38+
) -> AsyncGenerator:
3039
yield CmsUserManager(user_db)
3140

3241

3342
class Props(object):
34-
3543
def __init__(self, auth_user_enabled: bool) -> None:
3644
self._auth_backends: List = []
3745
self._fastapi_users = None
3846
self._current_active_user = lambda: None
3947
if auth_user_enabled:
4048
self._auth_backends = get_backends()
41-
self._fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, self.auth_backends)
49+
self._fastapi_users = FastAPIUsers[User, uuid.UUID](
50+
get_user_manager, self.auth_backends
51+
)
4252
self._current_active_user = self._fastapi_users.current_user(active=True)
4353

4454
@property

app/api/dependencies.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
import logging
22
import re
3-
from typing import Union
4-
from typing_extensions import Annotated
3+
from typing import Optional, Union
54

65
from fastapi import HTTPException, Query
76
from starlette.status import HTTP_400_BAD_REQUEST
7+
from typing_extensions import Annotated
88

9-
from typing import Optional
109
from config import Settings
1110
from registry import model_service_registry
12-
from model_services.base import AbstractModelService
11+
1312
from management.model_manager import ModelManager
13+
from model_services.base import AbstractModelService
1414

1515
TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")
1616

1717
logger = logging.getLogger("cms")
1818

1919

2020
class ModelServiceDep(object):
21-
2221
@property
2322
def model_service(self) -> AbstractModelService:
2423
return self._model_sevice
@@ -41,12 +40,11 @@ def __call__(self) -> AbstractModelService:
4140
self._model_sevice = model_service_registry[self._model_type](self._config)
4241
else:
4342
logger.error("Unknown model type: %s", self._model_type)
44-
exit(1) # throw an exception?
43+
exit(1) # throw an exception?
4544
return self._model_sevice
4645

4746

4847
class ModelManagerDep(object):
49-
5048
def __init__(self, model_service: AbstractModelService) -> None:
5149
self._model_manager = ModelManager(model_service.__class__, model_service.service_config)
5250
self._model_manager.model_service = model_service
@@ -56,11 +54,16 @@ def __call__(self) -> ModelManager:
5654

5755

5856
def validate_tracking_id(
59-
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None,
57+
tracking_id: Annotated[
58+
Union[str, None], Query(description="The tracking ID of the requested task")
59+
] = None,
6060
) -> Union[str, None]:
6161
if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None:
6262
raise HTTPException(
6363
status_code=HTTP_400_BAD_REQUEST,
64-
detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256",
64+
detail=(
65+
f"Invalid tracking ID '{tracking_id}',"
66+
" must be an alphanumeric string of length 1 to 256"
67+
),
6568
)
6669
return tracking_id

app/api/routers/authentication.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import logging
2-
import api.globals as cms_globals
2+
33
from fastapi import APIRouter
4+
45
from domain import Tags
6+
7+
import api.globals as cms_globals
8+
59
router = APIRouter()
610
logger = logging.getLogger("cms")
711

0 commit comments

Comments
 (0)