Skip to content

Commit e0cf4ab

Browse files
committed
mypy: analyze types in all possible libraries
Preparation for #646 * Remove `ignore_missing_imports = true` from mypy configuration. * Run mypy in the same environment instead of separate to check types in dependencies like fastapi. * Move mypy dependencies from pre-commit configuration to `setup.cfg`. * Update mypy dependencies there. * Move `rq` from `environment.yml` to `setup.cfg`: conda-forge version: 1.9.0, pypi version : 1.15.1 (two years difference; types were added). * Add libraries with missing types to ignore list in mypy configuration. * Add pydantic mypy plugin. * Allow running mypy without explicit paths. * Update GitHub Actions. * Temporarily add ignore `annotation-unchecked` to make mypy pass. * Fix new mypy issues: * Use https://github.com/hauntsaninja/no_implicit_optional to make `Optional` explicit. * If there is no default, the first `pydantic.Field` argument should be omitted (`None` means that the default argument is `None`). * Refactor `_run_migrations`. There were two different paths: one for normal execution and another one for testing. Simplify arguments and the function code, and introduce a new mock `run_migrations`. * To preserve compatibility, introduce `ChannelWithOptionalName` for the `/channels` patch method. Note that the solution is a bit dirty (I had to use `type: ignore[assignment]`) to minimize the number of models and the diff. * Trivial errors.
1 parent f20db64 commit e0cf4ab

17 files changed

+143
-83
lines changed

.github/workflows/lint.yml

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ jobs:
1717
- name: Add micromamba to GITHUB_PATH
1818
run: echo "${HOME}/micromamba-bin" >> "$GITHUB_PATH"
1919
- run: ln -s "${CONDA_PREFIX}" .venv # Necessary for pyright.
20+
- run: pip install -e .[mypy]
21+
- name: Add mypy to GITHUB_PATH
22+
run: echo "${GITHUB_WORKSPACE}/.venv/bin" >> "$GITHUB_PATH"
2023
- uses: pre-commit/[email protected]
2124
with:
2225
extra_args: --all-files --show-diff-on-failure

.pre-commit-config.yaml

+1-13
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,7 @@ repos:
1919
hooks:
2020
- id: mypy
2121
files: ^quetz/
22-
additional_dependencies:
23-
- sqlalchemy-stubs
24-
- types-click
25-
- types-Jinja2
26-
- types-mock
27-
- types-orjson
28-
- types-pkg-resources
29-
- types-redis
30-
- types-requests
31-
- types-six
32-
- types-toml
33-
- types-ujson
34-
- types-aiofiles
22+
language: system
3523
args: [--show-error-codes]
3624
- repo: https://github.com/Quantco/pre-commit-mirrors-prettier
3725
rev: 2.7.1

environment.yml

-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ dependencies:
4444
- pre-commit
4545
- pytest
4646
- pytest-mock
47-
- rq
4847
- libcflib
4948
- mamba
5049
- conda-content-trust

pyproject.toml

+20-1
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,33 @@ venv = ".venv"
5858
venvPath= "."
5959

6060
[tool.mypy]
61-
ignore_missing_imports = true
61+
packages = [
62+
"quetz"
63+
]
6264
plugins = [
65+
"pydantic.mypy",
6366
"sqlmypy"
6467
]
6568
disable_error_code = [
69+
"annotation-unchecked",
6670
"misc"
6771
]
6872

73+
[[tool.mypy.overrides]]
74+
module = [
75+
"adlfs",
76+
"authlib",
77+
"authlib.*",
78+
"fsspec",
79+
"gcsfs",
80+
"pamela",
81+
"sqlalchemy_utils",
82+
"sqlalchemy_utils.*",
83+
"s3fs",
84+
"xattr"
85+
]
86+
ignore_missing_imports = true
87+
6988
[tool.coverage.run]
7089
omit = [
7190
"quetz/tests/*",

quetz/cli.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,17 @@ def _alembic_config(db_url: str) -> AlembicConfig:
8787

8888

8989
def _run_migrations(
90-
db_url: Optional[str] = None,
91-
alembic_config: Optional[AlembicConfig] = None,
90+
db_url: str,
9291
branch_name: str = "heads",
9392
) -> None:
94-
if db_url:
95-
if db_url.startswith("postgre"):
96-
db_engine = "PostgreSQL"
97-
elif db_url.startswith("sqlite"):
98-
db_engine = "SQLite"
99-
else:
100-
db_engine = db_url.split("/")[0]
101-
logger.info('Running DB migrations on %s', db_engine)
102-
if not alembic_config:
103-
alembic_config = _alembic_config(db_url)
93+
if db_url.startswith("postgre"):
94+
db_engine = "PostgreSQL"
95+
elif db_url.startswith("sqlite"):
96+
db_engine = "SQLite"
97+
else:
98+
db_engine = db_url.split("/")[0]
99+
logger.info('Running DB migrations on %s', db_engine)
100+
alembic_config = _alembic_config(db_url)
104101
command.upgrade(alembic_config, branch_name)
105102

106103

@@ -135,6 +132,7 @@ def _make_migrations(
135132
logger.info('Making DB migrations on %r for %r', db_url, plugin_name)
136133
if not alembic_config and db_url:
137134
alembic_config = _alembic_config(db_url)
135+
assert alembic_config is not None
138136

139137
# find path
140138
if plugin_name == "quetz":
@@ -594,7 +592,7 @@ def start(
594592
uvicorn.run(
595593
"quetz.main:app",
596594
reload=reload,
597-
reload_dirs=(quetz_src,),
595+
reload_dirs=[quetz_src],
598596
port=port,
599597
proxy_headers=proxy_headers,
600598
host=host,

quetz/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class Config:
230230

231231
_instances: Dict[Optional[str], "Config"] = {}
232232

233-
def __new__(cls, deployment_config: str = None):
233+
def __new__(cls, deployment_config: Optional[str] = None):
234234
if not deployment_config and None in cls._instances:
235235
return cls._instances[None]
236236

@@ -254,7 +254,7 @@ def __getattr__(self, name: str) -> Any:
254254
super().__getattr__(self, name)
255255

256256
@classmethod
257-
def find_file(cls, deployment_config: str = None):
257+
def find_file(cls, deployment_config: Optional[str] = None):
258258
config_file_env = os.getenv(f"{_env_prefix}{_env_config_file}")
259259
deployment_config_files = []
260260
for f in (deployment_config, config_file_env):

quetz/dao.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,8 @@ def create_version(
926926
def get_package_versions(
927927
self,
928928
package,
929-
time_created_ge: datetime = None,
930-
version_match_str: str = None,
929+
time_created_ge: Optional[datetime] = None,
930+
version_match_str: Optional[str] = None,
931931
skip: int = 0,
932932
limit: int = -1,
933933
):

quetz/hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@hookspec
14-
def register_router() -> 'fastapi.APIRouter':
14+
def register_router() -> 'fastapi.APIRouter': # type: ignore[empty-body]
1515
"""add extra endpoints to the url tree.
1616
1717
It should return an :py:class:`fastapi.APIRouter` with new endpoints definitions.

quetz/jobs/rest_models.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def parse_job_name(v):
8383
class JobBase(BaseModel):
8484
"""New job spec"""
8585

86-
manifest: str = Field(None, title='Name of the function')
86+
manifest: str = Field(title='Name of the function')
8787

8888
start_at: Optional[datetime] = Field(
8989
None, title="date and time the job should start, if None it starts immediately"
@@ -110,35 +110,35 @@ def validate_job_name(cls, function_name):
110110
class JobCreate(JobBase):
111111
"""Create job spec"""
112112

113-
items_spec: str = Field(..., title='Item selector spec')
113+
items_spec: str = Field(title='Item selector spec')
114114

115115

116116
class JobUpdateModel(BaseModel):
117117
"""Modify job spec items (status and items_spec)"""
118118

119-
items_spec: str = Field(None, title='Item selector spec')
120-
status: JobStatus = Field(None, title='Change status')
119+
items_spec: Optional[str] = Field(None, title='Item selector spec')
120+
status: JobStatus = Field(title='Change status')
121121
force: bool = Field(False, title="force re-running job on all matching packages")
122122

123123

124124
class Job(JobBase):
125-
id: int = Field(None, title='Unique id for job')
126-
owner_id: uuid.UUID = Field(None, title='User id of the owner')
125+
id: int = Field(title='Unique id for job')
126+
owner_id: uuid.UUID = Field(title='User id of the owner')
127127

128-
created: datetime = Field(None, title='Created at')
128+
created: datetime = Field(title='Created at')
129129

130-
status: JobStatus = Field(None, title='Status of the job (running, paused, ...)')
130+
status: JobStatus = Field(title='Status of the job (running, paused, ...)')
131131

132132
items_spec: Optional[str] = Field(None, title='Item selector spec')
133133
model_config = ConfigDict(from_attributes=True)
134134

135135

136136
class Task(BaseModel):
137-
id: int = Field(None, title='Unique id for task')
138-
job_id: int = Field(None, title='ID of the parent job')
139-
package_version: dict = Field(None, title='Package version')
140-
created: datetime = Field(None, title='Created at')
141-
status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)')
137+
id: int = Field(title='Unique id for task')
138+
job_id: int = Field(title='ID of the parent job')
139+
package_version: dict = Field(title='Package version')
140+
created: datetime = Field(title='Created at')
141+
status: TaskStatus = Field(title='Status of the task (running, paused, ...)')
142142

143143
@field_validator("package_version", mode="before")
144144
@classmethod

quetz/main.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def get_users_handler(dao, q, auth, skip, limit):
326326
@api_router.get("/users", response_model=List[rest_models.User], tags=["users"])
327327
def get_users(
328328
dao: Dao = Depends(get_dao),
329-
q: str = None,
329+
q: Optional[str] = None,
330330
auth: authorization.Rules = Depends(get_rules),
331331
):
332332
return get_users_handler(dao, q, auth, 0, -1)
@@ -341,7 +341,7 @@ def get_paginated_users(
341341
dao: Dao = Depends(get_dao),
342342
skip: int = 0,
343343
limit: int = PAGINATION_LIMIT,
344-
q: str = None,
344+
q: Optional[str] = None,
345345
auth: authorization.Rules = Depends(get_rules),
346346
):
347347
return get_users_handler(dao, q, auth, skip, limit)
@@ -521,7 +521,7 @@ def set_user_role(
521521
def get_channels(
522522
public: bool = True,
523523
dao: Dao = Depends(get_dao),
524-
q: str = None,
524+
q: Optional[str] = None,
525525
auth: authorization.Rules = Depends(get_rules),
526526
):
527527
"""List all channels"""
@@ -540,7 +540,7 @@ def get_paginated_channels(
540540
skip: int = 0,
541541
limit: int = PAGINATION_LIMIT,
542542
public: bool = True,
543-
q: str = None,
543+
q: Optional[str] = None,
544544
auth: authorization.Rules = Depends(get_rules),
545545
):
546546
"""List all channels, as a paginated response"""
@@ -780,7 +780,7 @@ def post_channel(
780780
response_model=rest_models.ChannelBase,
781781
)
782782
def patch_channel(
783-
channel_data: rest_models.Channel,
783+
channel_data: rest_models.ChannelWithOptionalName,
784784
dao: Dao = Depends(get_dao),
785785
auth: authorization.Rules = Depends(get_rules),
786786
channel: db_models.Channel = Depends(get_channel_or_fail),
@@ -1054,8 +1054,8 @@ def post_package_member(
10541054
def get_package_versions(
10551055
package: db_models.Package = Depends(get_package_or_fail),
10561056
dao: Dao = Depends(get_dao),
1057-
time_created__ge: datetime.datetime = None,
1058-
version_match_str: str = None,
1057+
time_created__ge: Optional[datetime.datetime] = None,
1058+
version_match_str: Optional[str] = None,
10591059
):
10601060
version_profile_list = dao.get_package_versions(
10611061
package, time_created__ge, version_match_str
@@ -1079,8 +1079,8 @@ def get_paginated_package_versions(
10791079
dao: Dao = Depends(get_dao),
10801080
skip: int = 0,
10811081
limit: int = PAGINATION_LIMIT,
1082-
time_created__ge: datetime.datetime = None,
1083-
version_match_str: str = None,
1082+
time_created__ge: Optional[datetime.datetime] = None,
1083+
version_match_str: Optional[str] = None,
10841084
):
10851085
version_profile_list = dao.get_package_versions(
10861086
package, time_created__ge, version_match_str, skip, limit

quetz/metrics/view.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
from fastapi import FastAPI
34
from prometheus_client import (
45
CONTENT_TYPE_LATEST,
56
REGISTRY,
@@ -9,7 +10,6 @@
910
from prometheus_client.multiprocess import MultiProcessCollector
1011
from starlette.requests import Request
1112
from starlette.responses import Response
12-
from starlette.types import ASGIApp
1313

1414
from .middleware import PrometheusMiddleware
1515

@@ -24,6 +24,6 @@ def metrics(request: Request) -> Response:
2424
return Response(generate_latest(registry), media_type=CONTENT_TYPE_LATEST)
2525

2626

27-
def init(app: ASGIApp):
27+
def init(app: FastAPI):
2828
app.add_middleware(PrometheusMiddleware)
2929
app.add_route("/metricsp", metrics)

quetz/pkgstores.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def file_exists(self, channel: str, destination: str):
116116
def get_filemetadata(self, channel: str, src: str) -> Tuple[int, int, str]:
117117
"""get file metadata: returns (file size, last modified time, etag)"""
118118

119-
@abc.abstractclassmethod
119+
@abc.abstractmethod
120120
def cleanup_temp_files(self, channel: str, dry_run: bool = False):
121121
"""clean up temporary `*.json{HASH}.[bz2|gz]` files from pkgstore"""
122122

quetz/rest_models.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class User(BaseUser):
3737
Profile.model_rebuild()
3838

3939

40-
Role = Field(None, pattern='owner|maintainer|member')
40+
Role = Field(pattern='owner|maintainer|member')
4141

4242

4343
class Member(BaseModel):
@@ -58,7 +58,7 @@ class MirrorMode(str, Enum):
5858

5959

6060
class ChannelBase(BaseModel):
61-
name: str = Field(None, title='The name of the channel', max_length=50)
61+
name: str = Field(title='The name of the channel', max_length=50)
6262
description: Optional[str] = Field(
6363
None, title='The description of the channel', max_length=300
6464
)
@@ -134,7 +134,7 @@ class ChannelMetadata(BaseModel):
134134

135135
class Channel(ChannelBase):
136136
metadata: ChannelMetadata = Field(
137-
default_factory=ChannelMetadata, title="channel metadata", examples={}
137+
default_factory=ChannelMetadata, title="channel metadata", examples=[]
138138
)
139139

140140
actions: Optional[List[ChannelActionEnum]] = Field(
@@ -160,8 +160,14 @@ def check_mirror_params(self) -> "Channel":
160160
return self
161161

162162

163+
class ChannelWithOptionalName(Channel):
164+
name: Optional[str] = Field( # type: ignore[assignment]
165+
None, title='The name of the channel', max_length=50
166+
)
167+
168+
163169
class ChannelMirrorBase(BaseModel):
164-
url: str = Field(None, pattern="^(http|https)://.+")
170+
url: str = Field(pattern="^(http|https)://.+")
165171
api_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
166172
metrics_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
167173
model_config = ConfigDict(from_attributes=True)
@@ -173,7 +179,7 @@ class ChannelMirror(ChannelMirrorBase):
173179

174180
class Package(BaseModel):
175181
name: str = Field(
176-
None, title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
182+
title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
177183
)
178184
summary: Optional[str] = Field(None, title='The summary of the package')
179185
description: Optional[str] = Field(None, title='The description of the package')
@@ -201,18 +207,18 @@ class PackageRole(BaseModel):
201207

202208

203209
class PackageSearch(Package):
204-
channel_name: str = Field(None, title='The channel this package belongs to')
210+
channel_name: str = Field(title='The channel this package belongs to')
205211

206212

207213
class ChannelSearch(BaseModel):
208-
name: str = Field(None, title='The name of the channel', max_length=1500)
214+
name: str = Field(title='The name of the channel', max_length=1500)
209215
description: Optional[str] = Field(None, title='The description of the channel')
210-
private: bool = Field(None, title='The visibility of the channel')
216+
private: bool = Field(title='The visibility of the channel')
211217
model_config = ConfigDict(from_attributes=True)
212218

213219

214220
class PaginatedResponse(BaseModel, Generic[T]):
215-
pagination: Pagination = Field(None, title="Pagination object")
221+
pagination: Pagination = Field(title="Pagination object")
216222
result: List[T] = Field([], title="Result objects")
217223

218224

0 commit comments

Comments
 (0)